Skip to content

Commit

Permalink
Improves client header handling (#46)
Browse files Browse the repository at this point in the history
* Improves header handling and adds support for user defined headers for client upgrades

* Improves header handling and logging
  • Loading branch information
SirCipher authored Sep 26, 2024
1 parent 122efbe commit 36a420b
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 145 deletions.
5 changes: 4 additions & 1 deletion ratchet_core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ pub enum HttpError {
Status(StatusCode),
/// An invalid HTTP version was received in a request.
#[error("Invalid HTTP version: `{0:?}`")]
HttpVersion(Option<u8>),
HttpVersion(String),
/// A request or response was missing an expected header.
#[error("Missing header: `{0}`")]
MissingHeader(HeaderName),
Expand All @@ -176,6 +176,9 @@ pub enum HttpError {
/// A provided header was malformatted
#[error("A provided header was malformatted")]
MalformattedHeader(String),
/// A request was missing the authority.
#[error("Missing authority")]
MissingAuthority,
}

impl From<HttpError> for Error {
Expand Down
205 changes: 92 additions & 113 deletions ratchet_core/src/handshake/client/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
// limitations under the License.

use base64::Engine;
use bytes::{BufMut, BytesMut};
use http::header::{AsHeaderName, HeaderName, IntoHeaderName};
use bytes::BytesMut;
use http::header::{HOST, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL};
use http::request::Parts;
use http::{header, HeaderMap, HeaderValue, Method, Request, Version};
use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Version};

use ratchet_ext::ExtensionProvider;

Expand All @@ -27,13 +27,13 @@ use crate::handshake::{
};

use base64::engine::general_purpose::STANDARD;
use log::error;

pub fn encode_request(dst: &mut BytesMut, request: ValidatedRequest, nonce_buffer: &mut Nonce) {
let ValidatedRequest {
version,
headers,
path_and_query,
host,
} = request;

let nonce = rand::random::<[u8; 16]>();
Expand All @@ -49,76 +49,34 @@ pub fn encode_request(dst: &mut BytesMut, request: ValidatedRequest, nonce_buffe
let request = format!(
"\
GET {path} {version:?}\r\n\
Host: {host}\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
sec-websocket-version: 13\r\n\
sec-websocket-key: {nonce}",
version = version,
path = path_and_query,
host = host,
nonce = nonce_str
);

// 28 = request terminator + nonce buffer len
let mut len = 28 + request.len();
extend(dst, request.as_bytes());

let origin = write_header(&headers, header::ORIGIN);
let protocol = write_header(&headers, header::SEC_WEBSOCKET_PROTOCOL);
let ext = write_header(&headers, header::SEC_WEBSOCKET_EXTENSIONS);
let auth = write_header(&headers, header::AUTHORIZATION);

if let Some((name, value)) = &origin {
len += name.len() + value.len() + 2;
}
if let Some((name, value)) = &protocol {
len += name.len() + value.len() + 2;
}
if let Some((name, value)) = &ext {
len += name.len() + value.len() + 2;
for (name, value) in &headers {
extend(dst, b"\r\n");
extend(dst, name.as_str().as_bytes());
extend(dst, b": ");
extend(dst, value.as_bytes());
}
if let Some((name, value)) = &auth {
len += name.len() + value.len() + 2;
}

dst.reserve(len);
dst.put_slice(request.as_bytes());

if let Some((name, value)) = origin {
dst.put_slice(b"\r\n");
dst.put_slice(name.as_bytes());
dst.put_slice(value);
}
if let Some((name, value)) = protocol {
dst.put_slice(b"\r\n");
dst.put_slice(name.as_bytes());
dst.put_slice(value);
}
if let Some((name, value)) = ext {
dst.put_slice(b"\r\n");
dst.put_slice(name.as_bytes());
dst.put_slice(value);
}
if let Some((name, value)) = auth {
dst.put_slice(b"\r\n");
dst.put_slice(name.as_bytes());
dst.put_slice(value);
}

dst.put_slice(b"\r\n\r\n");
extend(dst, b"\r\n\r\n");
}

fn write_header(headers: &HeaderMap<HeaderValue>, name: HeaderName) -> Option<(String, &[u8])> {
headers
.get(&name)
.map(|value| (format!("{}: ", name), value.as_bytes()))
#[inline]
fn extend(dst: &mut BytesMut, data: &[u8]) {
dst.extend_from_slice(data);
}

#[derive(Debug)]
pub struct ValidatedRequest {
version: Version,
headers: HeaderMap,
path_and_query: String,
host: String,
}

// rfc6455 § 4.2.1
Expand Down Expand Up @@ -149,20 +107,46 @@ where
if version != Version::HTTP_11 {
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::HttpVersion(None),
HttpError::HttpVersion(format!("{version:?}")),
));
}

let authority = uri
.authority()
.ok_or_else(|| Error::with_cause(ErrorKind::Http, "Missing authority"))?
.as_str()
.to_string();
validate_or_insert(
&mut headers,
header::HOST,
HeaderValue::from_str(authority.as_ref())?,
)?;
if headers.get(SEC_WEBSOCKET_EXTENSIONS).is_some() {
error!(
"{} should only be set by extensions",
SEC_WEBSOCKET_EXTENSIONS
);
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::InvalidHeader(SEC_WEBSOCKET_EXTENSIONS),
));
}

// Run this first to ensure that the extension doesn't invalidate the headers.
extension.apply_headers(&mut headers);

match validate_host_header(&headers) {
Ok(()) => {
// The request should only contain *one* 'host' header, and it must be a single value,
// not a comma seperated list. If the request doesn't already have one then derive it
// from the URI if it contains an authority. If it doesn't, then the request is invalid
// and any correct server implementation would reject it - including Ratchet.
let authority = uri
.authority()
.ok_or_else(|| Error::with_cause(ErrorKind::Http, HttpError::MissingAuthority))?
.as_str()
.to_string();
validate_or_insert(
&mut headers,
header::HOST,
HeaderValue::from_str(authority.as_ref())?,
)?;
}
Err(e) => {
error!("Request should only contain one 'host' header. {e}");
return Err(e);
}
}

validate_or_insert(
&mut headers,
Expand All @@ -180,54 +164,28 @@ where
HeaderValue::from_static(WEBSOCKET_VERSION_STR),
)?;

if headers.get(header::SEC_WEBSOCKET_EXTENSIONS).is_some() {
if headers.get(SEC_WEBSOCKET_PROTOCOL).is_some() {
error!(
"{} should only be set by extensions",
SEC_WEBSOCKET_PROTOCOL
);
// WebSocket protocols can only be applied using a ProtocolRegistry
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::InvalidHeader(header::SEC_WEBSOCKET_EXTENSIONS),
HttpError::InvalidHeader(SEC_WEBSOCKET_PROTOCOL),
));
}

extension.apply_headers(&mut headers);
apply_to(subprotocols, &mut headers);

if headers.get(header::SEC_WEBSOCKET_PROTOCOL).is_some() {
// WebSocket protocols can only be applied using a ProtocolRegistry
if headers.get(SEC_WEBSOCKET_KEY).is_some() {
error!("{} should not be set", SEC_WEBSOCKET_KEY);
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::InvalidHeader(header::SEC_WEBSOCKET_PROTOCOL),
HttpError::InvalidHeader(SEC_WEBSOCKET_KEY),
));
}

apply_to(subprotocols, &mut headers);

let option = headers
.get(header::SEC_WEBSOCKET_KEY)
.map(|head| head.to_str());
match option {
Some(Ok(version)) if version == WEBSOCKET_VERSION_STR => {}
None => {
headers.insert(
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static(WEBSOCKET_VERSION_STR),
);
}
_ => {
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::InvalidHeader(header::SEC_WEBSOCKET_KEY),
));
}
}

let host = uri
.authority()
.ok_or_else(|| {
Error::with_cause(
ErrorKind::Http,
HttpError::MalformattedUri(Some("Missing authority".to_string())),
)
})?
.to_string();

let path_and_query = uri
.path_and_query()
.map(ToString::to_string)
Expand All @@ -237,25 +195,46 @@ where
version,
headers,
path_and_query,
host,
})
}

fn validate_or_insert<A>(
fn validate_or_insert(
headers: &mut HeaderMap,
header_name: A,
header_name: HeaderName,
expected: HeaderValue,
) -> Result<(), Error>
where
A: AsHeaderName + IntoHeaderName + Clone,
{
) -> Result<(), HttpError> {
if let Some(header_value) = headers.get(header_name.clone()) {
match header_value.to_str() {
Ok(v) if v.as_bytes().eq_ignore_ascii_case(expected.as_bytes()) => Ok(()),
_ => Err(Error::new(ErrorKind::Http)),
_ => {
error!("Invalid header set: {} -> {:?}", header_name, header_value);
Err(HttpError::InvalidHeader(header_name))
}
}
} else {
headers.insert(header_name, expected);
Ok(())
}
}

/// Validates that 'headers' contains at most one 'host' header and that it is not a seperated list.
fn validate_host_header(headers: &HeaderMap) -> Result<(), Error> {
let len = headers
.iter()
.filter_map(|(name, value)| {
if name.as_str().eq_ignore_ascii_case(HOST.as_str()) {
Some(value.as_bytes().split(|c| c == &b' ' || c == &b','))
} else {
None
}
})
.count();
if len <= 1 {
Ok(())
} else {
Err(Error::with_cause(
ErrorKind::Http,
HttpError::InvalidHeader(HOST),
))
}
}
12 changes: 7 additions & 5 deletions ratchet_core/src/handshake/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod encoding;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use bytes::BytesMut;
use http::{header, Request, StatusCode};
use http::{header, Request, StatusCode, Version};
use httparse::{Response, Status};
use log::{error, trace};
use sha1::{Digest, Sha1};
Expand Down Expand Up @@ -227,12 +227,14 @@ where
subprotocols,
} = self;

trace!("Encoding request: {request:?}");
let validated_request = build_request(request, extension, subprotocols)?;
encode_request(buffered.buffer, validated_request, nonce);
Ok(())
}

async fn write(&mut self) -> Result<(), Error> {
trace!("Writing buffered data");
self.buffered.write().await
}

Expand Down Expand Up @@ -285,10 +287,10 @@ fn check_partial_response(response: &Response) -> Result<(), Error> {
// httparse sets this to 0 for HTTP/1.0 or 1 for HTTP/1.1
// rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher
Some(1) | None => {}
Some(v) => {
Some(_) => {
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::HttpVersion(Some(v)),
HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)),
))
}
}
Expand Down Expand Up @@ -340,10 +342,10 @@ where
match response.version {
// rfc6455 § 4.2.1.1: must be HTTP/1.1 or higher
Some(1) => {}
v => {
_ => {
return Err(Error::with_cause(
ErrorKind::Http,
HttpError::HttpVersion(v),
HttpError::HttpVersion(format!("{:?}", Version::HTTP_10)),
))
}
}
Expand Down
Loading

0 comments on commit 36a420b

Please sign in to comment.