过程宏高级应用
在 Rust 精通篇中,我们将深入探索 Rust 的过程宏系统。过程宏是 Rust 元编程的强大工具,允许你在编译时生成代码。在本章中,我们将学习如何创建各种类型的过程宏,包括派生宏、属性宏和函数宏,并探索它们的高级应用场景。
过程宏基础回顾
在深入高级主题之前,让我们简要回顾 Rust 的过程宏系统:
rust
// 在 Cargo.toml 中声明过程宏 crate
// [lib]
// proc-macro = true
use proc_macro::TokenStream;
#[proc_macro_derive(MyDerive)]
pub fn my_derive(input: TokenStream) -> TokenStream {
// 解析输入的 TokenStream
// 生成新的代码
// 返回生成的代码作为 TokenStream
"fn generated_function() { println!(\"Hello from generated function!\"); }".parse().unwrap()
}
Rust 支持三种类型的过程宏:
- 派生宏(Derive Macros) :使用
#[derive(MacroName)]
语法,为结构体或枚举自动实现特征 - 属性宏(Attribute Macros) :使用
#[macro_name]
语法,修改或扩展带注解的项 - 函数宏(Function-like Macros) :使用
macro_name!(...)
语法,类似于声明宏但功能更强大
过程宏开发工具
syn 和 quote 库
开发过程宏通常需要使用两个关键库:
- syn:用于解析 Rust 代码为语法树
- quote:用于将语法树转换回 Rust 代码
rust
// Cargo.toml
// [dependencies]
// syn = { version = "1.0", features = ["full"] }
// quote = "1.0"
// proc-macro2 = "1.0"
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
#[proc_macro_derive(HelloWorld)]
pub fn hello_world_derive(input: TokenStream) -> TokenStream {
// 解析输入为语法树
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
// 使用 quote! 生成代码
let expanded = quote! {
impl HelloWorld for #name {
fn hello_world() {
println!("Hello, World! My name is {}", stringify!(#name));
}
}
};
// 将生成的代码转换为 TokenStream
TokenStream::from(expanded)
}
proc-macro2 库
proc-macro2
提供了与标准库 proc_macro
兼容的类型,但可以在过程宏 crate 之外使用,便于测试:
rust
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::Ident;
// 可以在非过程宏 crate 中测试
fn generate_impl(name: &str) -> TokenStream2 {
let ident = Ident::new(name, Span::call_site());
quote! {
impl #ident {
fn new() -> Self {
Self {}
}
}
}
}
高级派生宏
自定义派生宏实现序列化
下面是一个实现自定义序列化的派生宏示例:
rust
use proc_macro::TokenStream;
use quote::{quote, format_ident};
use syn::{parse_macro_input, Data, DeriveInput, Fields};
#[proc_macro_derive(Serialize)]
pub fn serialize_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
// 获取结构体字段
let fields = match input.data {
Data::Struct(data) => match data.fields {
Fields::Named(fields) => fields.named,
_ => panic!("Serialize only supports structs with named fields"),
},
_ => panic!("Serialize only supports structs"),
};
// 为每个字段生成序列化代码
let field_serializations = fields.iter().map(|field| {
let field_name = field.ident.as_ref().unwrap();
let field_name_str = field_name.to_string();
quote! {
serialized.push_str(&format!("\"{}\": {}, ", #field_name_str, self.#field_name.serialize()));
}
});
// 生成实现代码
let expanded = quote! {
impl Serialize for #name {
fn serialize(&self) -> String {
let mut serialized = String::from("{");
#(#field_serializations)*
// 移除最后的逗号和空格
if serialized.len() > 1 {
serialized.truncate(serialized.len() - 2);
}
serialized.push_str("}");
serialized
}
}
};
TokenStream::from(expanded)
}
带参数的派生宏
我们可以使用属性参数扩展派生宏的功能:
rust
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Meta, NestedMeta, Lit};
#[proc_macro_derive(Builder, attributes(builder))]
pub fn builder_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
let builder_name = format_ident!("{}Builder", name);
// 处理结构体字段
let fields = match input.data {
syn::Data::Struct(data) => match data.fields {
syn::Fields::Named(fields) => fields.named,
_ => panic!("Builder only supports structs with named fields"),
},
_ => panic!("Builder only supports structs"),
};
// 提取字段信息和属性
let field_defs = fields.iter().map(|field| {
let field_name = field.ident.as_ref().unwrap();
let field_type = &field.ty;
// 检查字段是否有 #[builder(default = "...")] 属性
let default_value = field.attrs.iter()
.filter(|attr| attr.path.is_ident("builder"))
.filter_map(|attr| attr.parse_meta().ok())
.filter_map(|meta| match meta {
Meta::List(list) => Some(list.nested),
_ => None,
})
.flat_map(|nested| nested.into_iter())
.filter_map(|nested| match nested {
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("default") => {
match nv.lit {
Lit::Str(lit) => Some(lit.value()),
_ => None,
}
},
_ => None,
})
.next();
// 根据是否有默认值生成不同的字段定义
if let Some(default) = default_value {
quote! {
#field_name: Option<#field_type>,
}
} else {
quote! {
#field_name: Option<#field_type>,
}
}
});
// 生成 setter 方法
let setters = fields.iter().map(|field| {
let field_name = field.ident.as_ref().unwrap();
let field_type = &field.ty;
quote! {
pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
self.#field_name = Some(value);
self
}
}
});
// 生成 build 方法
let build_fields = fields.iter().map(|field| {
let field_name = field.ident.as_ref().unwrap();
// 检查字段是否有默认值
let default_value = field.attrs.iter()
.filter(|attr| attr.path.is_ident("builder"))
.filter_map(|attr| attr.parse_meta().ok())
.filter_map(|meta| match meta {
Meta::List(list) => Some(list.nested),
_ => None,
})
.flat_map(|nested| nested.into_iter())
.filter_map(|nested| match nested {
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("default") => {
match nv.lit {
Lit::Str(lit) => Some(lit.value()),
_ => None,
}
},
_ => None,
})
.next();
if let Some(default) = default_value {
quote! {
#field_name: self.#field_name.clone().unwrap_or_else(|| #default),
}
} else {
quote! {
#field_name: self.#field_name.clone().ok_or(format!("Field {} is required", stringify!(#field_name)))?,
}
}
});
// 生成完整的 Builder 实现
let expanded = quote! {
#[derive(Clone, Default)]
pub struct #builder_name {
#(#field_defs)*
}
impl #builder_name {
pub fn new() -> Self {
Default::default()
}
#(#setters)*
pub fn build(&self) -> Result<#name, String> {
Ok(#name {
#(#build_fields)*
})
}
}
impl #name {
pub fn builder() -> #builder_name {
#builder_name::new()
}
}
};
TokenStream::from(expanded)
}
高级属性宏
自定义路由属性宏
下面是一个用于 Web 框架的路由属性宏示例:
rust
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, AttributeArgs, NestedMeta, Lit, LitStr};
#[proc_macro_attribute]
pub fn route(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as AttributeArgs);
let input_fn = parse_macro_input!(input as ItemFn);
// 提取函数信息
let fn_name = &input_fn.sig.ident;
let fn_block = &input_fn.block;
// 解析路由参数
let mut method = String::from("GET");
let mut path = String::new();
for arg in args {
match arg {
NestedMeta::Meta(meta) => {
// 处理 method = "POST" 形式
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("method") {
if let Lit::Str(lit) = nv.lit {
method = lit.value();
}
}
}
},
NestedMeta::Lit(lit) => {
// 处理 "/users" 形式
if let Lit::Str(lit) = lit {
path = lit.value();
}
},
}
}
// 生成路由注册代码
let expanded = quote! {
#[allow(non_camel_case_types)]
pub struct #fn_name;
impl Route for #fn_name {
fn method(&self) -> &'static str {
#method
}
fn path(&self) -> &'static str {
#path
}
fn handler(&self, req: Request) -> Response {
let handler = || #fn_block;
handler()
}
}
// 注册路由
inventory::submit! {
RouteItem {
route: Box::new(#fn_name)
}
}
};
TokenStream::from(expanded)
}
条件编译属性宏
创建一个用于条件编译的属性宏:
rust
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, AttributeArgs, NestedMeta, Lit};
#[proc_macro_attribute]
pub fn platform(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as AttributeArgs);
let input_fn = parse_macro_input!(input as ItemFn);
// 提取目标平台
let mut target_platforms = Vec::new();
for arg in args {
if let NestedMeta::Lit(Lit::Str(lit)) = arg {
target_platforms.push(lit.value());
}
}
// 获取当前平台
let current_platform = if cfg!(target_os = "windows") {
"windows"
} else if cfg!(target_os = "macos") {
"macos"
} else if cfg!(target_os = "linux") {
"linux"
} else {
"unknown"
};
// 检查当前平台是否在目标平台列表中
let should_include = target_platforms.iter().any(|p| p == current_platform);
// 根据条件生成代码
let output = if should_include {
// 包含原始函数
quote! { #input_fn }
} else {
// 生成一个空的存根函数
let fn_name = &input_fn.sig.ident;
let fn_args = &input_fn.sig.inputs;
let fn_output = &input_fn.sig.output;
quote! {
#[allow(unused_variables)]
fn #fn_name(#fn_args) #fn_output {
panic!("Function not available on this platform");
}
}
};
TokenStream::from(output)
}
高级函数宏
SQL 查询构建宏
创建一个用于构建类型安全 SQL 查询的函数宏:
rust
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, LitStr, parse::Parse, parse::ParseStream, Token, Ident, Result as SynResult};
// 定义查询参数解析器
struct SqlQuery {
query: LitStr,
params: Vec<Ident>,
}
impl Parse for SqlQuery {
fn parse(input: ParseStream) -> SynResult<Self> {
let query = input.parse::<LitStr>()?;
let mut params = Vec::new();
// 解析参数列表
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
while !input.is_empty() {
params.push(input.parse::<Ident>()?);
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
} else {
break;
}
}
}
Ok(SqlQuery { query, params })
}
}
#[proc_macro]
pub fn sql(input: TokenStream) -> TokenStream {
let SqlQuery { query, params } = parse_macro_input!(input as SqlQuery);
let query_string = query.value();
// 解析查询字符串,查找参数占位符
let mut param_positions = Vec::new();
let mut modified_query = String::new();
let mut current_pos = 0;
for (i, c) in query_string.char_indices() {
if c == '?' && i + 1 < query_string.len() {
if let Some(param_index) = query_string[i+1..].chars().next().and_then(|c| c.to_digit(10)) {
param_positions.push((current_pos, param_index as usize - 1));
modified_query.push('?');
current_pos += 1;
// 跳过数字
continue;
}
}
modified_query.push(c);
}
// 生成参数绑定代码
let param_bindings = param_positions.iter().map(|(pos, idx)| {
if *idx < params.len() {
let param = ¶ms[*idx];
quote! {
query.bind_param(#pos, &#param);
}
} else {
quote! {
compile_error!("Parameter index out of bounds");
}
}
});
// 生成最终代码
let expanded = quote! {
{
let mut query = Query::new(#modified_query);
#(#param_bindings)*
query
}
};
TokenStream::from(expanded)
}
测试生成宏
创建一个自动生成测试用例的函数宏:
rust
use proc_macro::TokenStream;
use quote::{quote, format_ident};
use syn::{parse_macro_input, LitStr, parse::Parse, parse::ParseStream, Token, Ident, Result as SynResult, Expr};
// 定义测试用例结构
struct TestCase {
name: Ident,
inputs: Vec<Expr>,
expected: Expr,
}
struct TestCases {
function_name: Ident,
cases: Vec<TestCase>,
}
impl Parse for TestCases {
fn parse(input: ParseStream) -> SynResult<Self> {
let function_name = input.parse::<Ident>()?;
input.parse::<Token![,]>()?;
let mut cases = Vec::new();
while !input.is_empty() {
// 解析测试名称
let name = input.parse::<Ident>()?;
input.parse::<Token![:]>()?;
// 解析输入参数
let content;
syn::parenthesized!(content in input);
let mut inputs = Vec::new();
while !content.is_empty() {
inputs.push(content.parse::<Expr>()?);
if content.peek(Token![,]) {
content.parse::<Token![,]>()?;
} else {
break;
}
}
// 解析期望输出
input.parse::<Token![=>]>()?;
let expected = input.parse::<Expr>()?;
cases.push(TestCase { name, inputs, expected });
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
} else {
break;
}
}
Ok(TestCases { function_name, cases })
}
}
#[proc_macro]
pub fn test_cases(input: TokenStream) -> TokenStream {
let TestCases { function_name, cases } = parse_macro_input!(input as TestCases);
// 为每个测试用例生成测试函数
let test_functions = cases.iter().map(|case| {
let test_name = format_ident!("test_{}_{}"