diff --git a/tower/src/timeout/mod.rs b/tower/src/timeout/mod.rs index da3bbf98d..5f5c06261 100644 --- a/tower/src/timeout/mod.rs +++ b/tower/src/timeout/mod.rs @@ -68,3 +68,72 @@ where ResponseFuture::new(response, sleep) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + convert::Infallible, + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, + }; + use tokio::time::sleep; + use tower_service::Service; + + struct SlowService(Duration); + + impl Service<()> for SlowService { + type Response = (); + type Error = Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + let delay = self.0; + Box::pin(async move { + sleep(delay).await; + Ok(()) + }) + } + } + + struct FastService; + + impl Service<()> for FastService { + type Response = &'static str; + type Error = Infallible; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + Box::pin(async move { Ok("ok") }) + } + } + + #[tokio::test(start_paused = true)] + async fn elapsed_error_when_timeout_exceeded() { + let mut svc = Timeout::new(SlowService(Duration::from_secs(10)), Duration::from_secs(1)); + + let res = svc.call(()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(err.downcast_ref::().is_some()); + } + + #[tokio::test(start_paused = true)] + async fn response_passes_through_when_under_timeout() { + let mut svc = Timeout::new(FastService, Duration::from_secs(1)); + + let res = svc.call(()).await; + assert_eq!(res.unwrap(), "ok"); + } +}