Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 46 additions & 42 deletions quiche/src/recovery/bandwidth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ impl std::ops::Mul<Duration> for Bandwidth {
}
}

const BITS_NANOS_PER_BYTE: u64 = 8 * NUM_NANOS_PER_SECOND;
const BITS_NANOS_PER_BYTE_U128: u128 = BITS_NANOS_PER_BYTE as u128;
const U64_MAX_U128: u128 = u64::MAX as u128;

impl Bandwidth {
pub const fn from_bytes_and_time_delta(
bytes: usize, time_delta: Duration,
Expand Down Expand Up @@ -150,65 +154,65 @@ impl Bandwidth {
/// Returns `Duration::ZERO` for infinite or zero bandwidth.
/// Saturates to `Duration::from_nanos(u64::MAX)` if the
/// calculation would overflow.
#[inline]
pub fn transfer_time(&self, bytes: u64) -> Duration {
// Handle infinite bandwidth sentinel: transfer is instantaneous
if self.bits_per_second == u64::MAX {
return Duration::ZERO;
}
let bps = self.bits_per_second;

if self.bits_per_second == 0 {
if bytes == 0 || bps == 0 || bps == u64::MAX {
return Duration::ZERO;
}

// Fast path: try u64 arithmetic first. At typical packet sizes
// (< 10 KB) and bandwidths, this won't overflow.
if let Some(nanos) = bytes.checked_mul(8 * NUM_NANOS_PER_SECOND) {
return Duration::from_nanos(nanos / self.bits_per_second);
// Fast path: exact u64 calculation.
if bytes <= u64::MAX / BITS_NANOS_PER_BYTE {
let nanos = bytes * BITS_NANOS_PER_BYTE / bps;
return Duration::from_nanos(nanos);
}

// Slow path: use u128 for intermediate calculation to avoid overflow.
// At very large byte counts, bytes * 8 * NUM_NANOS_PER_SECOND can
// overflow u64.
let nanos = (bytes as u128) * (8 * NUM_NANOS_PER_SECOND) as u128;
let nanos = nanos / (self.bits_per_second as u128);

// Saturate to Duration::MAX if result exceeds u64 range.
Duration::from_nanos(nanos.min(u64::MAX as u128) as u64)
// Slow path: exact u128 intermediate, then saturate.
let nanos = (bytes as u128) * BITS_NANOS_PER_BYTE_U128 / (bps as u128);
Duration::from_nanos(nanos.min(U64_MAX_U128) as u64)
}

/// Returns the number of bytes that can be sent in
/// `time_period` at this bandwidth.
/// Returns the number of bytes that can be sent in `time_period`.
///
/// Returns `u64::MAX` for infinite bandwidth (unless
/// `time_period` is zero). Saturates to `u64::MAX` if the
/// calculation would overflow.
/// Returns `u64::MAX` for infinite bandwidth and non-zero duration.
/// Saturates to `u64::MAX` on overflow.
#[inline]
pub fn to_bytes_per_period(self, time_period: Duration) -> u64 {
// Handle infinite bandwidth sentinel.
if self.bits_per_second == u64::MAX {
if time_period != Duration::ZERO {
return u64::MAX;
} else {
return 0;
}
let bps = self.bits_per_second;
let nanos = time_period.as_nanos();

if bps == 0 || nanos == 0 {
return 0;
}

if bps == u64::MAX {
return u64::MAX;
}

// Fast path: try u64 arithmetic first. At typical bandwidths (< 10
// Gbps) and short time periods (< 1 second), this won't overflow.
if let Ok(time_nanos) = u64::try_from(time_period.as_nanos()) {
if let Some(bits) = self.bits_per_second.checked_mul(time_nanos) {
return bits / (8 * NUM_NANOS_PER_SECOND);
// Fast path: exact u64 calculation.
if nanos <= u64::MAX as u128 {
let nanos_u64 = nanos as u64;

if bps <= u64::MAX / nanos_u64 {
return bps * nanos_u64 / BITS_NANOS_PER_BYTE;
}
}

// Slow path: use u128 for intermediate calculation to avoid overflow.
// At high bandwidths (e.g., 10+ Gbps) with non-trivial time periods,
// bits_per_second * time_period.as_nanos() can overflow u64.
let time_nanos = time_period.as_nanos();
let bits = (self.bits_per_second as u128).saturating_mul(time_nanos);
let bytes = bits / (8 * NUM_NANOS_PER_SECOND) as u128;
// If the final result must exceed u64::MAX, saturate before multiplying.
//
// floor((bps * nanos) / BITS_NANOS_PER_BYTE) > u64::MAX
// iff bps * nanos >= (u64::MAX + 1) * BITS_NANOS_PER_BYTE
let saturation_threshold = ((u64::MAX as u128) + 1) * BITS_NANOS_PER_BYTE_U128;

if (bps as u128) > (saturation_threshold - 1) / nanos {
return u64::MAX;
}

// Saturate to u64::MAX if result exceeds u64 range.
bytes.min(u64::MAX as u128) as u64
// Slow path: exact u128 arithmetic once the fast-path bounds no longer
// hold but the final result still fits in u64.
let bytes = (bps as u128) * nanos / BITS_NANOS_PER_BYTE_U128;
bytes as u64
}
}

Expand Down
6 changes: 3 additions & 3 deletions quiche/src/recovery/congestion/cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ fn on_packet_sent(
}

fn on_packets_acked(
r: &mut Congestion, bytes_in_flight: usize, packets: &mut Vec<Acked>,
r: &mut Congestion, bytes_in_flight: usize, packets: &mut [Acked],
now: Instant, rtt_stats: &RttStats,
) {
r.cubic_state.last_ack_time = Some(now);

for pkt in packets.drain(..) {
on_packet_acked(r, bytes_in_flight, &pkt, now, rtt_stats);
for pkt in packets.iter() {
on_packet_acked(r, bytes_in_flight, pkt, now, rtt_stats);
}
}

Expand Down
5 changes: 3 additions & 2 deletions quiche/src/recovery/congestion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,11 @@ impl Congestion {
(self.cc_ops.on_packets_acked)(
self,
bytes_in_flight,
acked,
acked.as_mut_slice(),
now,
rtt_stats,
);
acked.clear();
}
}

Expand All @@ -257,7 +258,7 @@ pub(crate) struct CongestionControlOps {
pub on_packets_acked: fn(
r: &mut Congestion,
bytes_in_flight: usize,
packets: &mut Vec<Acked>,
packets: &mut [Acked],
now: Instant,
rtt_stats: &RttStats,
),
Expand Down
2 changes: 1 addition & 1 deletion quiche/src/recovery/congestion/recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ impl LegacyRecovery {
&self, handshake_status: HandshakeStatus, now: Instant,
) -> (Option<Instant>, Epoch) {
let mut duration =
self.pto() * 2_u32.pow(self.pto_count.min(MAX_PTO_EXPONENT));
self.pto() * (1_u32 << self.pto_count.min(MAX_PTO_EXPONENT));

// Arm PTO from now when there are no inflight packets.
if self.bytes_in_flight.is_zero() {
Expand Down
6 changes: 3 additions & 3 deletions quiche/src/recovery/congestion/reno.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ pub fn on_packet_sent(
}

fn on_packets_acked(
r: &mut Congestion, _bytes_in_flight: usize, packets: &mut Vec<Acked>,
r: &mut Congestion, _bytes_in_flight: usize, packets: &mut [Acked],
now: Instant, rtt_stats: &RttStats,
) {
for pkt in packets.drain(..) {
on_packet_acked(r, &pkt, now, rtt_stats);
for pkt in packets.iter() {
on_packet_acked(r, pkt, now, rtt_stats);
}
}

Expand Down
4 changes: 2 additions & 2 deletions quiche/src/recovery/gcongestion/recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ impl GRecovery {
&self, handshake_status: HandshakeStatus, now: Instant,
) -> (Option<Instant>, packet::Epoch) {
let mut duration =
self.pto() * 2_u32.pow(self.pto_count.min(MAX_PTO_EXPONENT));
self.pto() * (1_u32 << self.pto_count.min(MAX_PTO_EXPONENT));

// Arm PTO from now when there are no inflight packets.
if self.bytes_in_flight.is_zero() {
Expand Down Expand Up @@ -645,7 +645,7 @@ impl GRecovery {

// Include max_ack_delay and backoff for Application Data.
duration += self.rtt_stats.max_ack_delay *
2_u32.pow(self.pto_count.min(MAX_PTO_EXPONENT));
(1_u32 << self.pto_count.min(MAX_PTO_EXPONENT));
}

let new_time = self.epochs[e]
Expand Down
6 changes: 3 additions & 3 deletions quiche/src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ impl Ord for StreamPriorityKey {
}
}

intrusive_adapter!(pub StreamWritablePriorityAdapter = Arc<StreamPriorityKey>: StreamPriorityKey { writable: RBTreeAtomicLink });
intrusive_adapter!(pub StreamWritablePriorityAdapter = Arc<StreamPriorityKey>: StreamPriorityKey { writable => RBTreeAtomicLink });

impl KeyAdapter<'_> for StreamWritablePriorityAdapter {
type Key = StreamPriorityKey;
Expand All @@ -964,7 +964,7 @@ impl KeyAdapter<'_> for StreamWritablePriorityAdapter {
}
}

intrusive_adapter!(pub StreamReadablePriorityAdapter = Arc<StreamPriorityKey>: StreamPriorityKey { readable: RBTreeAtomicLink });
intrusive_adapter!(pub StreamReadablePriorityAdapter = Arc<StreamPriorityKey>: StreamPriorityKey { readable => RBTreeAtomicLink });

impl KeyAdapter<'_> for StreamReadablePriorityAdapter {
type Key = StreamPriorityKey;
Expand All @@ -974,7 +974,7 @@ impl KeyAdapter<'_> for StreamReadablePriorityAdapter {
}
}

intrusive_adapter!(pub StreamFlushablePriorityAdapter = Arc<StreamPriorityKey>: StreamPriorityKey { flushable: RBTreeAtomicLink });
intrusive_adapter!(pub StreamFlushablePriorityAdapter = Arc<StreamPriorityKey>: StreamPriorityKey { flushable => RBTreeAtomicLink });

impl KeyAdapter<'_> for StreamFlushablePriorityAdapter {
type Key = StreamPriorityKey;
Expand Down