diff --git a/Cargo.lock b/Cargo.lock index 5b599e6b..1782324e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -815,7 +815,6 @@ name = "engineioxide" version = "0.17.5" dependencies = [ "axum", - "base64", "bytes", "codspeed-criterion-compat", "engineioxide-core", @@ -826,8 +825,6 @@ dependencies = [ "http-body-util", "hyper", "hyper-util", - "itoa", - "memchr", "pin-project-lite", "serde", "serde_json", @@ -849,8 +846,20 @@ version = "0.2.1" dependencies = [ "base64", "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "memchr", + "pin-project-lite", "rand 0.10.1", "serde", + "serde_json", + "smallvec", + "tokio", + "tokio-stream", + "tracing", ] [[package]] diff --git a/crates/engineioxide-core/Cargo.toml b/crates/engineioxide-core/Cargo.toml index 485e57e3..cd4b642d 100644 --- a/crates/engineioxide-core/Cargo.toml +++ b/crates/engineioxide-core/Cargo.toml @@ -17,6 +17,29 @@ rand = "0.10" base64 = "0.22" serde.workspace = true bytes.workspace = true +serde_json.workspace = true +http-body.workspace = true +http-body-util.workspace = true +http.workspace = true +futures-util.workspace = true +smallvec.workspace = true +pin-project-lite.workspace = true + +# Engine.io V3 payload +itoa = { workspace = true, optional = true } +memchr = { version = "2.7", optional = true } + +# Tracing +tracing = { workspace = true, optional = true } + + +[features] +v3 = ["dep:memchr", "dep:itoa"] +tracing = ["dep:tracing"] + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt", "time", "sync"] } +tokio-stream.workspace = true [lints] workspace = true diff --git a/crates/engineioxide-core/src/lib.rs b/crates/engineioxide-core/src/lib.rs index 9b5fff8f..fe94cc6b 100644 --- a/crates/engineioxide-core/src/lib.rs +++ b/crates/engineioxide-core/src/lib.rs @@ -1,8 +1,13 @@ #![warn(missing_docs)] #![doc = include_str!("../README.md")] +mod packet; +mod protocol; mod sid; mod str; +pub use packet::{OpenPacket, Packet, PacketBuf, PacketParseError}; +pub use protocol::{ProtocolVersion, TransportType, UnknownTransportError}; pub use sid::Sid; pub use str::Str; +pub mod payload; diff --git a/crates/engineioxide/src/packet.rs b/crates/engineioxide-core/src/packet.rs similarity index 68% rename from crates/engineioxide/src/packet.rs rename to crates/engineioxide-core/src/packet.rs index ead18cae..0d16416d 100644 --- a/crates/engineioxide/src/packet.rs +++ b/crates/engineioxide-core/src/packet.rs @@ -1,16 +1,14 @@ +use std::{fmt, time::Duration}; + use base64::{Engine, engine::general_purpose}; use bytes::Bytes; -use engineioxide_core::{Sid, Str}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use smallvec::{SmallVec, smallvec}; -use std::time::Duration; +use smallvec::SmallVec; -use crate::config::EngineIoConfig; -use crate::errors::Error; -use crate::{ProtocolVersion, TransportType}; +use crate::{ProtocolVersion, Sid, Str, TransportType}; /// A Packet type to use when receiving and sending data from the client -#[derive(Debug, Clone, PartialEq, PartialOrd)] +#[derive(Debug, Clone, PartialEq)] pub enum Packet { /// Open packet used to initiate a connection Open(OpenPacket), @@ -53,6 +51,67 @@ pub enum Packet { BinaryV3(Bytes), // Not part of the protocol, used internally } +/// An error that occurs when parsing a packet. +#[derive(Debug)] +pub enum PacketParseError { + /// Invalid connect packet + InvalidConnectPacket(serde_json::Error), + /// The packet type is invalid. + InvalidPacketType(Option), + /// The packet payload is invalid. + InvalidPacketPayload, + /// The packet length is invalid. + InvalidPacketLen, + /// The packet chunk is invalid + InvalidUtf8Boundary(std::str::Utf8Error), + /// The base64 decoding failed. + Base64Decode(base64::DecodeError), + /// The payload is too large. + PayloadTooLarge { + /// The maximum allowed payload size. + max: u64, + }, +} +impl fmt::Display for PacketParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PacketParseError::InvalidConnectPacket(e) => write!(f, "invalid connect packet: {e}"), + PacketParseError::InvalidPacketType(c) => write!(f, "invalid packet type: {c:?}"), + PacketParseError::InvalidPacketPayload => write!(f, "invalid packet payload"), + PacketParseError::InvalidPacketLen => write!(f, "invalid packet length"), + PacketParseError::InvalidUtf8Boundary(err) => write!( + f, + "invalid utf8 boundary when parsing payload into packet chunks: {err}" + ), + PacketParseError::Base64Decode(err) => write!(f, "base64 decode error: {err}"), + PacketParseError::PayloadTooLarge { max } => { + write!(f, "payload too large: max {max}") + } + } + } +} +impl From for PacketParseError { + fn from(err: base64::DecodeError) -> Self { + PacketParseError::Base64Decode(err) + } +} +impl From for PacketParseError { + fn from(err: std::string::FromUtf8Error) -> Self { + PacketParseError::InvalidUtf8Boundary(err.utf8_error()) + } +} +impl From for PacketParseError { + fn from(err: std::str::Utf8Error) -> Self { + PacketParseError::InvalidUtf8Boundary(err) + } +} +impl From for PacketParseError { + fn from(err: serde_json::Error) -> Self { + PacketParseError::InvalidConnectPacket(err) + } +} +impl std::error::Error for PacketParseError {} + impl Packet { /// Check if the packet is a binary packet pub fn is_binary(&self) -> bool { @@ -60,7 +119,7 @@ impl Packet { } /// If the packet is a message packet (text), it returns the message - pub(crate) fn into_message(self) -> Str { + pub fn into_message(self) -> Str { match self { Packet::Message(msg) => msg, _ => panic!("Packet is not a message"), @@ -68,7 +127,7 @@ impl Packet { } /// If the packet is a binary packet, it returns the binary data - pub(crate) fn into_binary(self) -> Bytes { + pub fn into_binary(self) -> Bytes { match self { Packet::Binary(data) => data, Packet::BinaryV3(data) => data, @@ -81,7 +140,7 @@ impl Packet { /// If b64 is true, it returns the max size when serialized to base64 /// /// The base64 max size factor is `ceil(n / 3) * 4` - pub(crate) fn get_size_hint(&self, b64: bool) -> usize { + pub fn get_size_hint(&self, b64: bool) -> usize { match self { Packet::Open(_) => 156, // max possible size for the open packet serialized Packet::Close => 1, @@ -110,6 +169,12 @@ impl Packet { } } +impl From for Bytes { + fn from(value: Packet) -> Self { + String::from(value).into() + } +} + /// Serialize a [Packet] to a [String] according to the Engine.IO protocol impl From for String { fn from(packet: Packet) -> String { @@ -143,25 +208,19 @@ impl From for String { buffer } } -impl From for tokio_tungstenite::tungstenite::Utf8Bytes { - fn from(value: Packet) -> Self { - String::from(value).into() - } -} -impl From for Bytes { - fn from(value: Packet) -> Self { - String::from(value).into() - } -} +/// Deserialize a [Packet] from a [String] according to the Engine.IO protocol impl Packet { /// Parses a packet from a string value using the specified protocol version. - pub fn parse(protocol: ProtocolVersion, value: impl Into) -> Result { + pub fn parse( + protocol: ProtocolVersion, + value: impl Into, + ) -> Result { let value = value.into(); let packet_type = value .as_bytes() .first() - .ok_or(Error::InvalidPacketType(None))?; + .ok_or(PacketParseError::InvalidPacketType(None))?; let is_upgrade = value.len() == 6 && &value[1..6] == "probe"; let res = match packet_type { b'1' => Packet::Close, @@ -182,29 +241,39 @@ impl Packet { .decode(value.slice(1..).as_bytes())? .into(), ), - c => Err(Error::InvalidPacketType(Some(*c as char)))?, + c => Err(PacketParseError::InvalidPacketType(Some(*c as char)))?, }; Ok(res) } } /// An OpenPacket is used to initiate a connection -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[serde(rename_all = "camelCase")] pub struct OpenPacket { - sid: Sid, - upgrades: SmallVec<[TransportType; 1]>, + /// The session ID. + pub sid: Sid, + + /// The list of available transport upgrades. + pub upgrades: SmallVec<[TransportType; 1]>, + + /// The ping interval, used in the heartbeat mechanism. #[serde( serialize_with = "serialize_duration_millis", deserialize_with = "deserialize_duration_from_millis" )] - ping_interval: Duration, + pub ping_interval: Duration, + + /// The ping timeout, used in the heartbeat mechanism. #[serde( serialize_with = "serialize_duration_millis", deserialize_with = "deserialize_duration_from_millis" )] - ping_timeout: Duration, - max_payload: u64, + pub ping_timeout: Duration, + + /// The maximum number of bytes per chunk, used by the client to + /// aggregate packets into payloads. + pub max_payload: u64, } /// Helper to serialize a duration as milliseconds @@ -225,28 +294,28 @@ where Ok(Duration::from_millis(millis)) } -impl OpenPacket { - /// Create a new [OpenPacket] - /// If the current transport is polling, the server will always allow the client to upgrade to websocket - pub fn new(transport: TransportType, sid: Sid, config: &EngineIoConfig) -> Self { - let upgrades = if transport == TransportType::Polling { - smallvec![TransportType::Websocket] - } else { - smallvec![] - }; - OpenPacket { - sid, - upgrades, - ping_interval: config.ping_interval, - ping_timeout: config.ping_timeout, - max_payload: config.max_payload, +/// This default implementation should only be used for testing purposes. +impl Default for OpenPacket { + fn default() -> Self { + Self { + sid: Sid::ZERO, + upgrades: smallvec::smallvec![TransportType::Websocket], + ping_interval: Duration::from_millis(25000), + ping_timeout: Duration::from_millis(20000), + max_payload: 100000, } } } +/// Buffered packets to send to the client. +/// It is used to ensure atomicity when sending multiple packets to the client. +/// +/// The [`PacketBuf`] stack size will impact the dynamically allocated buffer +/// of the internal mpsc channel. +pub type PacketBuf = SmallVec<[Packet; 2]>; + #[cfg(test)] mod tests { - use crate::config::EngineIoConfig; use super::*; use std::time::Duration; @@ -254,11 +323,13 @@ mod tests { #[test] fn test_open_packet() { let sid = Sid::new(); - let packet = Packet::Open(OpenPacket::new( - TransportType::Polling, + let packet = Packet::Open(OpenPacket { sid, - &EngineIoConfig::default(), - )); + upgrades: smallvec::smallvec![TransportType::Websocket], + ping_interval: Duration::from_millis(25000), + ping_timeout: Duration::from_millis(20000), + max_payload: 100000, + }); let packet_str: String = packet.into(); assert_eq!( packet_str, @@ -275,12 +346,6 @@ mod tests { assert_eq!(packet_str, "4hello"); } - #[test] - fn test_message_packet_deserialize() { - let packet = Packet::parse(ProtocolVersion::V4, "4hello").unwrap(); - assert_eq!(packet, Packet::Message("hello".into())); - } - #[test] fn test_binary_packet() { let packet = Packet::Binary(vec![1, 2, 3].into()); @@ -288,12 +353,6 @@ mod tests { assert_eq!(packet_str, "bAQID"); } - #[test] - fn test_binary_packet_deserialize() { - let packet = Packet::parse(ProtocolVersion::V4, "bAQID").unwrap(); - assert_eq!(packet, Packet::Binary(vec![1, 2, 3].into())); - } - #[test] fn test_binary_packet_v4_deserialize_payload_starting_with_4() { let data = vec![0xE0, 0xE1, 0xE2]; @@ -312,27 +371,16 @@ mod tests { assert_eq!(packet_str, "b4AQID"); } - #[test] - fn test_binary_packet_v3_deserialize() { - let packet = Packet::parse(ProtocolVersion::V3, "b4AQID").unwrap(); - assert_eq!(packet, Packet::BinaryV3(vec![1, 2, 3].into())); - } - #[test] fn test_packet_get_size_hint() { // Max serialized packet - let open = OpenPacket::new( - TransportType::Polling, - Sid::new(), - &EngineIoConfig { - max_buffer_size: usize::MAX, - max_payload: u64::MAX, - ping_interval: Duration::MAX, - ping_timeout: Duration::MAX, - transports: TransportType::Polling as u8 | TransportType::Websocket as u8, - ..Default::default() - }, - ); + let open = OpenPacket { + sid: Sid::new(), + ping_interval: Duration::MAX, + ping_timeout: Duration::MAX, + max_payload: u64::MAX, + upgrades: smallvec::smallvec![TransportType::Websocket], + }; let size = serde_json::to_string(&open).unwrap().len(); let packet = Packet::Open(open); assert_eq!(packet.get_size_hint(false), size); diff --git a/crates/engineioxide/src/transport/polling/payload/buf.rs b/crates/engineioxide-core/src/payload/buf.rs similarity index 100% rename from crates/engineioxide/src/transport/polling/payload/buf.rs rename to crates/engineioxide-core/src/payload/buf.rs diff --git a/crates/engineioxide/src/transport/polling/payload/decoder.rs b/crates/engineioxide-core/src/payload/decoder.rs similarity index 89% rename from crates/engineioxide/src/transport/polling/payload/decoder.rs rename to crates/engineioxide-core/src/payload/decoder.rs index a748fd81..da8098fd 100644 --- a/crates/engineioxide/src/transport/polling/payload/decoder.rs +++ b/crates/engineioxide-core/src/payload/decoder.rs @@ -5,11 +5,9 @@ //! - v3_decoder: Decodes the payload stream according to the [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) //! -use futures_core::Stream; -use futures_util::StreamExt; -use http::StatusCode; +use crate::{Packet, PacketParseError, ProtocolVersion}; +use futures_util::{Stream, StreamExt}; -use crate::{ProtocolVersion, errors::Error, packet::Packet}; use bytes::Buf; use http_body::Body; use http_body_util::BodyStream; @@ -23,6 +21,8 @@ struct Payload { end_of_stream: bool, current_payload_size: u64, + /// counter to detect if packets have already been + /// yielded or if the poller needs to wait #[cfg(feature = "v3")] yield_packets: u32, } @@ -42,7 +42,7 @@ impl Payload { /// Polls the body stream for data and adds it to the chunk list in the state /// Returns an error if the packet length exceeds the maximum allowed payload size -async fn poll_body(state: &mut Payload, max_payload: u64) -> Result<(), Error> +async fn poll_body(state: &mut Payload, max_payload: u64) -> Result<(), PacketParseError> where B: Body + Unpin, E: std::fmt::Debug, @@ -59,7 +59,7 @@ where Err(_e) => { #[cfg(feature = "tracing")] tracing::debug!("error reading body stream: {:?}", _e); - Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)) + Err(PacketParseError::InvalidPacketPayload) } }?; if state.current_payload_size + (data.remaining() as u64) <= max_payload { @@ -67,11 +67,14 @@ where state.buffer.push(data); Ok(()) } else { - Err(Error::PayloadTooLarge) + Err(PacketParseError::PayloadTooLarge { max: max_payload }) } } -pub fn v4_decoder(body: B, max_payload: u64) -> impl Stream> +pub fn v4_decoder( + body: B, + max_payload: u64, +) -> impl Stream> where B: Body + Unpin, E: std::fmt::Debug, @@ -93,11 +96,14 @@ where } // Read from the buffer until the packet separator is found - if let Err(e) = (&mut state.buffer) + if let Err(_err) = (&mut state.buffer) .reader() .read_until(PACKET_SEPARATOR_V4, &mut packet_buf) { - break Some((Err(Error::Io(e)), state)); + #[cfg(feature = "tracing")] + tracing::debug!("failed to read packet payload: {_err}"); + + break Some((Err(PacketParseError::InvalidPacketPayload), state)); } let separator_found = packet_buf.ends_with(&[PACKET_SEPARATOR_V4]); @@ -111,7 +117,7 @@ where || (state.end_of_stream && state.buffer.remaining() == 0 && !packet_buf.is_empty()) { let packet = String::from_utf8(packet_buf) - .map_err(|_| Error::InvalidPacketLength) + .map_err(PacketParseError::from) .and_then(|v| Packet::parse(ProtocolVersion::V4, v)); // Convert the packet buffer to a Packet object break Some((packet, state)); // Emit the packet and the updated state } else if state.end_of_stream && state.buffer.remaining() == 0 { @@ -125,14 +131,14 @@ where pub fn v3_binary_decoder( body: B, max_payload: u64, -) -> impl Stream> +) -> impl Stream> where B: Body + Unpin, E: std::fmt::Debug, { use std::io::Read; - use crate::transport::polling::payload::{ + use crate::payload::{ BINARY_PACKET_IDENTIFIER_V3, BINARY_PACKET_SEPARATOR_V3, STRING_PACKET_IDENTIFIER_V3, }; @@ -156,11 +162,14 @@ where // If there is no packet_type found if packet_type.is_none() && state.buffer.remaining() > 0 { // Read from the buffer until the packet separator is found - if let Err(e) = (&mut state.buffer) + if let Err(_err) = (&mut state.buffer) .reader() .read_until(BINARY_PACKET_SEPARATOR_V3, &mut packet_buf) { - break Some((Err(Error::Io(e)), state)); + #[cfg(feature = "tracing")] + tracing::debug!("failed to read packet payload: {_err}"); + + break Some((Err(PacketParseError::InvalidPacketPayload), state)); } // Extract packet_type and packet_size @@ -173,11 +182,11 @@ where Some(&STRING_PACKET_IDENTIFIER_V3) => { packet_type = Some(STRING_PACKET_IDENTIFIER_V3) } - _ => break Some((Err(Error::InvalidPacketLength), state)), + _ => break Some((Err(PacketParseError::InvalidPacketLen), state)), } if packet_buf.len() > 9 { - break Some((Err(Error::InvalidPacketLength), state)); + break Some((Err(PacketParseError::InvalidPacketLen), state)); } let size_str = &packet_buf[1..] @@ -187,7 +196,7 @@ where if let Ok(size) = size_str.parse() { packet_size = size; } else { - break Some((Err(Error::InvalidPacketLength), state)); + break Some((Err(PacketParseError::InvalidPacketLen), state)); } packet_buf.clear(); } @@ -204,19 +213,23 @@ where // Read the packet data let packet = match packet_type.unwrap() { STRING_PACKET_IDENTIFIER_V3 => String::from_utf8(packet_buf) - .map_err(|_| Error::InvalidPacketLength) + .map_err(PacketParseError::from) .and_then(|v| Packet::parse(ProtocolVersion::V3, v)), // Convert the packet buffer to a Packet object BINARY_PACKET_IDENTIFIER_V3 => Ok(Packet::BinaryV3(packet_buf.into())), - _ => Err(Error::InvalidPacketLength), + _ => Err(PacketParseError::InvalidPacketLen), }; + state.yield_packets += 1; break Some((packet, state)); - } else if state.end_of_stream && state.buffer.remaining() == 0 { + } else if state.end_of_stream + && state.buffer.remaining() == 0 + && state.yield_packets > 0 + { break None; } else if state.end_of_stream { // EOS reached with leftover bytes that cannot form a complete // packet (truncated header or truncated body). - break Some((Err(Error::InvalidPacketLength), state)); + break Some((Err(PacketParseError::InvalidPacketLen), state)); } } }) @@ -250,10 +263,10 @@ fn utf16_len(s: &str) -> usize { pub fn v3_string_decoder( body: impl Body + Unpin, max_payload: u64, -) -> impl Stream> { +) -> impl Stream> { use std::io::ErrorKind; - use crate::transport::polling::payload::STRING_PACKET_SEPARATOR_V3; + use crate::payload::STRING_PACKET_SEPARATOR_V3; #[cfg(feature = "tracing")] tracing::debug!("decoding payload with v3 string decoder"); @@ -272,7 +285,7 @@ pub fn v3_string_decoder( if state.end_of_stream && state.buffer.remaining() == 0 && state.yield_packets > 0 { break None; // Reached end of stream with no more data, end the stream } else if state.end_of_stream && state.buffer.remaining() == 0 { - return Some((Err(Error::InvalidPacketLength), state)); + return Some((Err(PacketParseError::InvalidPacketLen), state)); } let mut reader = (&mut state.buffer).reader(); @@ -285,7 +298,12 @@ pub fn v3_string_decoder( let available = match reader.fill_buf() { Ok(n) => n, Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, - Err(e) => return Some((Err(Error::Io(e)), state)), + Err(_err) => { + #[cfg(feature = "tracing")] + tracing::debug!("failed to read packet payload: {_err}"); + + return Some((Err(PacketParseError::InvalidPacketPayload), state)); + } }; let old_len = packet_buf.len(); packet_buf.extend_from_slice(available); @@ -294,9 +312,10 @@ pub fn v3_string_decoder( Some(i) => { // Extract the packet length from the available data packet_utf16_len = match std::str::from_utf8(&packet_buf[..i]) - .map_err(|_| Error::InvalidPacketLength) + .map_err(PacketParseError::from) .and_then(|s| { - s.parse::().map_err(|_| Error::InvalidPacketLength) + s.parse::() + .map_err(|_| PacketParseError::InvalidPacketLen) }) { Ok(size) => size, Err(e) => return Some((Err(e), state)), @@ -306,7 +325,7 @@ pub fn v3_string_decoder( (true, i + 1 - old_len) // Mark as done and set the used bytes count } None if state.end_of_stream && remaining - available.len() == 0 => { - return Some((Err(Error::InvalidPacketLength), state)); + return Some((Err(PacketParseError::InvalidPacketLen), state)); } // Reached end of stream and end of bufferered chunks without finding the separator None => (false, available.len()), // Continue reading more data } @@ -365,7 +384,7 @@ pub fn v3_string_decoder( // SAFETY: packet_buf is a valid utf8 string checkd above let packet = unsafe { String::from_utf8_unchecked(packet_buf) }; let packet = Packet::parse(ProtocolVersion::V3, packet) - .map_err(|_| Error::InvalidPacketLength); + .map_err(|_| PacketParseError::InvalidPacketLen); state.yield_packets += 1; break Some((packet, state)); // Emit the packet and the updated state } @@ -384,8 +403,6 @@ mod tests { use http_body::Frame; use http_body_util::{Full, StreamBody}; - use crate::packet::Packet; - use super::*; const MAX_PAYLOAD: u64 = 100_000; @@ -450,7 +467,10 @@ mod tests { let payload = v4_decoder(stream, MAX_PAYLOAD); futures_util::pin_mut!(payload); let packet = payload.next().await.unwrap(); - assert!(matches!(packet, Err(Error::PayloadTooLarge))); + assert!(matches!( + packet, + Err(PacketParseError::PayloadTooLarge { max: MAX_PAYLOAD }) + )); } } @@ -607,7 +627,10 @@ mod tests { let payload = v3_binary_decoder(stream, MAX_PAYLOAD); futures_util::pin_mut!(payload); let packet = payload.next().await.unwrap(); - assert!(matches!(packet, Err(Error::PayloadTooLarge))); + assert!(matches!( + packet, + Err(PacketParseError::PayloadTooLarge { max: MAX_PAYLOAD }) + )); } for i in 1..DATA.len() { let stream = StreamBody::new(futures_util::stream::iter( @@ -618,7 +641,10 @@ mod tests { let payload = v3_string_decoder(stream, MAX_PAYLOAD); futures_util::pin_mut!(payload); let packet = payload.next().await.unwrap(); - assert!(matches!(packet, Err(Error::PayloadTooLarge))); + assert!(matches!( + packet, + Err(PacketParseError::PayloadTooLarge { max: MAX_PAYLOAD }) + )); } } diff --git a/crates/engineioxide/src/transport/polling/payload/encoder.rs b/crates/engineioxide-core/src/payload/encoder.rs similarity index 63% rename from crates/engineioxide/src/transport/polling/payload/encoder.rs rename to crates/engineioxide-core/src/payload/encoder.rs index a56a9c9a..48f7d996 100644 --- a/crates/engineioxide/src/transport/polling/payload/encoder.rs +++ b/crates/engineioxide-core/src/payload/encoder.rs @@ -7,12 +7,12 @@ //! * binary encoder (used when there are binary packets and the client supports binary) //! -use tokio::sync::MutexGuard; +use std::pin::Pin; -use crate::{ - errors::Error, packet::Packet, peekable::PeekableReceiver, socket::PacketBuf, - transport::polling::payload::Payload, -}; +use futures_util::{FutureExt, Stream, StreamExt, stream::Peekable}; +use smallvec::smallvec; + +use crate::{Packet, packet::PacketBuf, payload::Payload}; /// Try to immediately poll a new packet buf from the rx channel and check that the new packet can be added to the payload /// @@ -25,12 +25,12 @@ use crate::{ /// * `max_payload` - The maximum payload length /// * `b64` - If binary packets should be encoded in base64 fn try_recv_packet( - rx: &mut MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, payload_len: usize, max_payload: u64, b64: bool, ) -> Option { - if let Some(packets) = rx.peek() { + if let Some(packets) = rx.as_mut().peek().now_or_never().flatten() { let size = packets.iter().map(|p| p.get_size_hint(b64)).sum::(); if (payload_len + size) as u64 > max_payload { #[cfg(feature = "tracing")] @@ -39,14 +39,14 @@ fn try_recv_packet( } } - let packets = rx.try_recv().ok(); + let packets = rx.next().now_or_never().flatten(); - if Some(&Packet::Close) == packets.as_ref().and_then(|p| p.first()) { - #[cfg(feature = "tracing")] - tracing::debug!("Received close packet, closing channel"); - rx.try_recv().ok(); - rx.close(); - } + // if Some(&Packet::Close) == packets.as_ref().and_then(|p| p.first()) { + // #[cfg(feature = "tracing")] + // tracing::debug!("Received close packet, closing channel"); + // rx.try_recv().ok(); + // rx.close(); + // } #[cfg(feature = "tracing")] tracing::debug!("sending packet: {:?}", packets); @@ -55,28 +55,28 @@ fn try_recv_packet( /// Same as [`try_recv_packet`] /// but wait for a new packet if there is no packet in the buffer -async fn recv_packet( - rx: &mut MutexGuard<'_, PeekableReceiver>, -) -> Result { - let packet = rx.recv().await.ok_or(Error::Aborted)?; +async fn recv_packet(mut rx: Pin<&mut Peekable>>) -> PacketBuf { + let packet = rx.next().await.unwrap_or(smallvec![]); + if Some(&Packet::Close) == packet.first() { #[cfg(feature = "tracing")] tracing::debug!("Received close packet, closing channel"); - rx.close(); + + // rx.close(); } #[cfg(feature = "tracing")] tracing::debug!("sending packet: {:?}", packet); - Ok(packet) + packet } /// Encode multiple packets into a string payload according to the /// [engine.io v4 protocol](https://socket.io/fr/docs/v4/engine-io-protocol/#http-long-polling-1) pub async fn v4_encoder( - mut rx: MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, max_payload: u64, -) -> Result { - use crate::transport::polling::payload::PACKET_SEPARATOR_V4; +) -> Payload { + use crate::payload::PACKET_SEPARATOR_V4; #[cfg(feature = "tracing")] tracing::debug!("encoding payload with v4 encoder"); @@ -85,7 +85,7 @@ pub async fn v4_encoder( // Send all packets in the buffer const PUNCTUATION_LEN: usize = 1; while let Some(packets) = - try_recv_packet(&mut rx, data.len() + PUNCTUATION_LEN, max_payload, true) + try_recv_packet(rx.as_mut(), data.len() + PUNCTUATION_LEN, max_payload, true) { for packet in packets { let packet: String = packet.into(); @@ -99,7 +99,7 @@ pub async fn v4_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { - let packets = recv_packet(&mut rx).await?; + let packets = recv_packet(rx.as_mut()).await; for packet in packets { if !data.is_empty() { data.push(std::char::from_u32(PACKET_SEPARATOR_V4 as u32).unwrap()); @@ -110,14 +110,14 @@ pub async fn v4_encoder( } } - Ok(Payload::new(data.into(), false)) + Payload::new(data.into(), false) } /// Encode one packet into a *binary* payload according to the /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub fn v3_bin_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { - use crate::transport::polling::payload::BINARY_PACKET_SEPARATOR_V3; + use crate::payload::BINARY_PACKET_SEPARATOR_V3; use bytes::BufMut; let mut itoa = itoa::Buffer::new(); @@ -157,7 +157,7 @@ pub fn v3_bin_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub fn v3_string_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { - use crate::transport::polling::payload::STRING_PACKET_SEPARATOR_V3; + use crate::payload::STRING_PACKET_SEPARATOR_V3; use bytes::BufMut; let packet: String = packet.into(); let packet = format!( @@ -173,9 +173,9 @@ pub fn v3_string_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { /// according to the [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub async fn v3_binary_encoder( - mut rx: MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, max_payload: u64, -) -> Result { +) -> Payload { let mut data = bytes::BytesMut::new(); let mut packet_buffer: Vec = Vec::new(); @@ -189,7 +189,7 @@ pub async fn v3_binary_encoder( // buffer all packets to find if there is binary packets let mut has_binary = false; - while let Some(packets) = try_recv_packet(&mut rx, estimated_size, max_payload, false) { + while let Some(packets) = try_recv_packet(rx.as_mut(), estimated_size, max_payload, false) { for packet in packets { if packet.is_binary() { has_binary = true; @@ -214,7 +214,7 @@ pub async fn v3_binary_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { - let packets = recv_packet(&mut rx).await?; + let packets = recv_packet(rx.as_mut()).await; has_binary = packets.iter().any(|p| p.is_binary()); for packet in packets { if has_binary { @@ -227,16 +227,16 @@ pub async fn v3_binary_encoder( #[cfg(feature = "tracing")] tracing::debug!("sending packet: {:?}", &data); - Ok(Payload::new(data.freeze(), has_binary)) + Payload::new(data.freeze(), has_binary) } /// Encode multiple packet packet into a *string* payload according to the /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub async fn v3_string_encoder( - mut rx: MutexGuard<'_, PeekableReceiver>, + mut rx: Pin<&mut Peekable>>, max_payload: u64, -) -> Result { +) -> Payload { let mut data = bytes::BytesMut::new(); #[cfg(feature = "tracing")] @@ -246,7 +246,7 @@ pub async fn v3_string_encoder( // number of digits of the max packet size, used to approximate the payload size let max_packet_size_len = max_payload.checked_ilog10().unwrap_or(0) as usize + 1; while let Some(packets) = try_recv_packet( - &mut rx, + rx.as_mut(), data.len() + PUNCTUATION_LEN + max_packet_size_len, max_payload, true, @@ -258,21 +258,19 @@ pub async fn v3_string_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { - let packets = recv_packet(&mut rx).await?; + let packets = recv_packet(rx.as_mut()).await; for packet in packets { v3_string_packet_encoder(packet, &mut data); } } - Ok(Payload::new(data.freeze(), false)) + Payload::new(data.freeze(), false) } #[cfg(test)] mod tests { use bytes::Bytes; - use tokio::sync::Mutex; - - use PacketBuf; + use futures_util::stream; use super::*; const MAX_PAYLOAD: u64 = 100_000; @@ -280,18 +278,15 @@ mod tests { #[tokio::test] async fn encode_v4_payload() { const PAYLOAD: &str = "4hello€\x1ebAQIDBA==\x1e4hello€"; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let rx = Mutex::new(PeekableReceiver::new(rx)); - let rx = rx.lock().await; - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Binary(Bytes::from_static(&[ - 1, 2, 3, 4 - ]))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Binary(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + ]); + let rx = std::pin::pin!(rx.peekable()); + + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await; assert_eq!(data, PAYLOAD.as_bytes()); } @@ -299,8 +294,8 @@ mod tests { async fn encode_v4_payload_parked_poll_multi_packet_batch() { const PAYLOAD: &str = "4hello€\x1ebAQIDBA=="; let (tx, rx) = tokio::sync::mpsc::channel::(10); - let rx = Mutex::new(PeekableReceiver::new(rx)); - let rx = rx.lock().await; + let rx = tokio_stream::wrappers::ReceiverStream::new(rx); + let rx = std::pin::pin!(rx.peekable()); tokio::spawn(async move { tokio::time::sleep(std::time::Duration::from_millis(50)).await; tx.try_send(smallvec::smallvec![ @@ -309,38 +304,33 @@ mod tests { ]) .unwrap(); }); - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); - assert_eq!(data, PAYLOAD); + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await; + assert_eq!(data, PAYLOAD.as_bytes()); } #[tokio::test] async fn max_payload_v4() { const MAX_PAYLOAD: u64 = 10; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Binary(Bytes::from_static(&[ - 1, 2, 3, 4 - ]))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Binary(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Message("hello€".into())], + ]); + + let mut rx = std::pin::pin!(rx.peekable()); + { - let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(data, "4hello€".as_bytes()); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx.as_mut(), MAX_PAYLOAD + 10).await; assert_eq!(data, "bAQIDBA==\x1e4hello€".as_bytes()); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx.as_mut(), MAX_PAYLOAD + 10).await; assert_eq!(data, "4hello€".as_bytes()); } } @@ -353,12 +343,9 @@ mod tests { // (packet type '4' + non-BMP codepoint U+1D54A) has 2 codepoints // but 3 UTF-16 code units. const PAYLOAD: &str = "3:4𝕊"; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - let rx = mutex.lock().await; - tx.try_send(smallvec::smallvec![Packet::Message("𝕊".into())]) - .unwrap(); - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let rx = stream::iter([smallvec![Packet::Message("𝕊".into())]]); + let rx = std::pin::pin!(rx.peekable()); + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD).await; assert_eq!(data, PAYLOAD.as_bytes()); } @@ -366,21 +353,14 @@ mod tests { #[tokio::test] async fn encode_v3b64_payload() { const PAYLOAD: &str = "7:4hello€10:b4AQIDBA==7:4hello€"; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - let rx = mutex.lock().await; - - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - let Payload { - data, has_binary, .. - } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + ]); + + let rx = std::pin::pin!(rx.peekable()); + let Payload { data, has_binary } = v3_string_encoder(rx, MAX_PAYLOAD).await; assert_eq!(data, PAYLOAD.as_bytes()); assert!(!has_binary); } @@ -390,33 +370,21 @@ mod tests { async fn max_payload_v3_b64() { const MAX_PAYLOAD: u64 = 10; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); + let rx = stream::iter(vec![ + smallvec::smallvec![Packet::Message("hello€".into())], + smallvec::smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec::smallvec![Packet::Message("hello€".into())], + smallvec::smallvec![Packet::Message("hello€".into())], + ]); + let mut rx = std::pin::pin!(rx.peekable()); + { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(data, "7:4hello€".as_bytes()); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); - assert_eq!(data, "10:b4AQIDBA==".as_bytes()); - } - { - // Next call drains one of the remaining Message packets. - let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); - assert_eq!(data, "7:4hello€".as_bytes()); + let Payload { data, .. } = v3_string_encoder(rx.as_mut(), MAX_PAYLOAD + 10).await; + assert_eq!(data, "10:b4AQIDBA==7:4hello€7:4hello€".as_bytes()); } } @@ -426,19 +394,14 @@ mod tests { const PAYLOAD: [u8; 20] = [ 0, 9, 255, 52, 104, 101, 108, 108, 111, 226, 130, 172, 1, 5, 255, 4, 1, 2, 3, 4, ]; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - let rx = mutex.lock().await; - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - let Payload { - data, has_binary, .. - } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let rx = stream::iter([ + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + ]); + let rx = std::pin::pin!(rx.peekable()); + + let Payload { data, has_binary } = v3_binary_encoder(rx, MAX_PAYLOAD).await; assert_eq!(*data, PAYLOAD); assert!(has_binary); } @@ -455,8 +418,8 @@ mod tests { 0, 9, 255, 52, 104, 101, 108, 108, 111, 226, 130, 172, 1, 5, 255, 4, 1, 2, 3, 4, ]; let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - let rx = mutex.lock().await; + let rx = tokio_stream::wrappers::ReceiverStream::new(rx); + let rx = std::pin::pin!(rx.peekable()); tokio::spawn(async move { tokio::time::sleep(std::time::Duration::from_millis(50)).await; tx.try_send(smallvec::smallvec![ @@ -467,7 +430,7 @@ mod tests { }); let Payload { data, has_binary, .. - } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + } = v3_binary_encoder(rx, MAX_PAYLOAD).await; assert_eq!(*data, PAYLOAD); assert!(has_binary); } @@ -481,26 +444,21 @@ mod tests { 0, 1, 1, 255, 52, 104, 101, 108, 108, 111, 111, 111, 226, 130, 172, 1, 5, 255, 4, 1, 2, 3, 4, ]; - let (tx, rx) = tokio::sync::mpsc::channel::(10); - let mutex = Mutex::new(PeekableReceiver::new(rx)); - tx.try_send(smallvec::smallvec![Packet::Message("hellooo€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::BinaryV3(Bytes::from_static( - &[1, 2, 3, 4] - ))]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); - tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) - .unwrap(); + + let rx = stream::iter([ + smallvec![Packet::Message("hellooo€".into())], + smallvec![Packet::BinaryV3(Bytes::from_static(&[1, 2, 3, 4]))], + smallvec![Packet::Message("hello€".into())], + smallvec![Packet::Message("hello€".into())], + ]); + let mut rx = std::pin::pin!(rx.peekable()); + { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_binary_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(*data, PAYLOAD); } { - let rx = mutex.lock().await; - let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_binary_encoder(rx.as_mut(), MAX_PAYLOAD).await; assert_eq!(data, "7:4hello€7:4hello€".as_bytes()); } } diff --git a/crates/engineioxide/src/transport/polling/payload/mod.rs b/crates/engineioxide-core/src/payload/mod.rs similarity index 70% rename from crates/engineioxide/src/transport/polling/payload/mod.rs rename to crates/engineioxide-core/src/payload/mod.rs index db62cc96..4ab1c33a 100644 --- a/crates/engineioxide/src/transport/polling/payload/mod.rs +++ b/crates/engineioxide-core/src/payload/mod.rs @@ -1,13 +1,10 @@ //! Payload encoder and decoder for polling transport. -use crate::{ - errors::Error, packet::Packet, peekable::PeekableReceiver, service::ProtocolVersion, - socket::PacketBuf, -}; +use crate::{Packet, PacketParseError, ProtocolVersion, packet::PacketBuf}; + use bytes::Bytes; -use futures_core::Stream; -use http::Request; -use tokio::sync::MutexGuard; +use futures_util::{Stream, StreamExt}; +use http::HeaderValue; mod buf; mod decoder; @@ -22,20 +19,24 @@ const BINARY_PACKET_SEPARATOR_V3: u8 = 0xff; const STRING_PACKET_IDENTIFIER_V3: u8 = 0x00; #[cfg(feature = "v3")] const BINARY_PACKET_IDENTIFIER_V3: u8 = 0x01; +#[cfg(feature = "v3")] +const BINARY_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/octet-stream"); +/// Decode a payload into a stream of packets. pub fn decoder( - body: Request + Unpin>, - #[allow(unused_variables)] protocol: ProtocolVersion, + body: impl http_body::Body + Unpin, + #[allow(unused)] content_type: Option<&HeaderValue>, + #[allow(unused)] protocol: ProtocolVersion, max_payload: u64, -) -> impl Stream> { +) -> impl Stream> { + #[cfg(feature = "tracing")] + tracing::debug!(?content_type, %protocol, "decoding payload"); + #[cfg(feature = "v3")] { use futures_util::future::Either; - use http::header::CONTENT_TYPE; - #[cfg(feature = "tracing")] - tracing::debug!("decoding payload {:?}", body.headers().get(CONTENT_TYPE)); - let is_binary = - body.headers().get(CONTENT_TYPE) == Some(&"application/octet-stream".parse().unwrap()); + + let is_binary = content_type == Some(&BINARY_CONTENT_TYPE); match protocol { ProtocolVersion::V4 => Either::Left(decoder::v4_decoder(body, max_payload)), ProtocolVersion::V3 if is_binary => { @@ -54,23 +55,28 @@ pub fn decoder( } /// A payload to transmit to the client through http polling -#[non_exhaustive] pub struct Payload { + /// The data of the payload. pub data: Bytes, + /// Whether the payload contains binary data. pub has_binary: bool, } impl Payload { + /// Creates a new payload with the given data and binary flag. pub fn new(data: Bytes, has_binary: bool) -> Self { Self { data, has_binary } } } +/// Encodes a payload into a byte stream. pub async fn encoder( - rx: MutexGuard<'_, PeekableReceiver>, - #[allow(unused_variables)] protocol: ProtocolVersion, - #[cfg(feature = "v3")] supports_binary: bool, + rx: impl Stream, + #[allow(unused)] protocol: ProtocolVersion, + #[allow(unused)] supports_binary: bool, max_payload: u64, -) -> Result { +) -> Payload { + let rx = std::pin::pin!(rx.peekable()); + #[cfg(feature = "v3")] { match protocol { @@ -91,8 +97,8 @@ pub async fn encoder( /// Encodes a single packet into a byte array. pub fn packet_encoder( packet: Packet, - #[allow(unused_variables)] protocol: ProtocolVersion, - #[cfg(feature = "v3")] supports_binary: bool, + #[allow(unused)] protocol: ProtocolVersion, + #[allow(unused)] supports_binary: bool, ) -> Bytes { #[cfg(feature = "v3")] { diff --git a/crates/engineioxide-core/src/protocol.rs b/crates/engineioxide-core/src/protocol.rs new file mode 100644 index 00000000..acc9714f --- /dev/null +++ b/crates/engineioxide-core/src/protocol.rs @@ -0,0 +1,125 @@ +use std::{fmt, str::FromStr}; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// The type of `transport` used to connect to the client/server. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum TransportType { + /// Polling transport + Polling = 0x01, + /// Websocket transport + Websocket = 0x02, +} + +impl From for TransportType { + fn from(t: u8) -> Self { + match t { + 0x01 => TransportType::Polling, + 0x02 => TransportType::Websocket, + _ => panic!("unknown transport type"), + } + } +} + +impl FromStr for TransportType { + type Err = UnknownTransportError; + + fn from_str(s: &str) -> Result { + match s { + "websocket" => Ok(TransportType::Websocket), + "polling" => Ok(TransportType::Polling), + _ => Err(UnknownTransportError), + } + } +} +impl From for &'static str { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling", + TransportType::Websocket => "websocket", + } + } +} +impl From for String { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling".into(), + TransportType::Websocket => "websocket".into(), + } + } +} + +impl Serialize for TransportType { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str((*self).into()) + } +} + +impl<'de> Deserialize<'de> for TransportType { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::from_str(&s).map_err(serde::de::Error::custom) + } +} + +/// Cannot determine the transport type to connect to the client/server. +#[derive(Debug, Copy, Clone)] +pub struct UnknownTransportError; +impl std::fmt::Display for UnknownTransportError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "unknown transport type") + } +} +impl std::error::Error for UnknownTransportError {} + +/// == ProtocolVersion == + +#[derive(Debug)] +pub struct UnknownProtocolVersionError; +impl std::fmt::Display for UnknownProtocolVersionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "unknown protocol version") + } +} +impl std::error::Error for UnknownProtocolVersionError {} + +/// The engine.io protocol version +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ProtocolVersion { + /// The protocol version 3 + V3 = 3, + /// The protocol version 4 + V4 = 4, +} + +impl fmt::Display for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&(*self as u8), f) + } +} +impl FromStr for ProtocolVersion { + type Err = UnknownProtocolVersionError; + + #[cfg(feature = "v3")] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + "4" => Ok(ProtocolVersion::V4), + _ => Err(UnknownProtocolVersionError), + } + } + + #[cfg(not(feature = "v3"))] + fn from_str(s: &str) -> Result { + match s { + "4" => Ok(ProtocolVersion::V4), + _ => Err(UnknownProtocolVersionError), + } + } +} diff --git a/crates/engineioxide/Cargo.toml b/crates/engineioxide/Cargo.toml index b604481d..105ab198 100644 --- a/crates/engineioxide/Cargo.toml +++ b/crates/engineioxide/Cargo.toml @@ -33,21 +33,15 @@ tokio-util.workspace = true tower-service.workspace = true tower-layer.workspace = true hyper.workspace = true -tokio-tungstenite.workspace = true +tokio-tungstenite = { workspace = true, features = ["handshake"] } http-body-util.workspace = true pin-project-lite.workspace = true smallvec.workspace = true hyper-util = { workspace = true, features = ["tokio"] } -base64 = "0.22" - # Tracing tracing = { workspace = true, optional = true } -# Engine.io V3 payload -itoa = { workspace = true, optional = true } -memchr = { version = "2.7", optional = true } - [dev-dependencies] tokio = { workspace = true, features = ["macros", "parking_lot"] } tracing-subscriber.workspace = true @@ -55,8 +49,10 @@ hyper = { workspace = true, features = ["server", "http1"] } criterion.workspace = true axum.workspace = true tokio-stream.workspace = true +tokio-util.workspace = true + [features] -v3 = ["memchr", "itoa"] +v3 = ["engineioxide-core/v3"] tracing = ["dep:tracing"] __test_harness = [] diff --git a/crates/engineioxide/benches/packet_encode.rs b/crates/engineioxide/benches/packet_encode.rs index 0cfc8842..f1d8f483 100644 --- a/crates/engineioxide/benches/packet_encode.rs +++ b/crates/engineioxide/benches/packet_encode.rs @@ -1,15 +1,11 @@ use bytes::Bytes; use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; -use engineioxide::{OpenPacket, Packet, TransportType, config::EngineIoConfig, socket::Sid}; +use engineioxide::{OpenPacket, Packet}; fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("engineio_packet/encode"); group.bench_function("Encode packet open", |b| { - let packet = Packet::Open(OpenPacket::new( - black_box(TransportType::Polling), - black_box(Sid::ZERO), - &EngineIoConfig::default(), - )); + let packet = Packet::Open(OpenPacket::default()); b.iter_batched( || packet.clone(), TryInto::::try_into, diff --git a/crates/engineioxide/src/config.rs b/crates/engineioxide/src/config.rs index aa74a583..2cd9f603 100644 --- a/crates/engineioxide/src/config.rs +++ b/crates/engineioxide/src/config.rs @@ -32,7 +32,7 @@ use std::{borrow::Cow, time::Duration}; -use crate::service::TransportType; +use engineioxide_core::TransportType; /// Configuration for the engine.io engine & transports #[derive(Debug, Clone)] diff --git a/crates/engineioxide/src/engine.rs b/crates/engineioxide/src/engine.rs index 99353b7c..b2a694e7 100644 --- a/crates/engineioxide/src/engine.rs +++ b/crates/engineioxide/src/engine.rs @@ -3,13 +3,12 @@ use std::{ sync::{Arc, RwLock}, }; -use engineioxide_core::Sid; +use engineioxide_core::{ProtocolVersion, Sid, TransportType}; use http::request::Parts; use crate::{ config::EngineIoConfig, handler::EngineIoHandler, - service::{ProtocolVersion, TransportType}, socket::{DisconnectReason, Socket}, }; @@ -45,7 +44,7 @@ impl EngineIo { protocol: ProtocolVersion, transport: TransportType, req: Parts, - #[cfg(feature = "v3")] supports_binary: bool, + supports_binary: bool, ) -> Arc> { let engine = self.clone(); let close_fn = Box::new(move |sid, reason| engine.close_session(sid, reason)); @@ -56,7 +55,6 @@ impl EngineIo { &self.config, req, close_fn, - #[cfg(feature = "v3")] supports_binary, ); let socket = Arc::new(socket); diff --git a/crates/engineioxide/src/errors.rs b/crates/engineioxide/src/errors.rs index 29b335a0..71296e9c 100644 --- a/crates/engineioxide/src/errors.rs +++ b/crates/engineioxide/src/errors.rs @@ -3,26 +3,23 @@ use tokio::sync::mpsc; use tokio_tungstenite::tungstenite; use crate::body::ResponseBody; -use crate::packet::Packet; -use engineioxide_core::Sid; +use engineioxide_core::{Packet, Sid}; + +pub use engineioxide_core::PacketParseError; #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("error decoding binary packet from polling request: {0:?}")] - Base64(#[from] base64::DecodeError), - #[error("error decoding packet: {0:?}")] - StrUtf8(#[from] std::str::Utf8Error), - #[error("io error: {0:?}")] - Io(#[from] std::io::Error), + #[error("error decoding packet from request: {0}")] + PacketParse(#[from] PacketParseError), #[error("bad packet received")] BadPacket(Packet), - #[error("ws transport error: {0:?}")] + #[error("ws transport error: {0}")] WsTransport(#[from] Box), - #[error("http error: {0:?}")] + #[error("http error: {0}")] Http(#[from] http::Error), - #[error("internal channel error: {0:?}")] + #[error("internal channel error: {0}")] SendChannel(#[from] mpsc::error::TrySendError), - #[error("internal channel error: {0:?}")] + #[error("internal channel error: {0}")] RecvChannel(#[from] mpsc::error::TryRecvError), #[error("heartbeat timeout")] HeartbeatTimeout, @@ -30,23 +27,16 @@ pub enum Error { Upgrade, #[error("multiple ws upgrade requests")] MultipleWebsocketRequests, - #[error("aborted connection")] - Aborted, - #[error("http error response: {0:?}")] - HttpErrorResponse(StatusCode), + #[error("multiple http polling error")] + MultipleHttpPolling, + #[error("invalid websocket Sec-WebSocket-Key http header")] + InvalidWebSocketKey, #[error("unknown session id")] UnknownSessionID(Sid), #[error("transport mismatch")] TransportMismatch, - #[error("payload too large")] - PayloadTooLarge, - - #[error("Invalid packet length")] - InvalidPacketLength, - #[error("Invalid packet type")] - InvalidPacketType(Option), } /// Convert an error into an http response @@ -62,18 +52,15 @@ impl From for Response> { .unwrap() }; match err { - Error::HttpErrorResponse(code) => Response::builder() - .status(code) + Error::PacketParse(PacketParseError::PayloadTooLarge { .. }) => Response::builder() + .status(413) .body(ResponseBody::empty_response()) .unwrap(), - Error::BadPacket(_) | Error::InvalidPacketLength | Error::InvalidPacketType(_) => { - Response::builder() - .status(400) - .body(ResponseBody::empty_response()) - .unwrap() - } - Error::PayloadTooLarge => Response::builder() - .status(413) + Error::BadPacket(_) + | Error::PacketParse(_) + | Error::MultipleHttpPolling + | Error::InvalidWebSocketKey => Response::builder() + .status(400) .body(ResponseBody::empty_response()) .unwrap(), diff --git a/crates/engineioxide/src/lib.rs b/crates/engineioxide/src/lib.rs index b8f34b3b..94cb84d7 100644 --- a/crates/engineioxide/src/lib.rs +++ b/crates/engineioxide/src/lib.rs @@ -2,13 +2,12 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![doc = include_str!("../README.md")] -pub use engineioxide_core::Str; -pub use service::{ProtocolVersion, TransportType}; +pub use engineioxide_core::{ProtocolVersion, Str, TransportType}; pub use socket::{DisconnectReason, Socket}; #[doc(hidden)] #[cfg(feature = "__test_harness")] -pub use packet::*; +pub use engineioxide_core::{OpenPacket, Packet, PacketParseError}; pub mod config; pub mod handler; @@ -25,6 +24,4 @@ pub mod sid { mod body; mod engine; mod errors; -mod packet; -mod peekable; mod transport; diff --git a/crates/engineioxide/src/peekable.rs b/crates/engineioxide/src/peekable.rs deleted file mode 100644 index bc8ddb6b..00000000 --- a/crates/engineioxide/src/peekable.rs +++ /dev/null @@ -1,79 +0,0 @@ -use tokio::sync::mpsc::{Receiver, error::TryRecvError}; - -/// Peekable receiver for polling transport -/// It is a thin wrapper around a [`Receiver`](tokio::sync::mpsc::Receiver) that allows to peek the next packet without consuming it -/// -/// Its main goal is to be able to peek the next packet without consuming it to calculate the -/// packet length when using polling transport to check if it fits according to the max_payload setting -#[derive(Debug)] -pub struct PeekableReceiver { - rx: Receiver, - next: Option, -} -impl PeekableReceiver { - pub fn new(rx: Receiver) -> Self { - Self { rx, next: None } - } - pub fn peek(&mut self) -> Option<&T> { - if self.next.is_none() { - self.next = self.rx.try_recv().ok(); - } - self.next.as_ref() - } - pub async fn recv(&mut self) -> Option { - if self.next.is_none() { - self.rx.recv().await - } else { - self.next.take() - } - } - pub fn try_recv(&mut self) -> Result { - if self.next.is_none() { - self.rx.try_recv() - } else { - Ok(self.next.take().unwrap()) - } - } - - pub fn close(&mut self) { - self.rx.close() - } -} - -#[cfg(test)] -mod tests { - use tokio::sync::Mutex; - - #[tokio::test] - async fn peek() { - use super::PeekableReceiver; - use crate::packet::Packet; - use tokio::sync::mpsc::channel; - - let (tx, rx) = channel(1); - let rx = Mutex::new(PeekableReceiver::new(rx)); - let mut rx = rx.lock().await; - - assert!(rx.peek().is_none()); - - tx.send(Packet::Ping).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Ping)); - assert_eq!(rx.recv().await, Some(Packet::Ping)); - assert!(rx.peek().is_none()); - - tx.send(Packet::Pong).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Pong)); - assert_eq!(rx.recv().await, Some(Packet::Pong)); - assert!(rx.peek().is_none()); - - tx.send(Packet::Close).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Close)); - assert_eq!(rx.recv().await, Some(Packet::Close)); - assert!(rx.peek().is_none()); - - tx.send(Packet::Close).await.unwrap(); - assert_eq!(rx.peek(), Some(&Packet::Close)); - assert_eq!(rx.recv().await, Some(Packet::Close)); - assert!(rx.peek().is_none()); - } -} diff --git a/crates/engineioxide/src/service/mod.rs b/crates/engineioxide/src/service/mod.rs index 4bce8fe6..6b9d1c13 100644 --- a/crates/engineioxide/src/service/mod.rs +++ b/crates/engineioxide/src/service/mod.rs @@ -33,6 +33,8 @@ use std::{ }; use bytes::Bytes; +#[cfg(feature = "__test_harness")] +use engineioxide_core::ProtocolVersion; use futures_util::future::{self, Ready}; use http::{Request, Response}; use http_body::Body; @@ -47,7 +49,6 @@ use crate::{ mod futures; mod parser; -pub use self::parser::{ProtocolVersion, TransportType}; use self::{futures::ResponseFuture, parser::dispatch_req}; /// A `Service` that handles engine.io requests as a middleware. diff --git a/crates/engineioxide/src/service/parser.rs b/crates/engineioxide/src/service/parser.rs index f1235063..2cb5623a 100644 --- a/crates/engineioxide/src/service/parser.rs +++ b/crates/engineioxide/src/service/parser.rs @@ -1,11 +1,10 @@ //! A Parser module to parse any `EngineIo` query -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::{future::Future, str::FromStr, sync::Arc}; +use std::{future::Future, sync::Arc}; use http::{Method, Request, Response}; -use engineioxide_core::Sid; +use engineioxide_core::{ProtocolVersion, Sid, TransportType}; use crate::{ body::ResponseBody, @@ -35,15 +34,8 @@ where sid: None, transport: TransportType::Polling, method: Method::GET, - #[cfg(feature = "v3")] b64, - }) => ResponseFuture::ready(polling::open_req( - engine, - protocol, - req, - #[cfg(feature = "v3")] - !b64, - )), + }) => ResponseFuture::ready(polling::open_req(engine, protocol, req, !b64)), Ok(RequestInfo { protocol, sid: Some(sid), @@ -118,102 +110,6 @@ impl From for Response> { } } -/// The engine.io protocol version -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum ProtocolVersion { - /// The protocol version 3 - V3 = 3, - /// The protocol version 4 - V4 = 4, -} - -impl FromStr for ProtocolVersion { - type Err = ParseError; - - #[cfg(feature = "v3")] - fn from_str(s: &str) -> Result { - match s { - "3" => Ok(ProtocolVersion::V3), - "4" => Ok(ProtocolVersion::V4), - _ => Err(ParseError::UnsupportedProtocolVersion), - } - } - - #[cfg(not(feature = "v3"))] - fn from_str(s: &str) -> Result { - match s { - "4" => Ok(ProtocolVersion::V4), - _ => Err(ParseError::UnsupportedProtocolVersion), - } - } -} - -/// The type of `transport` used by the client. -#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] -pub enum TransportType { - /// Polling transport - Polling = 0x01, - /// Websocket transport - Websocket = 0x02, -} - -impl Serialize for TransportType { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str((*self).into()) - } -} - -impl<'de> Deserialize<'de> for TransportType { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - Self::from_str(&s).map_err(serde::de::Error::custom) - } -} - -impl From for TransportType { - fn from(t: u8) -> Self { - match t { - 0x01 => TransportType::Polling, - 0x02 => TransportType::Websocket, - _ => panic!("unknown transport type"), - } - } -} - -impl FromStr for TransportType { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - match s { - "websocket" => Ok(TransportType::Websocket), - "polling" => Ok(TransportType::Polling), - _ => Err(ParseError::UnknownTransport), - } - } -} -impl From for &'static str { - fn from(t: TransportType) -> Self { - match t { - TransportType::Polling => "polling", - TransportType::Websocket => "websocket", - } - } -} -impl From for String { - fn from(t: TransportType) -> Self { - match t { - TransportType::Polling => "polling".into(), - TransportType::Websocket => "websocket".into(), - } - } -} - /// The request information extracted from the request URI. #[derive(Debug)] pub struct RequestInfo { @@ -226,7 +122,6 @@ pub struct RequestInfo { /// The request method. pub method: Method, /// If the client asked for base64 encoding only. - #[cfg(feature = "v3")] pub b64: bool, } @@ -240,8 +135,8 @@ impl RequestInfo { .split('&') .find(|s| s.starts_with("EIO=")) .and_then(|s| s.split('=').nth(1)) - .ok_or(UnsupportedProtocolVersion) - .and_then(|t| t.parse())?; + .and_then(|t| t.parse().ok()) + .ok_or(UnsupportedProtocolVersion)?; let sid = query .split('&') @@ -253,8 +148,8 @@ impl RequestInfo { .split('&') .find(|s| s.starts_with("transport=")) .and_then(|s| s.split('=').nth(1)) - .ok_or(UnknownTransport) - .and_then(|t| t.parse())?; + .and_then(|t| t.parse().ok()) + .ok_or(UnknownTransport)?; if !config.allowed_transport(transport) { return Err(TransportMismatch); @@ -267,6 +162,9 @@ impl RequestInfo { .map(|v| v == "1" || v == "true") .unwrap_or_default(); + #[cfg(not(feature = "v3"))] + let b64: bool = false; + let method = req.method().clone(); if !matches!(method, Method::GET) && sid.is_none() { Err(BadHandshakeMethod) @@ -276,7 +174,6 @@ impl RequestInfo { sid, transport, method, - #[cfg(feature = "v3")] b64, }) } diff --git a/crates/engineioxide/src/socket.rs b/crates/engineioxide/src/socket.rs index 222e3d8f..5a09d14c 100644 --- a/crates/engineioxide/src/socket.rs +++ b/crates/engineioxide/src/socket.rs @@ -63,15 +63,9 @@ use std::{ time::Duration, }; -use crate::{ - config::EngineIoConfig, - errors::Error, - packet::Packet, - peekable::PeekableReceiver, - service::{ProtocolVersion, TransportType}, -}; +use crate::{config::EngineIoConfig, errors::Error}; use bytes::Bytes; -use engineioxide_core::Str; +use engineioxide_core::{Packet, PacketBuf, ProtocolVersion, Str, TransportType}; use futures_util::FutureExt; use http::request::Parts; use smallvec::{SmallVec, smallvec}; @@ -108,9 +102,8 @@ impl From<&Error> for Option { fn from(err: &Error) -> Self { use Error::*; match err { - WsTransport(_) | Io(_) => Some(DisconnectReason::TransportError), - BadPacket(_) | Base64(_) | StrUtf8(_) | PayloadTooLarge | InvalidPacketLength - | InvalidPacketType(_) => Some(DisconnectReason::PacketParsingError), + WsTransport(_) => Some(DisconnectReason::TransportError), + BadPacket(_) | PacketParse(_) => Some(DisconnectReason::PacketParsingError), HeartbeatTimeout => Some(DisconnectReason::HeartbeatTimeout), _ => None, } @@ -159,13 +152,6 @@ impl Permit<'_> { } } -/// Buffered packets to send to the client. -/// It is used to ensure atomicity when sending multiple packets to the client. -/// -/// The [`PacketBuf`] stack size will impact the dynamically allocated buffer -/// of the internal mpsc channel. -pub(crate) type PacketBuf = SmallVec<[Packet; 2]>; - /// A [`Socket`] represents a client connection to the server. /// It is agnostic to the [`TransportType`]. /// @@ -207,7 +193,7 @@ where /// Because with polling transport, if the client is not currently polling then the encoder will never be able to close the channel /// /// The channel is made of a [`SmallVec`] of [`Packet`]s so that adjacent packets can be sent atomically. - pub(crate) internal_rx: Mutex>, + pub(crate) internal_rx: Mutex>, /// Channel to send [PacketBuf] to the internal connection internal_tx: mpsc::Sender, @@ -231,7 +217,6 @@ where pub req_parts: Parts, /// If the client supports binary packets (via polling XHR2) - #[cfg(feature = "v3")] pub(crate) supports_binary: bool, } @@ -245,7 +230,7 @@ where config: &EngineIoConfig, req_parts: Parts, close_fn: Box, - #[cfg(feature = "v3")] supports_binary: bool, + supports_binary: bool, ) -> Self { let (internal_tx, internal_rx) = mpsc::channel(config.max_buffer_size); let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); @@ -256,7 +241,7 @@ where transport: AtomicU8::new(transport as u8), upgrading: AtomicBool::new(false), - internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)), + internal_rx: Mutex::new(internal_rx), internal_tx, heartbeat_rx: Mutex::new(heartbeat_rx), @@ -268,7 +253,6 @@ where data: D::default(), req_parts, - #[cfg(feature = "v3")] supports_binary, } } @@ -555,7 +539,7 @@ where transport: AtomicU8::new(TransportType::Websocket as u8), upgrading: AtomicBool::new(false), - internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)), + internal_rx: Mutex::new(internal_rx), internal_tx, heartbeat_rx: Mutex::new(heartbeat_rx), @@ -566,7 +550,6 @@ where data: D::default(), req_parts: http::Request::<()>::default().into_parts().0, - #[cfg(feature = "v3")] supports_binary: true, }; let sock = Arc::new(sock); diff --git a/crates/engineioxide/src/transport/mod.rs b/crates/engineioxide/src/transport/mod.rs index 94d843a5..d0ad9a45 100644 --- a/crates/engineioxide/src/transport/mod.rs +++ b/crates/engineioxide/src/transport/mod.rs @@ -1,4 +1,23 @@ //! All transports modules available in engineioxide +use engineioxide_core::{OpenPacket, Sid, TransportType}; + +use crate::config::EngineIoConfig; + pub mod polling; pub mod ws; + +fn make_open_packet(transport: TransportType, id: Sid, config: &EngineIoConfig) -> OpenPacket { + let upgrades = if transport == TransportType::Polling { + smallvec::smallvec![TransportType::Websocket] + } else { + smallvec::smallvec![] + }; + OpenPacket { + sid: id, + upgrades, + ping_timeout: config.ping_timeout, + ping_interval: config.ping_interval, + max_payload: config.max_payload, + } +} diff --git a/crates/engineioxide/src/transport/polling/mod.rs b/crates/engineioxide/src/transport/polling.rs similarity index 73% rename from crates/engineioxide/src/transport/polling/mod.rs rename to crates/engineioxide/src/transport/polling.rs index e5207d35..8120a8c6 100644 --- a/crates/engineioxide/src/transport/polling/mod.rs +++ b/crates/engineioxide/src/transport/polling.rs @@ -2,26 +2,19 @@ use std::sync::Arc; use bytes::Bytes; +use engineioxide_core::payload::{self, Payload}; use futures_util::StreamExt; -use http::{Request, Response, StatusCode}; +use http::{Request, Response, StatusCode, header::CONTENT_TYPE}; use http_body::Body; use http_body_util::Full; -use engineioxide_core::Sid; +use engineioxide_core::{Packet, ProtocolVersion, Sid, TransportType}; use crate::{ - DisconnectReason, - body::ResponseBody, - engine::EngineIo, - errors::Error, - handler::EngineIoHandler, - packet::{OpenPacket, Packet}, - service::{ProtocolVersion, TransportType}, - transport::polling::payload::Payload, + DisconnectReason, body::ResponseBody, engine::EngineIo, errors::Error, + handler::EngineIoHandler, transport::make_open_packet, }; -mod payload; - /// Create a response for http request fn http_response( code: StatusCode, @@ -48,7 +41,7 @@ pub fn open_req( engine: Arc>, protocol: ProtocolVersion, req: Request, - #[cfg(feature = "v3")] supports_binary: bool, + supports_binary: bool, ) -> Result>, Error> where H: EngineIoHandler, @@ -58,11 +51,10 @@ where protocol, TransportType::Polling, req.into_parts().0, - #[cfg(feature = "v3")] supports_binary, ); - let packet = OpenPacket::new(TransportType::Polling, socket.id, &engine.config); + let packet = make_open_packet(TransportType::Polling, socket.id, &engine.config); socket.spawn_heartbeat(engine.config.ping_interval, engine.config.ping_timeout); @@ -106,10 +98,7 @@ where #[cfg(feature = "tracing")] tracing::debug!(?sid, "socket is upgrading, sending NOOP packet"); - #[cfg(feature = "v3")] let data = payload::packet_encoder(Packet::Noop, socket.protocol, socket.supports_binary); - #[cfg(not(feature = "v3"))] - let data = payload::packet_encoder(Packet::Noop, socket.protocol); let is_binary = false; // The noop packet is guaranteed to be serialized as text return Ok(http_response(StatusCode::OK, data, is_binary)?); @@ -117,27 +106,29 @@ where // If the socket is already locked, it means that the socket is being used by another request // In case of multiple http polling, session should be closed - let rx = match socket.internal_rx.try_lock() { + let mut rx = match socket.internal_rx.try_lock() { Ok(s) => s, Err(_) => { socket.close(DisconnectReason::MultipleHttpPollingError); - return Err(Error::HttpErrorResponse(StatusCode::BAD_REQUEST)); + return Err(Error::MultipleHttpPolling); } }; + //TODO: handle closing channel packet better than in %the encoding process. + #[cfg(feature = "tracing")] - tracing::debug!("[sid={sid}] polling request"); + tracing::debug!(%sid, %protocol, supports_binary = socket.supports_binary, "polling request"); let max_payload = engine.config.max_payload; - #[cfg(feature = "v3")] + let rx = rx_stream::ReceiverStream::new(&mut rx); + let Payload { data, has_binary } = - payload::encoder(rx, protocol, socket.supports_binary, max_payload).await?; - #[cfg(not(feature = "v3"))] - let Payload { data, has_binary } = payload::encoder(rx, protocol, max_payload).await?; + payload::encoder(rx, protocol, socket.supports_binary, max_payload).await; #[cfg(feature = "tracing")] - tracing::debug!("[sid={sid}] sending data: {:?}", data); + tracing::trace!(%sid, %protocol, supports_binary = socket.supports_binary, "sending data: {:?}", data); + Ok(http_response(StatusCode::OK, data, has_binary)?) } @@ -148,7 +139,7 @@ pub async fn post_req( engine: Arc>, protocol: ProtocolVersion, sid: Sid, - body: Request, + req: Request, ) -> Result>, Error> where H: EngineIoHandler, @@ -162,7 +153,9 @@ where return Err(Error::TransportMismatch); } - let packets = payload::decoder(body, protocol, engine.config.max_payload); + let (parts, body) = req.into_parts(); + let content_type = parts.headers.get(CONTENT_TYPE); + let packets = payload::decoder(body, content_type, protocol, engine.config.max_payload); futures_util::pin_mut!(packets); while let Some(packet) = packets.next().await { @@ -195,9 +188,39 @@ where #[cfg(feature = "tracing")] tracing::debug!("[sid={sid}] error parsing packet: {:?}", e); engine.close_session(sid, DisconnectReason::PacketParsingError); - return Err(e); + return Err(e.into()); } }?; } Ok(http_response(StatusCode::OK, "ok", false)?) } + +mod rx_stream { + use std::{ + pin::Pin, + task::{Context, Poll}, + }; + + use futures_core::Stream; + use tokio::sync::mpsc::Receiver; + + /// [`ReceiverStream`] is a stream that wraps a tokio::sync::mpsc::Receiver by reference. + /// Allowing to use it as a stream even if it is behind a mutex. + pub struct ReceiverStream<'a, T> { + inner: &'a mut Receiver, + } + + impl<'a, T> ReceiverStream<'a, T> { + pub fn new(inner: &'a mut Receiver) -> Self { + Self { inner } + } + } + + impl<'a, T> Stream for ReceiverStream<'a, T> { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } + } +} diff --git a/crates/engineioxide/src/transport/ws.rs b/crates/engineioxide/src/transport/ws.rs index 565a9940..f73f5ce5 100644 --- a/crates/engineioxide/src/transport/ws.rs +++ b/crates/engineioxide/src/transport/ws.rs @@ -15,24 +15,17 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_tungstenite::{ WebSocketStream, tungstenite::{ - Message, + self, Message, handshake::derive_accept_key, protocol::{Role, WebSocketConfig}, }, }; -use engineioxide_core::{Sid, Str}; +use engineioxide_core::{Packet, ProtocolVersion, Sid, Str, TransportType}; use crate::{ - DisconnectReason, Socket, - body::ResponseBody, - config::EngineIoConfig, - engine::EngineIo, - errors::Error, - handler::EngineIoHandler, - packet::{OpenPacket, Packet}, - service::ProtocolVersion, - service::TransportType, + DisconnectReason, Socket, body::ResponseBody, config::EngineIoConfig, engine::EngineIo, + errors::Error, handler::EngineIoHandler, transport::make_open_packet, }; /// Create a response for websocket upgrade @@ -66,7 +59,7 @@ pub fn new_req( let ws_key = parts .headers .get("Sec-WebSocket-Key") - .ok_or(Error::HttpErrorResponse(StatusCode::BAD_REQUEST))? + .ok_or(Error::InvalidWebSocketKey)? .clone(); tokio::spawn(async move { @@ -147,13 +140,7 @@ where } } } else { - let socket = engine.create_session( - protocol, - TransportType::Websocket, - req_data, - #[cfg(feature = "v3")] - false, - ); + let socket = engine.create_session(protocol, TransportType::Websocket, req_data, false); #[cfg(feature = "tracing")] tracing::debug!("new websocket connection"); @@ -203,8 +190,13 @@ where { while let Some(msg) = rx.try_next().await? { match msg { - Message::Text(msg) => match Packet::parse(socket.protocol, utf8_bytes_to_str(msg))? { - Packet::Close => break, + Message::Text(msg) => match Packet::parse(socket.protocol, ws_bytes_to_str(msg))? { + Packet::Close => { + #[cfg(feature = "tracing")] + tracing::debug!("[sid={}] closing session", socket.id); + engine.close_session(socket.id, DisconnectReason::TransportClose); + break; + } Packet::Pong | Packet::Ping => socket .heartbeat_tx .try_send(()) @@ -311,7 +303,8 @@ async fn init_handshake( where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let packet = Packet::Open(OpenPacket::new(TransportType::Websocket, sid, config)); + let packet = Packet::Open(make_open_packet(TransportType::Websocket, sid, config)); + let packet: String = packet.into(); ws.send(Message::Text(packet.into())).await?; Ok(()) } @@ -369,13 +362,13 @@ where Some(Ok(Message::Text(d))) => d, _ => Err(Error::Upgrade)?, }; - match Packet::parse(socket.protocol, utf8_bytes_to_str(msg))? { + match Packet::parse(socket.protocol, ws_bytes_to_str(msg))? { Packet::PingUpgrade => { #[cfg(feature = "tracing")] tracing::debug!("received first ping upgrade"); - // Respond with a PongUpgrade packet - ws.send(Message::Text(Packet::PongUpgrade.into())).await?; + ws.send(Message::Text(String::from(Packet::PongUpgrade).into())) + .await?; } p => Err(Error::BadPacket(p))?, }; @@ -394,7 +387,7 @@ where Err(Error::Upgrade)? } }; - match Packet::parse(socket.protocol, utf8_bytes_to_str(msg))? { + match Packet::parse(socket.protocol, ws_bytes_to_str(msg))? { Packet::Upgrade => { #[cfg(feature = "tracing")] tracing::debug!("ws upgraded successfully") @@ -406,7 +399,8 @@ where Ok(()) } -fn utf8_bytes_to_str(bytes: tokio_tungstenite::tungstenite::Utf8Bytes) -> Str { - // SAFETY: the bytes are guaranteed to be valid UTF-8 by the tungstenite parser +fn ws_bytes_to_str(bytes: tungstenite::Utf8Bytes) -> Str { + // SAFETY: We are converting a valid UTF-8 byte slice + // to a string without checking its validity. unsafe { Str::from_bytes_unchecked(bytes.into()) } }