diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index c2bceb5d908f14..afb1818866955a 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: from aiortc.rtcdatachannel import RTCDataChannel import aioice.ice +import aiortc.rtcrtpsender +from aiortc.rtp import RTCP_PSFB_APP, RtcpPsfbPacket, unpack_remb_fci from openpilot.system.webrtc.models import StreamRequestBody from openpilot.system.webrtc.schema import generate_field @@ -51,6 +53,17 @@ def _primary_host_addresses(use_ipv4: bool, use_ipv6: bool) -> list[str]: return [primary, ] aioice.ice.get_host_addresses = _primary_host_addresses +# aiortc patch: capture the browser's Receiver Estimated Maximum Bitrate on each sender +_handle_rtcp_packet = aiortc.rtcrtpsender.RTCRtpSender._handle_rtcp_packet +async def _handle_rtcp_packet_with_remb(self, packet): + if isinstance(packet, RtcpPsfbPacket) and packet.fmt == RTCP_PSFB_APP: + with contextlib.suppress(ValueError): + bitrate, ssrcs = unpack_remb_fci(packet.fci) + if self._ssrc in ssrcs: + self._remb_bitrate = bitrate + return await _handle_rtcp_packet(self, packet) +aiortc.rtcrtpsender.RTCRtpSender._handle_rtcp_packet = _handle_rtcp_packet_with_remb + class AsyncTaskRunner: def __init__(self): @@ -163,13 +176,20 @@ async def add_services_if_needed(self, services): class LivestreamBitrateController(AsyncTaskRunner): - bitrates = [500_000, 1_500_000, int(os.environ.get("STREAM_BITRATE", 5_000_000))] - label_to_bitrate = { "high": bitrates[2], "med": bitrates[1], "low": bitrates[0]} - sample_interval = 0.2 - high_level = 0.1 # drop immediately - med_level = 0.05 # drop after # of samples - low_level = 0 # raise after # of samples - down_samples = 5 # 1s + bitrates = [ + 500_000, + 1_000_000, + 1_500_000, + 2_500_000, + 4_000_000, + 5_000_000, + ] + label_to_bitrate = { "high": bitrates[5], "med": bitrates[2], "low": bitrates[0]} + sample_interval = 0.5 + higher_factor = 1.5 + lower_factor = 0.9 + loss_threshold = 0.05 + backoff_steps = 2 param_name = "LivestreamEncoderBitrate" def __init__(self, peer_connection: Any, params: Params, enabled: bool = True): @@ -177,11 +197,9 @@ def __init__(self, peer_connection: Any, params: Params, enabled: bool = True): self.pc = peer_connection self.params = params - self.level = 2 + self.level = 5 self._publish(self.bitrates[self.level]) - self.prev_lost, self.prev_sent = None, None - self.counter = 0 - self.up_samples = 5 # 1s + self.backoff = 0 self._auto = True self._enabled = enabled @@ -191,44 +209,41 @@ def enable(self, enable: bool): async def run(self): while True: await asyncio.sleep(self.sample_interval) - if not self._enabled: - continue - if not self._auto: + if not self._enabled or not self._auto: continue - loss_rate = await self._sample() - if loss_rate is None: + estimate = self._bandwidth_estimate() + loss = await self._packet_loss() + if estimate is None or loss is None: continue - if loss_rate >= self.med_level and self.level > 0: - self.counter += 1 - if self.counter >= self.down_samples or loss_rate >= self.high_level: + + if estimate < self.bitrates[self.level] * self.lower_factor or loss > self.loss_threshold: + if self.level > 0: self.level -= 1 - self.up_samples *= 2 # exponential backoff before raising again - self.counter = 0 - self._publish(self.bitrates[self.level]) - elif loss_rate <= self.low_level and self.level < len(self.bitrates) - 1: - self.counter -= 1 - if -self.counter >= self.up_samples: - self.level += 1 - self.counter = 0 - self._publish(self.bitrates[self.level]) - - async def _sample(self) -> float | None: + self.backoff = self.backoff_steps + self.backoff_steps *= 2 + elif estimate > self.bitrates[self.level] * self.higher_factor or loss < self.loss_threshold: + if self.backoff > 0: + self.backoff -= 1 + continue + if self.level < 2: self.level += 1 + self._publish(self.bitrates[self.level]) + + def _bandwidth_estimate(self) -> float | None: + estimate = None + for sender in self.pc.getSenders(): + bitrate = getattr(sender, "_remb_bitrate", None) + if bitrate is not None: + estimate = bitrate if estimate is None else min(estimate, bitrate) + return estimate + + async def _packet_loss(self) -> float: report = await self.pc.getStats() - packets_lost = packets_sent = 0 + loss = 0.0 for s in report.values(): if s.type == "remote-inbound-rtp": - packets_lost += s.packetsLost - elif s.type == "outbound-rtp": - packets_sent += s.packetsSent - - if self.prev_lost is None: - self.prev_lost, self.prev_sent = packets_lost, packets_sent - return None - lost_delta = max(0, packets_lost - self.prev_lost) - sent_delta = max(0, packets_sent - self.prev_sent) - self.prev_lost, self.prev_sent = packets_lost, packets_sent - return lost_delta / sent_delta if sent_delta else 0.0 + loss = max(loss, s.fractionLost / 256) # fractionLost is the raw 8-bit RR field, not a fraction + return loss def _publish(self, bitrate: float): self.params.put(self.param_name, bitrate)