diff --git a/quiche/src/recovery/bandwidth.rs b/quiche/src/recovery/bandwidth.rs index 58f1bfcbd4..e02f72f212 100644 --- a/quiche/src/recovery/bandwidth.rs +++ b/quiche/src/recovery/bandwidth.rs @@ -84,6 +84,10 @@ impl std::ops::Mul for Bandwidth { } } +const BITS_NANOS_PER_BYTE: u64 = 8 * NUM_NANOS_PER_SECOND; +const BITS_NANOS_PER_BYTE_U128: u128 = BITS_NANOS_PER_BYTE as u128; +const U64_MAX_U128: u128 = u64::MAX as u128; + impl Bandwidth { pub const fn from_bytes_and_time_delta( bytes: usize, time_delta: Duration, @@ -150,65 +154,65 @@ impl Bandwidth { /// Returns `Duration::ZERO` for infinite or zero bandwidth. /// Saturates to `Duration::from_nanos(u64::MAX)` if the /// calculation would overflow. + #[inline] pub fn transfer_time(&self, bytes: u64) -> Duration { - // Handle infinite bandwidth sentinel: transfer is instantaneous - if self.bits_per_second == u64::MAX { - return Duration::ZERO; - } + let bps = self.bits_per_second; - if self.bits_per_second == 0 { + if bytes == 0 || bps == 0 || bps == u64::MAX { return Duration::ZERO; } - // Fast path: try u64 arithmetic first. At typical packet sizes - // (< 10 KB) and bandwidths, this won't overflow. - if let Some(nanos) = bytes.checked_mul(8 * NUM_NANOS_PER_SECOND) { - return Duration::from_nanos(nanos / self.bits_per_second); + // Fast path: exact u64 calculation. + if bytes <= u64::MAX / BITS_NANOS_PER_BYTE { + let nanos = bytes * BITS_NANOS_PER_BYTE / bps; + return Duration::from_nanos(nanos); } - // Slow path: use u128 for intermediate calculation to avoid overflow. - // At very large byte counts, bytes * 8 * NUM_NANOS_PER_SECOND can - // overflow u64. - let nanos = (bytes as u128) * (8 * NUM_NANOS_PER_SECOND) as u128; - let nanos = nanos / (self.bits_per_second as u128); - - // Saturate to Duration::MAX if result exceeds u64 range. - Duration::from_nanos(nanos.min(u64::MAX as u128) as u64) + // Slow path: exact u128 intermediate, then saturate. + let nanos = (bytes as u128) * BITS_NANOS_PER_BYTE_U128 / (bps as u128); + Duration::from_nanos(nanos.min(U64_MAX_U128) as u64) } - /// Returns the number of bytes that can be sent in - /// `time_period` at this bandwidth. + /// Returns the number of bytes that can be sent in `time_period`. /// - /// Returns `u64::MAX` for infinite bandwidth (unless - /// `time_period` is zero). Saturates to `u64::MAX` if the - /// calculation would overflow. + /// Returns `u64::MAX` for infinite bandwidth and non-zero duration. + /// Saturates to `u64::MAX` on overflow. + #[inline] pub fn to_bytes_per_period(self, time_period: Duration) -> u64 { - // Handle infinite bandwidth sentinel. - if self.bits_per_second == u64::MAX { - if time_period != Duration::ZERO { - return u64::MAX; - } else { - return 0; - } + let bps = self.bits_per_second; + let nanos = time_period.as_nanos(); + + if bps == 0 || nanos == 0 { + return 0; + } + + if bps == u64::MAX { + return u64::MAX; } - // Fast path: try u64 arithmetic first. At typical bandwidths (< 10 - // Gbps) and short time periods (< 1 second), this won't overflow. - if let Ok(time_nanos) = u64::try_from(time_period.as_nanos()) { - if let Some(bits) = self.bits_per_second.checked_mul(time_nanos) { - return bits / (8 * NUM_NANOS_PER_SECOND); + // Fast path: exact u64 calculation. + if nanos <= u64::MAX as u128 { + let nanos_u64 = nanos as u64; + + if bps <= u64::MAX / nanos_u64 { + return bps * nanos_u64 / BITS_NANOS_PER_BYTE; } } - // Slow path: use u128 for intermediate calculation to avoid overflow. - // At high bandwidths (e.g., 10+ Gbps) with non-trivial time periods, - // bits_per_second * time_period.as_nanos() can overflow u64. - let time_nanos = time_period.as_nanos(); - let bits = (self.bits_per_second as u128).saturating_mul(time_nanos); - let bytes = bits / (8 * NUM_NANOS_PER_SECOND) as u128; + // If the final result must exceed u64::MAX, saturate before multiplying. + // + // floor((bps * nanos) / BITS_NANOS_PER_BYTE) > u64::MAX + // iff bps * nanos >= (u64::MAX + 1) * BITS_NANOS_PER_BYTE + let saturation_threshold = ((u64::MAX as u128) + 1) * BITS_NANOS_PER_BYTE_U128; + + if (bps as u128) > (saturation_threshold - 1) / nanos { + return u64::MAX; + } - // Saturate to u64::MAX if result exceeds u64 range. - bytes.min(u64::MAX as u128) as u64 + // Slow path: exact u128 arithmetic once the fast-path bounds no longer + // hold but the final result still fits in u64. + let bytes = (bps as u128) * nanos / BITS_NANOS_PER_BYTE_U128; + bytes as u64 } } diff --git a/quiche/src/recovery/congestion/cubic.rs b/quiche/src/recovery/congestion/cubic.rs index ba9440ae94..493792df27 100644 --- a/quiche/src/recovery/congestion/cubic.rs +++ b/quiche/src/recovery/congestion/cubic.rs @@ -198,13 +198,13 @@ fn on_packet_sent( } fn on_packets_acked( - r: &mut Congestion, bytes_in_flight: usize, packets: &mut Vec, + r: &mut Congestion, bytes_in_flight: usize, packets: &mut [Acked], now: Instant, rtt_stats: &RttStats, ) { r.cubic_state.last_ack_time = Some(now); - for pkt in packets.drain(..) { - on_packet_acked(r, bytes_in_flight, &pkt, now, rtt_stats); + for pkt in packets.iter() { + on_packet_acked(r, bytes_in_flight, pkt, now, rtt_stats); } } diff --git a/quiche/src/recovery/congestion/mod.rs b/quiche/src/recovery/congestion/mod.rs index 8161d2b762..ce76f176f5 100644 --- a/quiche/src/recovery/congestion/mod.rs +++ b/quiche/src/recovery/congestion/mod.rs @@ -237,10 +237,11 @@ impl Congestion { (self.cc_ops.on_packets_acked)( self, bytes_in_flight, - acked, + acked.as_mut_slice(), now, rtt_stats, ); + acked.clear(); } } @@ -257,7 +258,7 @@ pub(crate) struct CongestionControlOps { pub on_packets_acked: fn( r: &mut Congestion, bytes_in_flight: usize, - packets: &mut Vec, + packets: &mut [Acked], now: Instant, rtt_stats: &RttStats, ), diff --git a/quiche/src/recovery/congestion/recovery.rs b/quiche/src/recovery/congestion/recovery.rs index 736854331a..ce86e1ee06 100644 --- a/quiche/src/recovery/congestion/recovery.rs +++ b/quiche/src/recovery/congestion/recovery.rs @@ -458,7 +458,7 @@ impl LegacyRecovery { &self, handshake_status: HandshakeStatus, now: Instant, ) -> (Option, Epoch) { let mut duration = - self.pto() * 2_u32.pow(self.pto_count.min(MAX_PTO_EXPONENT)); + self.pto() * (1_u32 << self.pto_count.min(MAX_PTO_EXPONENT)); // Arm PTO from now when there are no inflight packets. if self.bytes_in_flight.is_zero() { diff --git a/quiche/src/recovery/congestion/reno.rs b/quiche/src/recovery/congestion/reno.rs index 02f58e7b01..87ed246c25 100644 --- a/quiche/src/recovery/congestion/reno.rs +++ b/quiche/src/recovery/congestion/reno.rs @@ -61,11 +61,11 @@ pub fn on_packet_sent( } fn on_packets_acked( - r: &mut Congestion, _bytes_in_flight: usize, packets: &mut Vec, + r: &mut Congestion, _bytes_in_flight: usize, packets: &mut [Acked], now: Instant, rtt_stats: &RttStats, ) { - for pkt in packets.drain(..) { - on_packet_acked(r, &pkt, now, rtt_stats); + for pkt in packets.iter() { + on_packet_acked(r, pkt, now, rtt_stats); } } diff --git a/quiche/src/recovery/gcongestion/recovery.rs b/quiche/src/recovery/gcongestion/recovery.rs index 0372cbc0fd..7a291bb447 100644 --- a/quiche/src/recovery/gcongestion/recovery.rs +++ b/quiche/src/recovery/gcongestion/recovery.rs @@ -615,7 +615,7 @@ impl GRecovery { &self, handshake_status: HandshakeStatus, now: Instant, ) -> (Option, packet::Epoch) { let mut duration = - self.pto() * 2_u32.pow(self.pto_count.min(MAX_PTO_EXPONENT)); + self.pto() * (1_u32 << self.pto_count.min(MAX_PTO_EXPONENT)); // Arm PTO from now when there are no inflight packets. if self.bytes_in_flight.is_zero() { @@ -645,7 +645,7 @@ impl GRecovery { // Include max_ack_delay and backoff for Application Data. duration += self.rtt_stats.max_ack_delay * - 2_u32.pow(self.pto_count.min(MAX_PTO_EXPONENT)); + (1_u32 << self.pto_count.min(MAX_PTO_EXPONENT)); } let new_time = self.epochs[e] diff --git a/quiche/src/stream/mod.rs b/quiche/src/stream/mod.rs index d51bdc60fa..44a19452de 100644 --- a/quiche/src/stream/mod.rs +++ b/quiche/src/stream/mod.rs @@ -954,7 +954,7 @@ impl Ord for StreamPriorityKey { } } -intrusive_adapter!(pub StreamWritablePriorityAdapter = Arc: StreamPriorityKey { writable: RBTreeAtomicLink }); +intrusive_adapter!(pub StreamWritablePriorityAdapter = Arc: StreamPriorityKey { writable => RBTreeAtomicLink }); impl KeyAdapter<'_> for StreamWritablePriorityAdapter { type Key = StreamPriorityKey; @@ -964,7 +964,7 @@ impl KeyAdapter<'_> for StreamWritablePriorityAdapter { } } -intrusive_adapter!(pub StreamReadablePriorityAdapter = Arc: StreamPriorityKey { readable: RBTreeAtomicLink }); +intrusive_adapter!(pub StreamReadablePriorityAdapter = Arc: StreamPriorityKey { readable => RBTreeAtomicLink }); impl KeyAdapter<'_> for StreamReadablePriorityAdapter { type Key = StreamPriorityKey; @@ -974,7 +974,7 @@ impl KeyAdapter<'_> for StreamReadablePriorityAdapter { } } -intrusive_adapter!(pub StreamFlushablePriorityAdapter = Arc: StreamPriorityKey { flushable: RBTreeAtomicLink }); +intrusive_adapter!(pub StreamFlushablePriorityAdapter = Arc: StreamPriorityKey { flushable => RBTreeAtomicLink }); impl KeyAdapter<'_> for StreamFlushablePriorityAdapter { type Key = StreamPriorityKey;