从零构建中间件:Tower 核心设计的来龙去脉

在《手把手搞懂 Service 特质:Tower 核心设计的来龙去脉》那篇内容里,我们已经搞懂了 Service 的设计初衷,以及它为什么是现在这个样子。之前我们也写过几个简单的中间件,但当时走了不少捷径。这次咱们不偷懒,完完整整地复现一遍当前 Tower 框架里 "Timeout 中间件" 的实现过程。

要写一个靠谱的中间件,得在异步 Rust 的底层层面开发 ------ 这个层面会比你平时常用的层面稍深一点。不过别担心,这篇指南会把复杂的概念和逻辑讲明白,等你看完,不仅能自己写中间件,说不定还能给 Tower 生态贡献代码呢!

开始上手

我们要做的这个中间件,就是 Tower 里的tower::timeout::Timeout。它的核心作用很简单:给内部 Service 的 "响应任务"(也就是 Future)设个最大执行时间。如果内部 Service 在规定时间内没返回结果,就直接返回一个 "超时错误"。这样客户端就不用一直等,要么重试请求,要么告诉用户出问题了。

首先,我们明确第一步:定义一个 Timeout 结构体。这个结构体要存两样东西 ------ 被它包装的 "内部 Service",以及请求的超时时长。代码如下:

rust 复制代码
use std::time::Duration;

// 定义Timeout结构体:inner存被包装的内部Service,timeout存请求的超时时长
struct Timeout<S> {
    inner: S,
    timeout: Duration,
}

之前在《手把手搞懂 Service 特质:Tower 核心设计的来龙去脉》里提过一个关键点:Service 必须实现Clone特征。为啥?因为有时候需要把Service::call方法里的 "可变引用(&mut self)",变成 "能转移所有权的 self",再放进后续的 Future 里。所以,我们得给 Timeout 结构体加两个派生宏:#[derive(Debug)](方便调试看日志)和#[derive(Clone)](满足所有权转移需求):

R 复制代码
// 派生Debug(调试时能打印结构体信息)和Clone(支持所有权转移)特征
#[derive(Debug, Clone)]
struct Timeout<S> {
    inner: S,
    timeout: Duration,
}

接下来,给 Timeout 写个 "构造函数"------ 就是一个能创建 Timeout 实例的方法:

rust 复制代码
impl<S> Timeout<S> {
    // 构造函数:接收"内部Service"和"超时时长",返回Timeout实例
    pub fn new(inner: S, timeout: Duration) -> Self {
        Timeout { inner, timeout }
    }
}

这里有个小细节:虽然我们知道S最终要实现Service特征,但按照 Rust 的 API 规范,暂时不给S加约束 ------ 等后面需要的时候再加也不迟。

现在进入关键环节:给 Timeout 实现Service特征。咱们先搭个基础框架,这个框架啥也不做,就把所有请求 "转发" 给内部 Service。先把架子立起来,后面再加超时逻辑:

rust 复制代码
use tower::Service;
use std::task::{Context, Poll};

// 给Timeout<S>实现Service特征,约束:S必须是能处理Request的Service
impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
{
    type Response = S::Response; // 响应类型和内部Service保持一致
    type Error = S::Error;       // 错误类型和内部Service保持一致
    type Future = S::Future;     // 异步任务类型(Future)和内部Service保持一致

    // 轮询"是否就绪":判断当前能不能接收新请求
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // 咱们的中间件不关心"背压"(比如请求太多处理不过来),只要内部Service就绪,咱们就就绪
        self.inner.poll_ready(cx)
    }

    // 处理请求:把收到的请求直接传给内部Service
    fn call(&mut self, request: Request) -> Self::Future {
        self.inner.call(request)
    }
}

对新手来说,先写这种 "转发框架" 很有用 ------ 能帮你理清Service特征的结构,后面加逻辑时不容易乱。


核心逻辑:怎么加超时?

要实现超时,核心思路其实很简单:

  1. 调用内部 Service 的call方法,拿到它返回的 "响应任务(Future)";
  2. 同时创建一个 "超时任务(Future)"------ 比如用tokio::time::sleep,等指定时长后就完成;
  3. 同时盯着这两个任务:哪个先完成,就先处理哪个。如果 "超时任务" 先完成,就返回超时错误。

先试试写第一步:创建两个任务(响应任务和超时任务):

rust 复制代码
use tokio::time::sleep;

fn call(&mut self, request: Request) -> Self::Future {
    // 1. 调用内部Service,拿到"处理请求的响应任务"
    let response_future = self.inner.call(request);

    // 2. 创建"超时任务":等self.timeout这么久后就完成
    // 注意:Duration类型支持"复制",不用clone,直接传就行
    let sleep = tokio::time::sleep(self.timeout);

    // 这里后面要写"怎么同时处理两个任务"的逻辑
}

这里有个小问题:如果直接返回 "装箱的 Future"(比如Pin<Box<dyn Future<...>>>),会用到堆内存(Box),有额外开销。要是中间件嵌套很多层(比如 10 个、20 个),每个请求都要分配一次堆内存,性能会受影响。所以咱们得想个办法:不⽤ Box,自己定义一个 Future 类型。


自定义响应任务:ResponseFuture

咱们自己写一个ResponseFuture结构体,专门用来 "包装两个任务":内部 Service 的响应任务,和超时用的 sleep 任务。这个逻辑和 "用 Timeout 包装 Service" 很像,只不过这次包装的是 "Future(异步任务)":

rust 复制代码
use tokio::time::Sleep;

// 自定义的响应任务结构体:包装两个任务
pub struct ResponseFuture<F> {
    response_future: F, // 内部Service的"响应任务"
    sleep: Sleep,       // 超时用的"睡眠任务"
}

这里的泛型F,就是内部 Service 返回的 Future 类型。接下来,咱们更新 Timeout 的Service实现 ------ 把返回的 Future 类型,改成这个自定义的ResponseFuture

rust 复制代码
impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
{
    type Response = S::Response;
    type Error = S::Error;

    // 把Future类型改成自定义的ResponseFuture(用内部Service的Future当泛型参数)
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        // 1. 拿到内部Service的响应任务
        let response_future = self.inner.call(request);
        // 2. 创建超时睡眠任务
        let sleep = tokio::time::sleep(self.timeout);

        // 3. 把两个任务包装成自定义的ResponseFuture,返回出去
        ResponseFuture {
            response_future,
            sleep,
        }
    }
}

这里要特别注意一个点:Rust 的 Future 是 "惰性的"。啥意思?就是调用inner.call(request)时,不会立刻执行请求处理,只会返回一个 Future 对象;只有后面调用poll(轮询)时,这个任务才会真正开始干活。


给 ResponseFuture 实现 Future 特征

要让ResponseFuture能像普通 Future 一样被 "轮询",就得给它实现Future特征。咱们先搭个架子:

rust 复制代码
use std::{pin::Pin, future::Future};

// 给ResponseFuture<F>实现Future特征
// 约束:F必须是返回"Result<响应, 错误>"的Future
impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
{
    type Output = Result<Response, Error>; // 输出类型和内部Future一致

    // 轮询逻辑:核心是"同时盯两个任务"
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // 后面要写具体的轮询逻辑
    }
}

咱们想要的轮询逻辑很明确:

  1. 先看看 "响应任务"(response_future)有没有结果:有结果就直接返回;
  2. 如果响应任务还没好,再看看 "超时任务"(sleep)有没有完成:完成了就返回超时错误;
  3. 要是两个都没好,就告诉调用者 "还在等(Poll::Pending)"。

先试试写第一步:轮询响应任务。但直接写会报错:

rust 复制代码
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    // 尝试轮询响应任务------但这里会报错!
    match self.response_future.poll(cx) {
        Poll::Ready(result) => return Poll::Ready(result),
        Poll::Pending => {}
    }

    todo!()
}

报错原因是:selfPin<&mut Self>(固定的可变引用),直接访问self.response_future拿到的不是 "固定引用",而调用poll必须要Pin<&mut F>类型。这就涉及到 Rust 里的 "Pin(固定)" 概念 ------ 简单说,Pin 是为了防止某些异步任务被 "移动",导致内存安全问题。

不过不用怕,有个叫pin-project的库能帮我们解决这个问题。它能自动生成 "固定投影" 代码 ------ 所谓 "固定投影",就是从 "对整个结构体的固定引用",安全地拿到 "对结构体里某个字段的固定引用"。

用 pin-project 解决固定引用问题

先给ResponseFuture#[pin_project]派生宏,再给需要 "固定引用" 的字段加#[pin]属性:

rust 复制代码
use pin_project::pin_project;

// 加#[pin_project]:自动生成"固定投影"的代码
#[pin_project]
pub struct ResponseFuture<F> {
    #[pin] // 给response_future加#[pin]:需要固定引用
    response_future: F,
    #[pin] // 给sleep加#[pin]:也需要固定引用
    sleep: Sleep,
}

然后,在poll方法里用self.project()拿到 "带固定引用的字段":

rust 复制代码
impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
{
    type Output = Result<Response, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // 调用project():拿到每个字段的"固定引用"(如果字段加了#[pin])
        let this = self.project();

        // this.response_future 现在是 Pin<&mut F>,能调用poll了
        let response_future: Pin<&mut F> = this.response_future;
        // this.sleep 现在是 Pin<&mut Sleep>,也能调用poll了
        let sleep: Pin<&mut Sleep> = this.sleep;

        // 后面写轮询逻辑
    }
}

有了固定引用,咱们就能完整实现轮询逻辑了:

rust 复制代码
impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
{
    type Output = Result<Response, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        // 第一步:先查响应任务有没有结果
        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                // 内部Service已经返回结果了,直接把结果传出去
                return Poll::Ready(result);
            }
            Poll::Pending => {
                // 响应任务还没好,继续查超时
            }
        }

        // 第二步:查超时任务有没有完成(也就是超时了没)
        match this.sleep.poll(cx) {
            Poll::Ready(()) => {
                // 超时时间到了!但这里有个问题:返回什么错误?
                todo!()
            }
            Poll::Pending => {
                // 还没超时,继续等
            }
        }

        // 第三步:两个任务都没好,返回"还在等"
        Poll::Pending
    }
}

现在卡在最后一个问题上:超时的时候,该返回什么类型的错误?


解决错误类型问题

目前,我们说好了 "Timeout 的错误类型和内部 Service 一致",但内部 Service 的错误类型是泛型Error------ 我们根本不知道这个Error是什么,也没法创建一个 "超时错误" 的Error实例。

咱们有三种解决方案,咱们一个个分析,最后选最适合 Tower 的方案。

方案 1:用 "装箱的错误特征对象"

就是返回Box<dyn std::error::Error + Send + Sync>------ 简单说,不管是什么错误,都装到一个 "通用错误盒子" 里。这样不管内部 Service 返回什么错误,都能转成这个盒子类型,超时错误也能装进去。

方案 2:用枚举包两种错误

定义一个枚举,里面包含 "超时错误" 和 "内部服务错误" 两个选项:

rust 复制代码
enum TimeoutError<Error> {
    Timeout,       // 超时错误
    Service(Error) // 内部服务错误
}

但这个方案有个大问题:如果中间件嵌套多层(比如 A 包装 B,B 包装 C),错误类型就会变成AError<BError<CError<MyError>>>,写匹配逻辑时会非常麻烦,而且改中间件顺序会导致错误类型变样。

方案 3:要求内部错误能转成超时错误

定义一个TimeoutError结构体,然后要求内部 Service 的Error能从TimeoutError转过来(比如TimeoutError: Into<Error>)。但这样用户用自定义错误时,得手动写转换逻辑,很麻烦。

综合来看,方案 1 最适合 Tower------ 虽然需要一点堆内存(装箱),但胜在简单、灵活,嵌套多层也不怕。

实现方案 1:定义超时错误和通用错误类型

第一步:定义TimeoutError结构体,实现 Rust 的标准错误特征(std::error::Error):

rust 复制代码
use std::fmt;

// 超时错误结构体:加个私有字段(()),防止外部随便创建
#[derive(Debug, Default)]
pub struct TimeoutError(());

// 实现Display:错误信息的文字描述
impl fmt::Display for TimeoutError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.pad("request timed out") // 错误信息:"请求超时"
    }
}

// 实现Error:标记这是一个标准错误类型
impl std::error::Error for TimeoutError {}

第二步:给 "通用错误盒子" 起个别名,省得每次都写一大串:

rust 复制代码
// 通用错误类型别名(Tower里已经有这个类型,叫tower::BoxError)
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

第三步:更新ResponseFutureFuture实现 ------ 把错误类型改成BoxError,同时要求内部 Service 的错误能转成BoxError

rust 复制代码
impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
    // 约束:内部Service的错误能转成BoxError
    Error: Into<BoxError>,
{
    // 输出类型的错误改成BoxError
    type Output = Result<Response, BoxError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        // 轮询响应任务:把内部错误转成BoxError
        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                // 用map_err把内部错误转成BoxError
                let result = result.map_err(Into::into);
                return Poll::Ready(result);
            }
            Poll::Pending => {}
        }

        // 超时了:创建TimeoutError,装箱后返回
        match this.sleep.poll(cx) {
            Poll::Ready(()) => {
                let error = Box::new(TimeoutError(())); // 把超时错误装箱
                return Poll::Ready(Err(error));
            }
            Poll::Pending => {}
        }

        Poll::Pending
    }
}

最后,更新 Timeout 的Service实现 ------ 错误类型也要改成BoxError,并且加上同样的约束:

rust 复制代码
impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
    // 和ResponseFuture保持一致:内部错误能转成BoxError
    S::Error: Into<BoxError>,
{
    type Response = S::Response;
    type Error = BoxError; // 错误类型改成BoxError
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        // 轮询就绪时,也要把内部错误转成BoxError
        self.inner.poll_ready(cx).map_err(Into::into)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let response_future = self.inner.call(request);
        let sleep = tokio::time::sleep(self.timeout);

        ResponseFuture {
            response_future,
            sleep,
        }
    }
}

总结

到这里,咱们就完整复现了 Tower 里 Timeout 中间件的实现!最终代码如下:

rust 复制代码
use pin_project::pin_project;
use std::time::Duration;
use std::{
    fmt,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tokio::time::Sleep;
use tower::Service;

// 超时中间件结构体:包装内部Service和超时时长,支持调试和克隆
#[derive(Debug, Clone)]
struct Timeout<S> {
    inner: S,
    timeout: Duration,
}

impl<S> Timeout<S> {
    // 构造函数:接收内部Service和超时时长,返回Timeout实例
    fn new(inner: S, timeout: Duration) -> Self {
        Timeout { inner, timeout }
    }
}

// 给Timeout<S>实现Service特征
impl<S, Request> Service<Request> for Timeout<S>
where
    S: Service<Request>,
    S::Error: Into<BoxError>, // 约束:内部错误能转成BoxError
{
    type Response = S::Response;
    type Error = BoxError;
    type Future = ResponseFuture<S::Future>;

    // 轮询就绪状态:转发内部Service的状态,同时把错误转成BoxError
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map_err(Into::into)
    }

    // 处理请求:创建两个任务(响应任务+超时任务),包装成ResponseFuture返回
    fn call(&mut self, request: Request) -> Self::Future {
        let response_future = self.inner.call(request);
        let sleep = tokio::time::sleep(self.timeout);

        ResponseFuture {
            response_future,
            sleep,
        }
    }
}

// 自定义响应任务:包装响应任务和超时任务,支持固定投影
#[pin_project]
struct ResponseFuture<F> {
    #[pin]
    response_future: F,
    #[pin]
    sleep: Sleep,
}

// 给ResponseFuture实现Future特征
impl<F, Response, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response, Error>>,
    Error: Into<BoxError>, // 约束:内部错误能转成BoxError
{
    type Output = Result<Response, BoxError>;

    // 轮询逻辑:先查响应,再查超时,都没好就返回Pending
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        // 先查响应任务
        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                let result = result.map_err(Into::into);
                return Poll::Ready(result);
            }
            Poll::Pending => {}
        }

        // 再查超时任务
        match this.sleep.poll(cx) {
            Poll::Ready(()) => {
                let error = Box::new(TimeoutError(()));
                return Poll::Ready(Err(error));
            }
            Poll::Pending => {}
        }

        Poll::Pending
    }
}

// 超时错误结构体:私有字段防止外部构造,实现标准错误特征
#[derive(Debug, Default)]
struct TimeoutError(());

impl fmt::Display for TimeoutError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.pad("request timed out")
    }
}

impl std::error::Error for TimeoutError {}

// 通用错误类型别名:简化"装箱错误特征对象"的写法
type BoxError = Box<dyn std::error::Error + Send + Sync>;

其实大多数 Tower 中间件,都是用这种 "包装 + 转发" 的思路实现的:

  1. 定义一个结构体,包装内部 Service;
  2. 给这个结构体实现Service特征,核心逻辑在call里;
  3. 自定义一个 Future,包装内部 Service 的 Future,实现Future特征处理异步逻辑。

除了 Timeout,还有几个常用的中间件也用了这个模式:

  • ConcurrencyLimit:限制同时处理的最大请求数;
  • LoadShed:当内部 Service 忙不过来时,直接拒绝新请求(削峰);
  • Steer:把请求路由到不同的 Service(类似负载均衡)。

现在你已经掌握了写中间件的核心方法!如果想多练手,可以试试这几个小任务:

  1. 不用tokio::time::sleep,改用tokio::time::timeout实现超时逻辑;
  2. 写一个 "适配器":用闭包修改请求、响应或错误(类似Result::map);
  3. 实现ConcurrencyLimit(提示:需要用PollSemaphore控制并发数)。