Rust Trait 与泛型高级用法
概述
Trait 是 Rust 中定义共享行为的机制,类似于其他语言的接口。泛型允许编写灵活且可重用的代码。结合使用它们可以创建强大的抽象。
Trait 基础
Trait 定义了类型必须实现的方法集合。
基础示例
rust
trait Summary {
fn summarize(&self) -> String;
// 默认实现
fn preview(&self) -> String {
format!("(预览: {}...)", &self.summarize()[..20.min(self.summarize().len())])
}
}
struct Article {
title: String,
content: String,
author: String,
}
impl Summary for Article {
fn summarize(&self) -> String {
format!("{} by {}", self.title, self.author)
}
}
复杂案例:实现一个通用的数据处理框架
rust
use std::fmt::Debug;
use std::collections::HashMap;
// 定义数据处理器 trait
trait DataProcessor {
type Input;
type Output;
type Error;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error>;
fn name(&self) -> &str;
fn can_process(&self, _input: &Self::Input) -> bool {
true
}
}
// 定义可链接的处理器
trait ChainableProcessor: DataProcessor {
fn chain<P>(self, next: P) -> ProcessorChain<Self, P>
where
Self: Sized,
P: DataProcessor<Input = Self::Output>,
{
ProcessorChain::new(self, next)
}
}
impl<T: DataProcessor> ChainableProcessor for T {}
// 处理器链
struct ProcessorChain<P1, P2> {
first: P1,
second: P2,
}
impl<P1, P2> ProcessorChain<P1, P2> {
fn new(first: P1, second: P2) -> Self {
ProcessorChain { first, second }
}
}
impl<P1, P2> DataProcessor for ProcessorChain<P1, P2>
where
P1: DataProcessor,
P2: DataProcessor<Input = P1::Output>,
P1::Error: From<P2::Error>,
{
type Input = P1::Input;
type Output = P2::Output;
type Error = P1::Error;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error> {
let intermediate = self.first.process(input)?;
self.second.process(intermediate)
.map_err(|e| e.into())
}
fn name(&self) -> &str {
"ProcessorChain"
}
}
// 文本转大写处理器
struct UppercaseProcessor;
impl DataProcessor for UppercaseProcessor {
type Input = String;
type Output = String;
type Error = String;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error> {
Ok(input.to_uppercase())
}
fn name(&self) -> &str {
"UppercaseProcessor"
}
}
// 文本分词处理器
struct TokenizeProcessor;
impl DataProcessor for TokenizeProcessor {
type Input = String;
type Output = Vec<String>;
type Error = String;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error> {
Ok(input.split_whitespace()
.map(|s| s.to_string())
.collect())
}
fn name(&self) -> &str {
"TokenizeProcessor"
}
}
// 词频统计处理器
struct WordCountProcessor;
impl DataProcessor for WordCountProcessor {
type Input = Vec<String>;
type Output = HashMap<String, usize>;
type Error = String;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error> {
let mut counts = HashMap::new();
for word in input {
*counts.entry(word).or_insert(0) += 1;
}
Ok(counts)
}
fn name(&self) -> &str {
"WordCountProcessor"
}
}
// 过滤处理器
struct FilterProcessor<T, F>
where
F: Fn(&T) -> bool,
{
predicate: F,
_phantom: std::marker::PhantomData<T>,
}
impl<T, F> FilterProcessor<T, F>
where
F: Fn(&T) -> bool,
{
fn new(predicate: F) -> Self {
FilterProcessor {
predicate,
_phantom: std::marker::PhantomData,
}
}
}
impl<T, F> DataProcessor for FilterProcessor<T, F>
where
T: Clone,
F: Fn(&T) -> bool,
{
type Input = Vec<T>;
type Output = Vec<T>;
type Error = String;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error> {
Ok(input.into_iter()
.filter(|item| (self.predicate)(item))
.collect())
}
fn name(&self) -> &str {
"FilterProcessor"
}
}
// 映射处理器
struct MapProcessor<I, O, F>
where
F: Fn(I) -> O,
{
mapper: F,
_phantom: std::marker::PhantomData<(I, O)>,
}
impl<I, O, F> MapProcessor<I, O, F>
where
F: Fn(I) -> O,
{
fn new(mapper: F) -> Self {
MapProcessor {
mapper,
_phantom: std::marker::PhantomData,
}
}
}
impl<I, O, F> DataProcessor for MapProcessor<I, O, F>
where
F: Fn(I) -> O,
{
type Input = Vec<I>;
type Output = Vec<O>;
type Error = String;
fn process(&mut self, input: Self::Input) -> Result<Self::Output, Self::Error> {
Ok(input.into_iter()
.map(|item| (self.mapper)(item))
.collect())
}
fn name(&self) -> &str {
"MapProcessor"
}
}
// 演示处理器使用
fn demonstrate_processors() {
let text = "Rust is amazing Rust is powerful Rust is safe".to_string();
// 单个处理器
let mut uppercase = UppercaseProcessor;
let result = uppercase.process(text.clone()).unwrap();
println!("大写: {}", result);
// 链式处理器
let mut tokenizer = TokenizeProcessor;
let words = tokenizer.process(text.clone()).unwrap();
println!("分词: {:?}", words);
let mut counter = WordCountProcessor;
let counts = counter.process(words).unwrap();
println!("词频: {:?}", counts);
}
// 泛型函数示例
fn find_max<T: PartialOrd>(list: &[T]) -> Option<&T> {
if list.is_empty() {
return None;
}
let mut max = &list[0];
for item in &list[1..] {
if item > max {
max = item;
}
}
Some(max)
}
// 多 trait 约束
fn print_and_compare<T>(a: &T, b: &T)
where
T: Debug + PartialOrd,
{
println!("比较 {:?} 和 {:?}", a, b);
if a > b {
println!("{:?} 更大", a);
} else if a < b {
println!("{:?} 更大", b);
} else {
println!("相等");
}
}
// 关联类型示例
trait Container {
type Item;
fn add(&mut self, item: Self::Item);
fn get(&self, index: usize) -> Option<&Self::Item>;
fn len(&self) -> usize;
}
struct VecContainer<T> {
items: Vec<T>,
}
impl<T> VecContainer<T> {
fn new() -> Self {
VecContainer { items: Vec::new() }
}
}
impl<T> Container for VecContainer<T> {
type Item = T;
fn add(&mut self, item: Self::Item) {
self.items.push(item);
}
fn get(&self, index: usize) -> Option<&Self::Item> {
self.items.get(index)
}
fn len(&self) -> usize {
self.items.len()
}
}
// 实现一个泛型缓存
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
struct Cache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
storage: HashMap<u64, (K, V)>,
max_size: usize,
}
impl<K, V> Cache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
fn new(max_size: usize) -> Self {
Cache {
storage: HashMap::new(),
max_size,
}
}
fn hash_key(&self, key: &K) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
fn insert(&mut self, key: K, value: V) {
if self.storage.len() >= self.max_size {
// 简单策略:移除第一个元素
if let Some(&first_key) = self.storage.keys().next() {
self.storage.remove(&first_key);
}
}
let hash = self.hash_key(&key);
self.storage.insert(hash, (key, value));
}
fn get(&self, key: &K) -> Option<V> {
let hash = self.hash_key(key);
self.storage.get(&hash).map(|(_, v)| v.clone())
}
}
// Trait 对象示例
trait Drawable {
fn draw(&self);
fn area(&self) -> f64;
}
struct Circle {
radius: f64,
}
impl Drawable for Circle {
fn draw(&self) {
println!("绘制圆形,半径: {}", self.radius);
}
fn area(&self) -> f64 {
std::f64::consts::PI * self.radius * self.radius
}
}
struct Rectangle {
width: f64,
height: f64,
}
impl Drawable for Rectangle {
fn draw(&self) {
println!("绘制矩形,宽: {}, 高: {}", self.width, self.height);
}
fn area(&self) -> f64 {
self.width * self.height
}
}
fn draw_all(shapes: &[Box<dyn Drawable>]) {
for shape in shapes {
shape.draw();
println!(" 面积: {:.2}", shape.area());
}
}
// 泛型结构体和方法
struct Point<T> {
x: T,
y: T,
}
impl<T> Point<T> {
fn new(x: T, y: T) -> Self {
Point { x, y }
}
}
impl<T: std::ops::Add<Output = T> + Copy> Point<T> {
fn add(&self, other: &Point<T>) -> Point<T> {
Point {
x: self.x + other.x,
y: self.y + other.y,
}
}
}
impl Point<f64> {
fn distance_from_origin(&self) -> f64 {
(self.x * self.x + self.y * self.y).sqrt()
}
}
// trait 继承
trait Named {
fn name(&self) -> &str;
}
trait Identifiable: Named {
fn id(&self) -> u64;
fn full_identifier(&self) -> String {
format!("{}:{}", self.id(), self.name())
}
}
struct User {
id: u64,
username: String,
}
impl Named for User {
fn name(&self) -> &str {
&self.username
}
}
impl Identifiable for User {
fn id(&self) -> u64 {
self.id
}
}
fn demonstrate_advanced_features() {
// 泛型函数
let numbers = vec![1, 5, 3, 9, 2];
if let Some(max) = find_max(&numbers) {
println!("最大值: {}", max);
}
// trait 约束
print_and_compare(&10, &20);
print_and_compare(&"hello", &"world");
// 容器
let mut container = VecContainer::new();
container.add("Rust");
container.add("is");
container.add("awesome");
println!("容器大小: {}", container.len());
// 缓存
let mut cache = Cache::new(3);
cache.insert("key1", "value1");
cache.insert("key2", "value2");
if let Some(value) = cache.get(&"key1") {
println!("缓存值: {}", value);
}
// trait 对象
let shapes: Vec<Box<dyn Drawable>> = vec![
Box::new(Circle { radius: 5.0 }),
Box::new(Rectangle { width: 4.0, height: 6.0 }),
];
draw_all(&shapes);
// 泛型点
let p1 = Point::new(3, 4);
let p2 = Point::new(1, 2);
let p3 = p1.add(&p2);
println!("点相加: ({}, {})", p3.x, p3.y);
let pf = Point::new(3.0, 4.0);
println!("距离原点: {:.2}", pf.distance_from_origin());
// trait 继承
let user = User {
id: 1001,
username: "rustacean".to_string(),
};
println!("用户标识: {}", user.full_identifier());
}
fn main() {
demonstrate_processors();
println!("\n---\n");
demonstrate_advanced_features();
}
Trait 对象与泛型的选择
- 泛型: 编译时单态化,性能更好,但会增加代码大小
- Trait 对象: 运行时动态分发,代码大小小,但有轻微性能开销
总结
Trait 和泛型是 Rust 类型系统的核心特性。它们提供了强大的抽象能力,同时保持了类型安全和高性能。