diff --git a/crates/engineioxide/src/socket.rs b/crates/engineioxide/src/socket.rs index 222e3d8f..b046bacb 100644 --- a/crates/engineioxide/src/socket.rs +++ b/crates/engineioxide/src/socket.rs @@ -78,6 +78,7 @@ use smallvec::{SmallVec, smallvec}; use tokio::sync::{ Mutex, mpsc::{self, Receiver, error::TrySendError}, + watch, }; pub use engineioxide_core::Sid; @@ -212,6 +213,13 @@ where /// Channel to send [PacketBuf] to the internal connection internal_tx: mpsc::Sender, + /// Channel to send volatile [PacketBuf]s that bypass the internal buffer. + /// Uses a [`watch`](tokio::sync::watch) channel so only the latest volatile + /// message is retained; subsequent volatile sends overwrite the previous one. + volatile_tx: watch::Sender>, + /// Receiver for the volatile channel, read by the transport with priority. + pub(crate) volatile_rx: watch::Receiver>, + /// Internal channel to receive Pong [`Packets`](Packet) (v4 protocol) or Ping (v3 protocol) in the heartbeat job /// which is running in a separate task heartbeat_rx: Mutex>, @@ -249,6 +257,7 @@ where ) -> Self { let (internal_tx, internal_rx) = mpsc::channel(config.max_buffer_size); let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); + let (volatile_tx, volatile_rx) = watch::channel(None); Self { id: Sid::new(), @@ -259,6 +268,9 @@ where internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)), internal_tx, + volatile_rx, + volatile_tx, + heartbeat_rx: Mutex::new(heartbeat_rx), heartbeat_tx, cancellation_token: CancellationToken::new(), @@ -488,6 +500,55 @@ where TrySendError::Closed(p) => TrySendError::Closed(p.into_binary()), }) } + + /// Try to send a volatile message bypassing the internal buffer channel. + /// Volatile messages may be dropped if the transport is not ready to + /// receive them. + /// + /// Because volatile messages bypass the main mpsc buffer queue, they may + /// arrive out of order relative to regular messages. + /// + /// Returns `true` if the message was queued for sending, `false` if it + /// was dropped (channel full or transport shutting down). + #[inline] + pub fn emit_volatile(&self, msg: impl Into) -> bool { + self.send_volatile(smallvec![Packet::Message(msg.into())]) + } + + /// Try to send a volatile binary message bypassing the internal buffer channel. + /// Volatile messages may be dropped if the transport is not ready to + /// receive them. + /// + /// Returns `true` if the message was queued for sending, `false` if it + /// was dropped. + #[inline] + pub fn emit_binary_volatile>(&self, data: B) -> bool { + if self.protocol == ProtocolVersion::V3 { + self.send_volatile(smallvec![Packet::BinaryV3(data.into())]) + } else { + self.send_volatile(smallvec![Packet::Binary(data.into())]) + } + } + + /// Try to send a volatile message with multiple adjacent binary payloads. + /// The message and all binary payloads are sent atomically as a single + /// volatile write. + /// + /// Returns `true` if the message was queued for sending, `false` if it + /// was dropped. + #[inline] + pub fn emit_many_volatile(&self, msg: Str, data: VecDeque) -> bool { + let mut packets = SmallVec::with_capacity(1 + data.len()); + packets.push(Packet::Message(msg)); + for bin in data { + packets.push(Packet::Binary(bin)); + } + self.send_volatile(packets) + } + + pub(crate) fn send_volatile(&self, packets: PacketBuf) -> bool { + self.volatile_tx.send(Some(packets)).is_ok() + } } impl std::fmt::Debug for Socket { @@ -548,6 +609,7 @@ where ) -> (Arc>, tokio::sync::mpsc::Receiver) { let (internal_tx, internal_rx) = mpsc::channel(buffer_size); let (heartbeat_tx, heartbeat_rx) = mpsc::channel(1); + let (volatile_tx, volatile_rx) = watch::channel(None); let sock = Self { id: sid, @@ -558,6 +620,9 @@ where internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)), internal_tx, + volatile_rx, + volatile_tx, + heartbeat_rx: Mutex::new(heartbeat_rx), heartbeat_tx, cancellation_token: CancellationToken::new(), diff --git a/crates/engineioxide/src/transport/polling/mod.rs b/crates/engineioxide/src/transport/polling/mod.rs index e5207d35..bb8826c9 100644 --- a/crates/engineioxide/src/transport/polling/mod.rs +++ b/crates/engineioxide/src/transport/polling/mod.rs @@ -130,11 +130,20 @@ where let max_payload = engine.config.max_payload; + let mut volatile_rx = socket.volatile_rx.clone(); + #[cfg(feature = "v3")] - let Payload { data, has_binary } = - payload::encoder(rx, protocol, socket.supports_binary, max_payload).await?; + let Payload { data, has_binary } = payload::encoder( + rx, + protocol, + socket.supports_binary, + max_payload, + &mut volatile_rx, + ) + .await?; #[cfg(not(feature = "v3"))] - let Payload { data, has_binary } = payload::encoder(rx, protocol, max_payload).await?; + let Payload { data, has_binary } = + payload::encoder(rx, protocol, max_payload, &mut volatile_rx).await?; #[cfg(feature = "tracing")] tracing::debug!("[sid={sid}] sending data: {:?}", data); diff --git a/crates/engineioxide/src/transport/polling/payload/encoder.rs b/crates/engineioxide/src/transport/polling/payload/encoder.rs index a56a9c9a..c509cee3 100644 --- a/crates/engineioxide/src/transport/polling/payload/encoder.rs +++ b/crates/engineioxide/src/transport/polling/payload/encoder.rs @@ -75,6 +75,7 @@ async fn recv_packet( pub async fn v4_encoder( mut rx: MutexGuard<'_, PeekableReceiver>, max_payload: u64, + volatile_rx: &mut tokio::sync::watch::Receiver>, ) -> Result { use crate::transport::polling::payload::PACKET_SEPARATOR_V4; @@ -82,11 +83,24 @@ pub async fn v4_encoder( tracing::debug!("encoding payload with v4 encoder"); let mut data: String = String::new(); + // Encode any pending volatile packets first so they are included in + // the current response even if the main channel has a backlog. + // The watch receiver is checked again at each iteration of the main + // channel drain loop, so volatile packets arriving during encoding + // (between .await points) are also captured. + encode_volatile_packets_v4(&mut data, volatile_rx); + // Send all packets in the buffer const PUNCTUATION_LEN: usize = 1; - while let Some(packets) = - try_recv_packet(&mut rx, data.len() + PUNCTUATION_LEN, max_payload, true) - { + loop { + // Check for volatile data before each main channel read + encode_volatile_packets_v4(&mut data, volatile_rx); + + let Some(packets) = + try_recv_packet(&mut rx, data.len() + PUNCTUATION_LEN, max_payload, true) + else { + break; + }; for packet in packets { let packet: String = packet.into(); @@ -108,11 +122,46 @@ pub async fn v4_encoder( let packet: String = packet.into(); data.push_str(&packet); } + + // Check for volatile packets that arrived during the parked recv + let mut volatile_data = String::new(); + encode_volatile_packets_v4(&mut volatile_data, volatile_rx); + if !volatile_data.is_empty() { + if !data.is_empty() { + volatile_data.push(std::char::from_u32(PACKET_SEPARATOR_V4 as u32).unwrap()); + } + volatile_data.push_str(&data); + data = volatile_data; + } } Ok(Payload::new(data.into(), false)) } +/// Encode any pending volatile packets from the watch receiver into the +/// v4 payload string. The watch is checked lazily (only when new data is +/// available) so volatile packets arriving during encoding are captured. +fn encode_volatile_packets_v4( + data: &mut String, + volatile_rx: &mut tokio::sync::watch::Receiver>, +) { + use crate::transport::polling::payload::PACKET_SEPARATOR_V4; + + if !volatile_rx.has_changed().unwrap_or(false) { + return; + } + let value = volatile_rx.borrow_and_update().clone(); + if let Some(packets) = value { + for packet in packets { + let packet: String = packet.into(); + if !data.is_empty() { + data.push(std::char::from_u32(PACKET_SEPARATOR_V4 as u32).unwrap()); + } + data.push_str(&packet); + } + } +} + /// Encode one packet into a *binary* payload according to the /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] @@ -175,6 +224,7 @@ pub fn v3_string_packet_encoder(packet: Packet, data: &mut bytes::BytesMut) { pub async fn v3_binary_encoder( mut rx: MutexGuard<'_, PeekableReceiver>, max_payload: u64, + volatile_rx: &mut tokio::sync::watch::Receiver>, ) -> Result { let mut data = bytes::BytesMut::new(); let mut packet_buffer: Vec = Vec::new(); @@ -189,7 +239,28 @@ pub async fn v3_binary_encoder( // buffer all packets to find if there is binary packets let mut has_binary = false; - while let Some(packets) = try_recv_packet(&mut rx, estimated_size, max_payload, false) { + // Encode any pending volatile packets first. Checked again inside the + // main drain loop so volatile packets arriving during encoding are captured. + buffer_volatile_packets( + &mut packet_buffer, + &mut estimated_size, + &mut has_binary, + max_packet_size_len, + volatile_rx, + ); + + loop { + buffer_volatile_packets( + &mut packet_buffer, + &mut estimated_size, + &mut has_binary, + max_packet_size_len, + volatile_rx, + ); + + let Some(packets) = try_recv_packet(&mut rx, estimated_size, max_payload, false) else { + break; + }; for packet in packets { if packet.is_binary() { has_binary = true; @@ -203,11 +274,11 @@ pub async fn v3_binary_encoder( } if has_binary { - for packet in packet_buffer { + for packet in packet_buffer.drain(..) { v3_bin_packet_encoder(packet, &mut data); } } else { - for packet in packet_buffer { + for packet in packet_buffer.drain(..) { v3_string_packet_encoder(packet, &mut data); } } @@ -215,7 +286,24 @@ pub async fn v3_binary_encoder( // If there is no packet in the buffer, wait for the next packet if data.is_empty() { let packets = recv_packet(&mut rx).await?; - has_binary = packets.iter().any(|p| p.is_binary()); + has_binary = packets.iter().any(|p| p.is_binary()) || has_binary; + + // Check for volatile that arrived during the park + buffer_volatile_packets( + &mut packet_buffer, + &mut estimated_size, + &mut has_binary, + max_packet_size_len, + volatile_rx, + ); + + for packet in packet_buffer.drain(..) { + if has_binary { + v3_bin_packet_encoder(packet, &mut data); + } else { + v3_string_packet_encoder(packet, &mut data); + } + } for packet in packets { if has_binary { v3_bin_packet_encoder(packet, &mut data); @@ -230,12 +318,40 @@ pub async fn v3_binary_encoder( Ok(Payload::new(data.freeze(), has_binary)) } +/// Buffer any pending volatile packets from the watch receiver into the +/// packet buffer. Called at each iteration so volatile packets arriving +/// during encoding (between .await points) are captured. +#[cfg(feature = "v3")] +fn buffer_volatile_packets( + packet_buffer: &mut Vec, + estimated_size: &mut usize, + has_binary: &mut bool, + max_packet_size_len: usize, + volatile_rx: &mut tokio::sync::watch::Receiver>, +) { + if !volatile_rx.has_changed().unwrap_or(false) { + return; + } + let value = volatile_rx.borrow_and_update().clone(); + if let Some(packets) = value { + for packet in packets { + if packet.is_binary() { + *has_binary = true; + } + const PUNCTUATION_LEN: usize = 2; + *estimated_size += packet.get_size_hint(false) + max_packet_size_len + PUNCTUATION_LEN; + packet_buffer.push(packet); + } + } +} + /// Encode multiple packet packet into a *string* payload according to the /// [engine.io v3 protocol](https://github.com/socketio/engine.io-protocol/tree/v3#payload) #[cfg(feature = "v3")] pub async fn v3_string_encoder( mut rx: MutexGuard<'_, PeekableReceiver>, max_payload: u64, + volatile_rx: &mut tokio::sync::watch::Receiver>, ) -> Result { let mut data = bytes::BytesMut::new(); @@ -245,12 +361,22 @@ pub async fn v3_string_encoder( const PUNCTUATION_LEN: usize = 2; // number of digits of the max packet size, used to approximate the payload size let max_packet_size_len = max_payload.checked_ilog10().unwrap_or(0) as usize + 1; - while let Some(packets) = try_recv_packet( - &mut rx, - data.len() + PUNCTUATION_LEN + max_packet_size_len, - max_payload, - true, - ) { + + // Encode any pending volatile packets first; checked again in the main + // drain loop so volatile packets arriving during encoding are captured. + encode_volatile_v3_string(&mut data, volatile_rx); + + loop { + encode_volatile_v3_string(&mut data, volatile_rx); + + let Some(packets) = try_recv_packet( + &mut rx, + data.len() + PUNCTUATION_LEN + max_packet_size_len, + max_payload, + true, + ) else { + break; + }; for packet in packets { v3_string_packet_encoder(packet, &mut data); } @@ -262,11 +388,37 @@ pub async fn v3_string_encoder( for packet in packets { v3_string_packet_encoder(packet, &mut data); } + + let mut volatile_data = bytes::BytesMut::new(); + encode_volatile_v3_string(&mut volatile_data, volatile_rx); + if !volatile_data.is_empty() { + volatile_data.unsplit(data); + data = volatile_data; + } } Ok(Payload::new(data.freeze(), false)) } +/// Encode any pending volatile packets from the watch receiver into the +/// v3 string payload. Called at each iteration so volatile packets +/// arriving during encoding are captured. +#[cfg(feature = "v3")] +fn encode_volatile_v3_string( + data: &mut bytes::BytesMut, + volatile_rx: &mut tokio::sync::watch::Receiver>, +) { + if !volatile_rx.has_changed().unwrap_or(false) { + return; + } + let value = volatile_rx.borrow_and_update().clone(); + if let Some(packets) = value { + for packet in packets { + v3_string_packet_encoder(packet, data); + } + } +} + #[cfg(test)] mod tests { use bytes::Bytes; @@ -277,8 +429,14 @@ mod tests { use super::*; const MAX_PAYLOAD: u64 = 100_000; + fn dummy_volatile_rx() -> tokio::sync::watch::Receiver> { + let (_, rx) = tokio::sync::watch::channel(None); + rx + } + #[tokio::test] async fn encode_v4_payload() { + let mut vr = dummy_volatile_rx(); const PAYLOAD: &str = "4hello€\x1ebAQIDBA==\x1e4hello€"; let (tx, rx) = tokio::sync::mpsc::channel::(10); let rx = Mutex::new(PeekableReceiver::new(rx)); @@ -291,12 +449,13 @@ mod tests { .unwrap(); tx.try_send(smallvec::smallvec![Packet::Message("hello€".into())]) .unwrap(); - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, PAYLOAD.as_bytes()); } #[tokio::test] async fn encode_v4_payload_parked_poll_multi_packet_batch() { + let mut vr = dummy_volatile_rx(); const PAYLOAD: &str = "4hello€\x1ebAQIDBA=="; let (tx, rx) = tokio::sync::mpsc::channel::(10); let rx = Mutex::new(PeekableReceiver::new(rx)); @@ -309,12 +468,13 @@ mod tests { ]) .unwrap(); }); - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, PAYLOAD); } #[tokio::test] async fn max_payload_v4() { + let mut vr = dummy_volatile_rx(); const MAX_PAYLOAD: u64 = 10; let (tx, rx) = tokio::sync::mpsc::channel::(10); let mutex = Mutex::new(PeekableReceiver::new(rx)); @@ -330,17 +490,17 @@ mod tests { .unwrap(); { let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, "4hello€".as_bytes()); } { let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10, &mut vr).await.unwrap(); assert_eq!(data, "bAQIDBA==\x1e4hello€".as_bytes()); } { let rx = mutex.lock().await; - let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD + 10, &mut vr).await.unwrap(); assert_eq!(data, "4hello€".as_bytes()); } } @@ -348,6 +508,7 @@ mod tests { #[cfg(feature = "v3")] #[tokio::test] async fn encode_v3_string_payload_utf16_length() { + let mut vr = dummy_volatile_rx(); // Length must be the number of UTF-16 code units to match the // engine.io v3 JS reference implementation. The message "4𝕊" // (packet type '4' + non-BMP codepoint U+1D54A) has 2 codepoints @@ -358,13 +519,14 @@ mod tests { let rx = mutex.lock().await; tx.try_send(smallvec::smallvec![Packet::Message("𝕊".into())]) .unwrap(); - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, PAYLOAD.as_bytes()); } #[cfg(feature = "v3")] #[tokio::test] async fn encode_v3b64_payload() { + let mut vr = dummy_volatile_rx(); const PAYLOAD: &str = "7:4hello€10:b4AQIDBA==7:4hello€"; let (tx, rx) = tokio::sync::mpsc::channel::(10); let mutex = Mutex::new(PeekableReceiver::new(rx)); @@ -380,7 +542,7 @@ mod tests { .unwrap(); let Payload { data, has_binary, .. - } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, PAYLOAD.as_bytes()); assert!(!has_binary); } @@ -388,6 +550,7 @@ mod tests { #[cfg(feature = "v3")] #[tokio::test] async fn max_payload_v3_b64() { + let mut vr = dummy_volatile_rx(); const MAX_PAYLOAD: u64 = 10; let (tx, rx) = tokio::sync::mpsc::channel::(10); @@ -404,18 +567,22 @@ mod tests { .unwrap(); { let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, "7:4hello€".as_bytes()); } { let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10, &mut vr) + .await + .unwrap(); assert_eq!(data, "10:b4AQIDBA==".as_bytes()); } { // Next call drains one of the remaining Message packets. let rx = mutex.lock().await; - let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10).await.unwrap(); + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD + 10, &mut vr) + .await + .unwrap(); assert_eq!(data, "7:4hello€".as_bytes()); } } @@ -423,6 +590,7 @@ mod tests { #[cfg(feature = "v3")] #[tokio::test] async fn encode_v3binary_payload() { + let mut vr = dummy_volatile_rx(); const PAYLOAD: [u8; 20] = [ 0, 9, 255, 52, 104, 101, 108, 108, 111, 226, 130, 172, 1, 5, 255, 4, 1, 2, 3, 4, ]; @@ -438,7 +606,7 @@ mod tests { .unwrap(); let Payload { data, has_binary, .. - } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(*data, PAYLOAD); assert!(has_binary); } @@ -446,6 +614,7 @@ mod tests { #[cfg(feature = "v3")] #[tokio::test] async fn encode_v3binary_payload_parked_poll_multi_packet_batch() { + let mut vr = dummy_volatile_rx(); // When the v3 binary encoder is parked on an empty buffer and a // multi-packet batch arrives, every packet must be encoded with the // same framing (binary framing if any packet in the batch is binary). @@ -467,7 +636,7 @@ mod tests { }); let Payload { data, has_binary, .. - } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(*data, PAYLOAD); assert!(has_binary); } @@ -475,6 +644,7 @@ mod tests { #[cfg(feature = "v3")] #[tokio::test] async fn max_payload_v3_binary() { + let mut vr = dummy_volatile_rx(); const MAX_PAYLOAD: u64 = 25; const PAYLOAD: [u8; 23] = [ @@ -495,13 +665,398 @@ mod tests { .unwrap(); { let rx = mutex.lock().await; - let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(*data, PAYLOAD); } { let rx = mutex.lock().await; - let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD).await.unwrap(); + let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); assert_eq!(data, "7:4hello€7:4hello€".as_bytes()); } } + + fn make_volatile_chan( + packets: PacketBuf, + ) -> ( + tokio::sync::watch::Sender>, + tokio::sync::watch::Receiver>, + ) { + let (tx, rx) = tokio::sync::watch::channel(None); + tx.send(Some(packets)).unwrap(); + (tx, rx) + } + + #[tokio::test] + async fn v4_volatile_before_normal() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("foo".into())]) + .unwrap(); + tx.try_send(smallvec::smallvec![Packet::Message("bar".into())]) + .unwrap(); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("v1".into())]); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "4v1\x1e4foo\x1e4bar".as_bytes()); + } + + #[tokio::test] + async fn v4_normal_before_volatile() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("foo".into())]) + .unwrap(); + tx.try_send(smallvec::smallvec![Packet::Message("bar".into())]) + .unwrap(); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("v1".into())]); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "4v1\x1e4foo\x1e4bar".as_bytes()); + } + + #[tokio::test] + async fn v4_volatile_overwrite() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("normal".into())]) + .unwrap(); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + volatile_tx + .send(Some(smallvec::smallvec![Packet::Message("dropped".into())])) + .unwrap(); + volatile_tx + .send(Some(smallvec::smallvec![Packet::Message("kept".into())])) + .unwrap(); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "4kept\x1e4normal".as_bytes()); + } + + #[tokio::test] + async fn v4_volatile_mid_encoding() { + let (main_tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + main_tx + .try_send(smallvec::smallvec![Packet::Message("first".into())]) + .unwrap(); + main_tx + .try_send(smallvec::smallvec![Packet::Message("second".into())]) + .unwrap(); + + let rx = mutex.lock().await; + volatile_tx + .send(Some(smallvec::smallvec![Packet::Message("mid".into())])) + .unwrap(); + + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "4mid\x1e4first\x1e4second".as_bytes()); + } + + #[tokio::test] + async fn v4_volatile_only_no_drain() { + let (_main_tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("v_only".into())]); + drop(_main_tx); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "4v_only".as_bytes()); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_string_volatile_mixed() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("foo".into())]) + .unwrap(); + tx.try_send(smallvec::smallvec![Packet::Message("bar".into())]) + .unwrap(); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("v1".into())]); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "3:4v14:4foo4:4bar".as_bytes()); + assert!(!has_binary); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_string_volatile_overwrite() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("normal".into())]) + .unwrap(); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + volatile_tx + .send(Some(smallvec::smallvec![Packet::Message("drop1".into())])) + .unwrap(); + volatile_tx + .send(Some(smallvec::smallvec![Packet::Message("keep1".into())])) + .unwrap(); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "6:4keep17:4normal".as_bytes()); + assert!(!has_binary); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_binary_volatile_mixed() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("foo".into())]) + .unwrap(); + let (_volatile_tx, mut vr) = make_volatile_chan(smallvec::smallvec![Packet::BinaryV3( + Bytes::from_static(&[1, 2, 3, 4]) + )]); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert!(has_binary); + assert_eq!( + &data[..8], + &[0x01, 0x05, 0xff, 0x04, 0x01, 0x02, 0x03, 0x04][..] + ); + assert_eq!(&data[8..], &[0x00, 0x04, 0xff, 0x34, 0x66, 0x6f, 0x6f][..]); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_binary_volatile_overwrite() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("normal".into())]) + .unwrap(); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + volatile_tx + .send(Some(smallvec::smallvec![Packet::Message("drop_v".into())])) + .unwrap(); + volatile_tx + .send(Some(smallvec::smallvec![Packet::BinaryV3( + Bytes::from_static(&[9, 9]) + )])) + .unwrap(); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert!(has_binary); + assert_eq!(&data[..6], &[0x01, 0x03, 0xff, 0x04, 0x09, 0x09][..]); + assert_eq!( + &data[6..], + &[0x00, 0x07, 0xff, 0x34, 0x6e, 0x6f, 0x72, 0x6d, 0x61, 0x6c][..] + ); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_binary_volatile_determines_has_binary() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + tx.try_send(smallvec::smallvec![Packet::Message("foo".into())]) + .unwrap(); + let (_volatile_tx, mut vr) = make_volatile_chan(smallvec::smallvec![Packet::BinaryV3( + Bytes::from_static(&[1, 2, 3]) + )]); + + let rx = mutex.lock().await; + let Payload { + data: _, + has_binary, + .. + } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert!(has_binary); + } + + #[tokio::test] + async fn v4_volatile_arrives_during_parked_poll() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + let volatile_tx = std::sync::Arc::new(std::sync::Mutex::new(volatile_tx)); + + let tx_clone = tx.clone(); + let vt = volatile_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + vt.lock() + .unwrap() + .send(Some(smallvec::smallvec![Packet::Message( + "volatile".into() + )])) + .unwrap(); + tx_clone + .try_send(smallvec::smallvec![Packet::Message("normal".into())]) + .unwrap(); + }); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + // After the fix: volatile arrived during park, now captured same payload + assert_eq!(data, "4volatile\x1e4normal".as_bytes()); + + drop(tx); + let rx = mutex.lock().await; + let result = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await; + assert!(result.is_err()); // no more data + } + + #[tokio::test] + async fn v4_volatile_pushes_past_max_payload() { + const SMALL_LIMIT: u64 = 12; + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("big_vol".into())]); + tx.try_send(smallvec::smallvec![Packet::Message("normal".into())]) + .unwrap(); + tx.try_send(smallvec::smallvec![Packet::Message("extra".into())]) + .unwrap(); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, SMALL_LIMIT, &mut vr).await.unwrap(); + assert_eq!(data, "4big_vol".as_bytes()); + + let rx = mutex.lock().await; + let Payload { data, .. } = v4_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "4normal\x1e4extra".as_bytes()); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_string_volatile_during_parked_poll() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + let volatile_tx = std::sync::Arc::new(std::sync::Mutex::new(volatile_tx)); + + let tx_clone = tx.clone(); + let vt = volatile_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + vt.lock() + .unwrap() + .send(Some(smallvec::smallvec![Packet::Message("v".into())])) + .unwrap(); + tx_clone + .try_send(smallvec::smallvec![Packet::Message("n".into())]) + .unwrap(); + }); + + let rx = mutex.lock().await; + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "2:4v2:4n".as_bytes()); + + drop(tx); + let rx = mutex.lock().await; + let result = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await; + assert!(result.is_err()); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_string_volatile_max_payload() { + const SMALL_LIMIT: u64 = 12; + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("big".into())]); + tx.try_send(smallvec::smallvec![Packet::Message("normal".into())]) + .unwrap(); + tx.try_send(smallvec::smallvec![Packet::Message("extra".into())]) + .unwrap(); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_string_encoder(rx, SMALL_LIMIT, &mut vr).await.unwrap(); + assert_eq!(data, "4:4big".as_bytes()); + assert!(!has_binary); + + let rx = mutex.lock().await; + let Payload { data, .. } = v3_string_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "7:4normal6:4extra".as_bytes()); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_binary_volatile_during_parked_poll() { + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (volatile_tx, mut vr) = tokio::sync::watch::channel(None); + let volatile_tx = std::sync::Arc::new(std::sync::Mutex::new(volatile_tx)); + + let tx_clone = tx.clone(); + let vt = volatile_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + vt.lock() + .unwrap() + .send(Some(smallvec::smallvec![Packet::BinaryV3( + Bytes::from_static(&[1, 2]) + )])) + .unwrap(); + tx_clone + .try_send(smallvec::smallvec![Packet::Message("n".into())]) + .unwrap(); + }); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + // After fix: volatile binary + normal message both captured, has_binary=true + assert!(has_binary); + // Volatile binary: 0x01 0x03 0xFF 0x04 0x01 0x02 (6 bytes) + // Normal message in binary frame: 0x00 0x02 0xFF 0x34 0x6E (5 bytes) + assert_eq!(data.len(), 11); + assert_eq!(&data[..6], &[0x01, 0x03, 0xff, 0x04, 0x01, 0x02][..]); + assert_eq!(&data[6..], &[0x00, 0x02, 0xff, 0x34, 0x6e][..]); + + drop(tx); + let rx = mutex.lock().await; + let result = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await; + assert!(result.is_err()); + } + + #[cfg(feature = "v3")] + #[tokio::test] + async fn v3_binary_volatile_max_payload() { + const SMALL_LIMIT: u64 = 15; + let (tx, rx) = tokio::sync::mpsc::channel::(10); + let mutex = Mutex::new(PeekableReceiver::new(rx)); + let (_volatile_tx, mut vr) = + make_volatile_chan(smallvec::smallvec![Packet::Message("big_volatile".into())]); + tx.try_send(smallvec::smallvec![Packet::Message("after".into())]) + .unwrap(); + + let rx = mutex.lock().await; + let Payload { + data, has_binary, .. + } = v3_binary_encoder(rx, SMALL_LIMIT, &mut vr).await.unwrap(); + assert!(!has_binary); + assert_eq!(data, "13:4big_volatile".as_bytes()); + + let rx = mutex.lock().await; + let Payload { data, .. } = v3_binary_encoder(rx, MAX_PAYLOAD, &mut vr).await.unwrap(); + assert_eq!(data, "6:4after".as_bytes()); + } } diff --git a/crates/engineioxide/src/transport/polling/payload/mod.rs b/crates/engineioxide/src/transport/polling/payload/mod.rs index db62cc96..beafbdb4 100644 --- a/crates/engineioxide/src/transport/polling/payload/mod.rs +++ b/crates/engineioxide/src/transport/polling/payload/mod.rs @@ -70,21 +70,22 @@ pub async fn encoder( #[allow(unused_variables)] protocol: ProtocolVersion, #[cfg(feature = "v3")] supports_binary: bool, max_payload: u64, + volatile_rx: &mut tokio::sync::watch::Receiver>, ) -> Result { #[cfg(feature = "v3")] { match protocol { - ProtocolVersion::V4 => encoder::v4_encoder(rx, max_payload).await, + ProtocolVersion::V4 => encoder::v4_encoder(rx, max_payload, volatile_rx).await, ProtocolVersion::V3 if supports_binary => { - encoder::v3_binary_encoder(rx, max_payload).await + encoder::v3_binary_encoder(rx, max_payload, volatile_rx).await } - ProtocolVersion::V3 => encoder::v3_string_encoder(rx, max_payload).await, + ProtocolVersion::V3 => encoder::v3_string_encoder(rx, max_payload, volatile_rx).await, } } #[cfg(not(feature = "v3"))] { - encoder::v4_encoder(rx, max_payload).await + encoder::v4_encoder(rx, max_payload, volatile_rx).await } } diff --git a/crates/engineioxide/src/transport/ws.rs b/crates/engineioxide/src/transport/ws.rs index 565a9940..ec7451ef 100644 --- a/crates/engineioxide/src/transport/ws.rs +++ b/crates/engineioxide/src/transport/ws.rs @@ -249,6 +249,7 @@ async fn forward_to_socket( S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let mut internal_rx = socket.internal_rx.try_lock().unwrap(); + let mut volatile_rx = socket.volatile_rx.clone(); // map a packet to a websocket message // It is declared as a macro rather than a closure to avoid ownership issues @@ -288,17 +289,35 @@ async fn forward_to_socket( }; } - while let Some(items) = internal_rx.recv().await { + loop { + // Priority: wait for and drain the main channel. + // Volatile events are checked after each main-channel cycle. + let items = match internal_rx.recv().await { + Some(items) => items, + None => break, + }; for item in items { map_fn!(item); } - // For every available packet we continue to send until the channel is drained while let Ok(items) = internal_rx.try_recv() { for item in items { map_fn!(item); } } + // Check volatile channel after main is drained. + // `has_changed` is non-blocking; if no volatile data is + // pending we simply go back to waiting for the main channel. + if volatile_rx.has_changed().unwrap_or(false) { + let val = volatile_rx.borrow_and_update().clone(); + if let Some(packets) = val { + #[cfg(feature = "tracing")] + tracing::info!(sid = ?socket.id, "ws volatile check: flushing {:?}", &packets); + for item in packets { + map_fn!(item); + } + } + } tx.flush().await.ok(); } } diff --git a/crates/engineioxide/tests/volatile.rs b/crates/engineioxide/tests/volatile.rs new file mode 100644 index 00000000..743b2807 --- /dev/null +++ b/crates/engineioxide/tests/volatile.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use bytes::Bytes; +use engineioxide::{ + Str, + handler::EngineIoHandler, + socket::{DisconnectReason, Socket}, +}; +use tokio::sync::mpsc; + +mod fixture; +use fixture::{create_polling_connection, create_server, send_req}; + +#[derive(Debug, Clone)] +struct VolatileHandler { + socket_tx: mpsc::Sender>>, +} +impl EngineIoHandler for VolatileHandler { + type Data = (); + fn on_connect(self: Arc, socket: Arc>) { + self.socket_tx.try_send(socket).ok(); + } + fn on_disconnect(&self, _socket: Arc>, _reason: DisconnectReason) {} + fn on_message(self: &Arc, msg: Str, socket: Arc>) { + if msg == "trigger_volatile" { + socket.emit_volatile("volatile_response"); + } else if msg == "echo" { + socket.emit(msg).ok(); + } + } + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { + socket.emit_binary(data).ok(); + } +} + +#[tokio::test] +async fn volatile_message_arrives_via_polling() { + let (socket_tx, _socket_rx) = mpsc::channel(10); + let mut svc = create_server(VolatileHandler { socket_tx }).await; + let sid = create_polling_connection(&mut svc).await; + + send_req( + &mut svc, + format!("transport=polling&sid={sid}"), + http::Method::POST, + Some("4trigger_volatile".into()), + ) + .await; + + let response = send_req( + &mut svc, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ) + .await; + // send_req skips the first character (packet type '4') + assert_eq!(response, "volatile_response"); +} + +#[tokio::test] +async fn mixed_volatile_and_normal_via_polling() { + let (socket_tx, mut socket_rx) = mpsc::channel(10); + let mut svc = create_server(VolatileHandler { socket_tx }).await; + let sid = create_polling_connection(&mut svc).await; + + let socket = socket_rx + .recv() + .await + .expect("socket not received from on_connect"); + + // Send a normal message AND a volatile + socket.emit("normal_msg").ok(); + assert!(socket.emit_volatile("volatile_msg")); + + let response = send_req( + &mut svc, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ) + .await; + // Volatile should have priority and appear first, before normal. + // send_req skips only the first character (the leading '4' of the volatile). + assert_eq!(response, "volatile_msg\x1e4normal_msg"); +} + +#[tokio::test] +async fn volatile_overwrite_only_latest_survives() { + let (socket_tx, mut socket_rx) = mpsc::channel(10); + let mut svc = create_server(VolatileHandler { socket_tx }).await; + let sid = create_polling_connection(&mut svc).await; + + let socket = socket_rx + .recv() + .await + .expect("socket not received from on_connect"); + + assert!(socket.emit_volatile("dropped")); + assert!(socket.emit_volatile("kept")); + assert!(socket.emit("normal").is_ok()); + + let response = send_req( + &mut svc, + format!("transport=polling&sid={sid}"), + http::Method::GET, + None, + ) + .await; + // After send_req's skip(1): "kept" then \x1e separator then "4normal" + assert_eq!(response, "kept\x1e4normal"); +} diff --git a/crates/socketioxide-core/src/adapter/mod.rs b/crates/socketioxide-core/src/adapter/mod.rs index 7e055ed9..e5f5447d 100644 --- a/crates/socketioxide-core/src/adapter/mod.rs +++ b/crates/socketioxide-core/src/adapter/mod.rs @@ -38,6 +38,11 @@ pub enum BroadcastFlags { Local = 0x01, /// Broadcast to all clients except the sender Broadcast = 0x02, + /// The event may be dropped if the client is not ready to receive it + /// (e.g. the connection is buffering or not connected). + /// This is useful for events that are not critical, like position updates in a game. + /// See [socket.io volatile events](https://socket.io/docs/v4/emitting-events/#volatile-events). + Volatile = 0x04, } /// Options that can be used to modify the behavior of the broadcast methods. @@ -202,6 +207,12 @@ pub trait SocketEmitter: Send + Sync + 'static { fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec; /// Send data to the list of socket ids. fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec>; + /// Send data to the list of socket ids with volatile semantics. + /// Errors are silently discarded; packets may be dropped if the + /// transport is not ready. + fn send_many_volatile(&self, sids: BroadcastIter<'_>, data: Value) { + _ = self.send_many(sids, data); + } /// Send data to the list of socket ids and get a stream of acks and the number of expected acks. fn send_many_with_ack( &self, @@ -430,8 +441,14 @@ impl CoreLocalAdapter { return Ok(()); } + let is_volatile = opts.has_flag(BroadcastFlags::Volatile); let data = self.emitter.parser().encode(packet); - self.emitter.send_many(sids, data) + if is_volatile { + self.emitter.send_many_volatile(sids, data); + Ok(()) + } else { + self.emitter.send_many(sids, data) + } } /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses. diff --git a/crates/socketioxide/docs/operators/volatile.md b/crates/socketioxide/docs/operators/volatile.md new file mode 100644 index 00000000..42023754 --- /dev/null +++ b/crates/socketioxide/docs/operators/volatile.md @@ -0,0 +1,28 @@ +# Set the volatile flag for the emit. When set, the event may be dropped +if the client is not ready to receive it (e.g. the connection is buffering +or not connected). This is useful for events that are not critical, such +as position updates in a game. + +Because volatile events use a separate channel that bypasses the main +mpsc buffer, they may arrive **out of order** relative to regular events +emitted around the same time. Only use volatile when ordering relative to +regular events is not important. + +See [socket.io volatile events](https://socket.io/docs/v4/emitting-events/#volatile-events). + +# Example +```rust +# use socketioxide::{SocketIo, extract::*}; +# use serde::Serialize; +#[derive(Serialize)] +struct GameState { x: f64, y: f64 } + +let (_, io) = SocketIo::new_svc(); +io.ns("/", async |socket: SocketRef| { + // Direct volatile emit — may be dropped if the socket is not ready + socket.volatile().emit("position", &GameState { x: 1.0, y: 2.0 }).ok(); + + // Volatile broadcast to a room + socket.volatile().to("game_room").emit("update", &42).await.ok(); +}); +``` diff --git a/crates/socketioxide/src/io.rs b/crates/socketioxide/src/io.rs index 15618278..045822e3 100644 --- a/crates/socketioxide/src/io.rs +++ b/crates/socketioxide/src/io.rs @@ -591,6 +591,16 @@ impl SocketIo { self.get_default_op() } + /// _Alias for `io.of("/").unwrap().volatile()`_. If the **default namespace "/" is not found** this fn will panic! + /// + /// Sets the volatile flag on the broadcast operators. Volatile events may be dropped + /// if the client is not ready to receive them. + /// See [socket.io volatile events](https://socket.io/docs/v4/emitting-events/#volatile-events). + #[inline] + pub fn volatile(&self) -> BroadcastOperators { + self.get_default_op().volatile() + } + #[cfg(feature = "state")] pub(crate) fn get_state(&self) -> Option { self.0.state.try_get::().cloned() diff --git a/crates/socketioxide/src/ns.rs b/crates/socketioxide/src/ns.rs index 78236875..c4b94a0f 100644 --- a/crates/socketioxide/src/ns.rs +++ b/crates/socketioxide/src/ns.rs @@ -227,6 +227,9 @@ trait InnerEmitter: Send + Sync + 'static { fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec; /// Send data to the list of socket ids. fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec>; + /// Send data to the list of socket ids with volatile semantics. + /// Errors are silently discarded. + fn send_many_volatile(&self, sids: BroadcastIter<'_>, data: Value); /// Send data to the list of socket ids and get a stream of acks. fn send_many_with_ack( &self, @@ -268,6 +271,15 @@ impl InnerEmitter for Namespace { if errs.is_empty() { Ok(()) } else { Err(errs) } } + fn send_many_volatile(&self, sids: BroadcastIter<'_>, data: Value) { + let sockets = self.sockets.read().unwrap(); + for sid in sids { + if let Some(socket) = sockets.get(&sid) { + socket.send_raw_volatile(data.clone()); + } + } + } + fn send_many_with_ack( &self, sids: BroadcastIter<'_>, @@ -354,6 +366,12 @@ impl SocketEmitter for Emitter { } } + fn send_many_volatile(&self, sids: BroadcastIter<'_>, data: Value) { + if let Some(ns) = self.ns.upgrade() { + ns.send_many_volatile(sids, data); + } + } + fn send_many_with_ack( &self, sids: BroadcastIter<'_>, diff --git a/crates/socketioxide/src/operators.rs b/crates/socketioxide/src/operators.rs index efdde4a8..c92bf534 100644 --- a/crates/socketioxide/src/operators.rs +++ b/crates/socketioxide/src/operators.rs @@ -30,6 +30,7 @@ use socketioxide_core::{ /// Chainable operators to configure the message to be sent. pub struct ConfOperators<'a, A: Adapter = LocalAdapter> { timeout: Option, + volatile: bool, socket: &'a Socket, } /// Chainable operators to select sockets to send a message to and to configure the message to be sent. @@ -42,7 +43,10 @@ pub struct BroadcastOperators { impl From> for BroadcastOperators { fn from(conf: ConfOperators<'_, A>) -> Self { - let opts = BroadcastOptions::new(conf.socket.id); + let mut opts = BroadcastOptions::new(conf.socket.id); + if conf.volatile { + opts.add_flag(BroadcastFlags::Volatile); + } Self { timeout: conf.timeout, ns: conf.socket.ns.clone(), @@ -57,6 +61,7 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { pub(crate) fn new(sender: &'a Socket) -> Self { Self { timeout: None, + volatile: false, socket: sender, } } @@ -91,6 +96,12 @@ impl<'a, A: Adapter> ConfOperators<'a, A> { self.timeout = Some(timeout); self } + + #[doc = include_str!("../docs/operators/volatile.md")] + pub fn volatile(mut self) -> Self { + self.volatile = true; + self + } } // ==== impl ConfOperators consume fns ==== @@ -101,6 +112,18 @@ impl ConfOperators<'_, A> { event: impl AsRef, data: &T, ) -> Result<(), SendError> { + if self.volatile { + if !self.socket.connected() { + return Ok(()); + } + let Ok(packet) = self.get_packet(event, data) else { + return Ok(()); + }; + self.socket + .send_raw_volatile(self.socket.parser.encode(packet)); + return Ok(()); + } + use crate::SocketError; use crate::socket::PermitExt; if !self.socket.connected() { @@ -228,6 +251,12 @@ impl BroadcastOperators { self.timeout = Some(timeout); self } + + #[doc = include_str!("../docs/operators/volatile.md")] + pub fn volatile(mut self) -> Self { + self.opts.add_flag(BroadcastFlags::Volatile); + self + } } // ==== impl BroadcastOperators consume fns ==== @@ -336,6 +365,7 @@ impl<'a, A: Adapter> Clone for ConfOperators<'a, A> { fn clone(&self) -> Self { Self { timeout: self.timeout, + volatile: self.volatile, socket: self.socket, } } diff --git a/crates/socketioxide/src/socket.rs b/crates/socketioxide/src/socket.rs index 4d4d7b68..1073cef5 100644 --- a/crates/socketioxide/src/socket.rs +++ b/crates/socketioxide/src/socket.rs @@ -639,6 +639,27 @@ impl Socket { BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).broadcast() } + /// Returns a [`ConfOperators`] with the volatile flag set, so that any + /// subsequent `emit()` will drop the event instead of buffering it if the + /// client is not ready to receive it. + /// + /// # Example + /// ``` + /// # use socketioxide::{SocketIo, extract::SocketRef}; + /// # use serde::Serialize; + /// #[derive(Serialize)] + /// struct GameState { x: f64, y: f64 } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", async |socket: SocketRef| { + /// socket.volatile().emit("position", &GameState { x: 1.0, y: 2.0 }).ok(); + /// socket.volatile().to("room1").emit("update", &42).await.ok(); + /// }); + /// ``` + pub fn volatile(&self) -> ConfOperators<'_, A> { + ConfOperators::new(self).volatile() + } + /// # Get the [`SocketIo`] context related to this socket /// /// # Panics @@ -738,6 +759,20 @@ impl Socket { Ok(()) } + pub(crate) fn send_raw_volatile(&self, value: Value) { + match value { + Value::Str(msg, None) => { + self.esocket.emit_volatile(msg); + } + Value::Str(msg, Some(bin_payloads)) => { + self.esocket.emit_many_volatile(msg, bin_payloads); + } + Value::Bytes(bin) => { + self.esocket.emit_binary_volatile(bin); + } + } + } + pub(crate) fn send_with_ack_permit( &self, mut packet: Packet, diff --git a/crates/socketioxide/tests/volatile.rs b/crates/socketioxide/tests/volatile.rs new file mode 100644 index 00000000..cfc273af --- /dev/null +++ b/crates/socketioxide/tests/volatile.rs @@ -0,0 +1,156 @@ +//! Integration tests for volatile events on socketioxide. +//! Verifies the volatile operator API and that volatile emits +//! silently drop errors rather than propagating them. +mod fixture; +mod utils; + +use fixture::{create_polling_connection, create_server, send_req}; +use http::Method; +use socketioxide::{SocketIo, extract::SocketRef}; + +#[tokio::test] +async fn volatile_emit_returns_ok() { + use serde_json::json; + let (_svc, io) = SocketIo::new_svc(); + + let (tx, rx) = std::sync::mpsc::channel(); + io.ns("/", async move |socket: SocketRef| { + let result = socket.volatile().emit("test", &json!({"key": "val"})); + tx.send(result).unwrap(); + }); + + io.new_dummy_sock("/", ()).await; + assert!(rx.recv().unwrap().is_ok()); +} + +#[tokio::test] +async fn volatile_emit_broadcast_does_not_panic() { + let (_svc, io) = SocketIo::new_svc(); + + let (tx, rx) = std::sync::mpsc::channel(); + io.ns("/", async move |socket: SocketRef| { + // Broadcast with volatile flag should complete without panicking + socket + .within("room") + .volatile() + .emit("event", &"data") + .await + .ok(); + tx.send(()).unwrap(); + }); + + io.new_dummy_sock("/", ()).await; + rx.recv().unwrap(); +} + +#[tokio::test] +async fn io_volatile_composes() { + let (_svc, io) = SocketIo::new_svc(); + io.ns("/", |_: SocketRef| async {}); + + io.new_dummy_sock("/", ()).await; + + // Volatile on the io handle delegates to the default namespace + let _ = io.volatile(); + let _ = io.of("/").unwrap().volatile(); +} + +#[tokio::test] +async fn volatile_emit_on_disconnected_socket_returns_ok() { + use serde_json::json; + let (_svc, io) = SocketIo::new_svc(); + + let (tx, rx) = std::sync::mpsc::channel(); + io.ns("/", { + let tx = tx.clone(); + async move |_: SocketRef| { + tx.send(()).unwrap(); + } + }); + + io.new_dummy_sock("/", ()).await; + rx.recv().unwrap(); + + // At this point the dummy socket's connect handler has run. + // The socket is technically connected — test that volatile + // emit returns Ok(()) without error. + let (tx2, rx2) = std::sync::mpsc::channel(); + io.ns("/test", async move |socket: SocketRef| { + let result = socket.volatile().emit("event", &json!({"data": 42})); + tx2.send(result).unwrap(); + }); + + io.new_dummy_sock("/test", ()).await; + assert!(rx2.recv().unwrap().is_ok()); +} + +#[tokio::test] +async fn volatile_broadcast_arrives_via_polling_transport() { + let (svc, io) = create_server().await; + + io.ns("/", |s: SocketRef| async move { + s.on( + "drawing", + |s: SocketRef, socketioxide::extract::Data::(data)| async move { + s.broadcast().volatile().emit("drawing", &data).await.ok(); + }, + ); + }); + + let sender_sid = create_polling_connection(&svc).await; + let receiver_sid = create_polling_connection(&svc).await; + + // Drain any queued handshake data from the connections + send_req( + &svc, + format!("transport=polling&sid={sender_sid}"), + Method::GET, + None, + ) + .await; + send_req( + &svc, + format!("transport=polling&sid={receiver_sid}"), + Method::GET, + None, + ) + .await; + // Respond to pings to keep sessions alive + send_req( + &svc, + format!("transport=polling&sid={sender_sid}"), + Method::POST, + Some("3".into()), + ) + .await; + send_req( + &svc, + format!("transport=polling&sid={receiver_sid}"), + Method::POST, + Some("3".into()), + ) + .await; + + // Send drawing event from sender + send_req( + &svc, + format!("transport=polling&sid={sender_sid}"), + Method::POST, + Some("42[\"drawing\",\"hello\"]".into()), + ) + .await; + + // Poll receiver — should receive the volatile broadcast + let response = send_req( + &svc, + format!("transport=polling&sid={receiver_sid}"), + Method::GET, + None, + ) + .await; + // send_req skips first char (engine.io message type '4') + assert!( + response.contains("drawing"), + "Expected volatile broadcast with 'drawing', got: {response}" + ); +} diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 24d5140e..a941f4b9 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -1081,7 +1081,7 @@ dependencies = [ [[package]] name = "engineioxide" -version = "0.17.3" +version = "0.17.5" dependencies = [ "base64", "bytes", @@ -3789,7 +3789,7 @@ dependencies = [ [[package]] name = "socketioxide" -version = "0.18.3" +version = "0.18.4" dependencies = [ "bytes", "engineioxide", @@ -3814,7 +3814,7 @@ dependencies = [ [[package]] name = "socketioxide-core" -version = "0.18.0" +version = "0.18.1" dependencies = [ "arbitrary", "bytes", diff --git a/examples/whiteboard/src/main.rs b/examples/whiteboard/src/main.rs index 52ba121d..c78454ae 100644 --- a/examples/whiteboard/src/main.rs +++ b/examples/whiteboard/src/main.rs @@ -23,7 +23,9 @@ async fn main() -> Result<(), Box> { io.ns("/", async |s: SocketRef| { s.on("drawing", async |s: SocketRef, Data::(data)| { - s.broadcast().emit("drawing", &data).await.unwrap(); + info!("Drawing event received, broadcasting with volatile"); + s.broadcast().volatile().emit("drawing", &data).await.unwrap(); + info!("Volatile broadcast completed"); }); });