Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tower/src/timeout/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,72 @@ where
ResponseFuture::new(response, sleep)
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::sleep;
use tower_service::Service;

struct SlowService(Duration);

impl Service<()> for SlowService {
type Response = ();
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<(), Infallible>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: ()) -> Self::Future {
let delay = self.0;
Box::pin(async move {
sleep(delay).await;
Ok(())
})
}
}

struct FastService;

impl Service<()> for FastService {
type Response = &'static str;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<&'static str, Infallible>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, _req: ()) -> Self::Future {
Box::pin(async move { Ok("ok") })
}
}

#[tokio::test(start_paused = true)]
async fn elapsed_error_when_timeout_exceeded() {
let mut svc = Timeout::new(SlowService(Duration::from_secs(10)), Duration::from_secs(1));

let res = svc.call(()).await;
assert!(res.is_err());

let err = res.unwrap_err();
assert!(err.downcast_ref::<error::Elapsed>().is_some());
}

#[tokio::test(start_paused = true)]
async fn response_passes_through_when_under_timeout() {
let mut svc = Timeout::new(FastService, Duration::from_secs(1));

let res = svc.call(()).await;
assert_eq!(res.unwrap(), "ok");
}
}