Rust 练习册 :Pythagorean Triplet与数学算法

毕达哥拉斯三元组(Pythagorean Triplet)是满足毕达哥拉斯定理 a² + b² = c² 的三个正整数 (a, b, c)。在 Exercism 的 "pythagorean-triplet" 练习中,我们需要找到所有满足 a + b + c = sum 条件的毕达哥拉斯三元组。这不仅能帮助我们掌握数学算法和优化技巧,还能深入学习Rust中的集合操作、迭代器使用和性能优化。

什么是毕达哥拉斯三元组?

毕达哥拉斯三元组是满足毕达哥拉斯定理 a² + b² = c² 的三个正整数 (a, b, c)。其中最著名的例子是 (3, 4, 5),因为 3² + 4² = 9 + 16 = 25 = 5²。

在本练习中,我们需要找到所有满足以下条件的三元组:

  1. a² + b² = c²(毕达哥拉斯定理)
  2. a + b + c = sum(给定的和)
  3. a < b < c(按升序排列)

让我们先看看练习提供的函数签名:

rust 复制代码
use std::collections::HashSet;

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    unimplemented!("Given the sum {}, return all possible Pythagorean triplets, which produce the said sum, or an empty HashSet if there are no such triplets. Note that you are expected to return triplets in [a, b, c] order, where a < b < c", sum);
}

我们需要实现 find 函数,根据给定的和找到所有符合条件的毕达哥拉斯三元组。

设计分析

1. 核心要求

  1. 数学计算:正确实现毕达哥拉斯定理的验证
  2. 条件筛选:筛选满足和条件的三元组
  3. 排序约束:确保返回的三元组满足 a < b < c
  4. 去重处理:使用 HashSet 避免重复的三元组

2. 技术要点

  1. 暴力搜索:遍历所有可能的组合
  2. 数学优化:利用数学性质减少搜索空间
  3. 集合操作:使用 HashSet 存储和去重结果
  4. 性能优化:优化算法以处理大数计算

完整实现

1. 基础暴力搜索实现

rust 复制代码
use std::collections::HashSet;

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    let mut triplets = HashSet::new();
    
    // 遍历所有可能的 a 和 b 值
    for a in 1..sum / 3 {
        for b in (a + 1)..sum / 2 {
            let c = sum - a - b;
            
            // 确保 c > b 以满足 a < b < c 的条件
            if c > b {
                // 验证毕达哥拉斯定理
                if a * a + b * b == c * c {
                    triplets.insert([a, b, c]);
                }
            }
        }
    }
    
    triplets
}

2. 优化的数学实现

rust 复制代码
use std::collections::HashSet;

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    let mut triplets = HashSet::new();
    
    // 根据数学推导优化搜索范围
    // 从 a = 1 开始,到 a < sum / 3 结束
    for a in 1..sum / 3 {
        // 利用 a + b + c = sum 和 a² + b² = c² 推导出:
        // b = (sum² - 2*sum*a) / (2*(sum - a))
        let numerator = (sum as u64) * (sum as u64) - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        // 检查是否能整除
        if denominator != 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            // 确保 b > a 且 c > b
            if b > a {
                let c = sum - a - b;
                if c > b {
                    // 验证毕达哥拉斯定理
                    if a * a + b * b == c * c {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

3. 使用Euclid公式的实现

rust 复制代码
use std::collections::HashSet;

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    let mut triplets = HashSet::new();
    
    // 使用Euclid公式生成原始毕达哥拉斯三元组
    // 对于互质的 m > n > 0,且 m,n 不同时为奇数:
    // a = m² - n²
    // b = 2mn
    // c = m² + n²
    
    let limit = ((sum as f64).sqrt() as u32) + 1;
    
    for m in 2..limit {
        for n in 1..m {
            // 确保 m 和 n 不同时为奇数
            if (m - n) % 2 == 1 {
                let a = m * m - n * n;
                let b = 2 * m * n;
                let c = m * m + n * n;
                
                // 确保 a < b(如果不满足则交换)
                let (a, b) = if a < b { (a, b) } else { (b, a) };
                
                let triplet_sum = a + b + c;
                
                // 如果基本三元组的和能整除目标和,则存在解
                if sum % triplet_sum == 0 {
                    let k = sum / triplet_sum;
                    triplets.insert([k * a, k * b, k * c]);
                }
            }
        }
    }
    
    triplets
}

测试用例分析

通过查看测试用例,我们可以更好地理解需求:

rust 复制代码
#[test]
fn test_triplets_whose_sum_is_12() {
    process_tripletswithsum_case(12, &[[3, 4, 5]]);
}

和为12的毕达哥拉斯三元组只有(3, 4, 5)。

rust 复制代码
#[test]
fn test_triplets_whose_sum_is_108() {
    process_tripletswithsum_case(108, &[[27, 36, 45]]);
}

和为108的毕达哥拉斯三元组只有(27, 36, 45)。

rust 复制代码
#[test]
fn test_triplets_whose_sum_is_1000() {
    process_tripletswithsum_case(1000, &[[200, 375, 425]]);
}

和为1000的毕达哥拉斯三元组只有(200, 375, 425)。

rust 复制代码
#[test]
fn test_no_matching_triplets_for_1001() {
    process_tripletswithsum_case(1001, &[]);
}

和为1001时没有符合条件的毕达哥拉斯三元组。

rust 复制代码
#[test]
fn test_returns_all_matching_triplets() {
    process_tripletswithsum_case(90, &[[9, 40, 41], [15, 36, 39]]);
}

和为90时有两个符合条件的毕达哥拉斯三元组。

rust 复制代码
#[test]
fn test_several_matching_triplets() {
    process_tripletswithsum_case(
        840,
        &[
            [40, 399, 401],
            [56, 390, 394],
            [105, 360, 375],
            [120, 350, 370],
            [140, 336, 364],
            [168, 315, 357],
            [210, 280, 350],
            [240, 252, 348],
        ],
    );
}

和为840时有8个符合条件的毕达哥拉斯三元组。

rust 复制代码
#[test]
fn test_triplets_for_large_number() {
    process_tripletswithsum_case(
        30_000,
        &[
            [1200, 14_375, 14_425],
            [1875, 14_000, 14_125],
            [5000, 12_000, 13_000],
            [6000, 11_250, 12_750],
            [7500, 10_000, 12_500],
        ],
    );
}

和为30000时有5个符合条件的毕达哥拉斯三元组。

性能优化版本

考虑性能的优化实现:

rust 复制代码
use std::collections::HashSet;

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    let mut triplets = HashSet::new();
    
    // 边界情况处理
    if sum < 12 {
        return triplets;
    }
    
    // 优化搜索范围
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        // 使用数学公式直接计算 b
        // 从 a + b + c = sum 和 a² + b² = c² 可得:
        // b = (sum² - 2*sum*a) / (2*(sum - a))
        let sum_sq = (sum as u64) * (sum as u64);
        let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            // 确保 b > a 且满足三元组条件
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    // 验证毕达哥拉斯定理(二次验证)
                    if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

// 使用预分配的版本
pub fn find_with_capacity(sum: u32) -> HashSet<[u32; 3]> {
    // 预估容量以减少重新分配
    let estimated_capacity = (sum as f64).sqrt() as usize;
    let mut triplets = HashSet::with_capacity(estimated_capacity);
    
    if sum < 12 {
        return triplets;
    }
    
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        let sum_sq = (sum as u64) * (sum as u64);
        let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

// 使用Vec代替HashSet的版本(如果不需要去重)
pub fn find_vec(sum: u32) -> Vec<[u32; 3]> {
    let mut triplets = Vec::new();
    
    if sum < 12 {
        return triplets;
    }
    
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        let sum_sq = (sum as u64) * (sum as u64);
        let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                        triplets.push([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

错误处理和边界情况

考虑更多边界情况的实现:

rust 复制代码
use std::collections::HashSet;

#[derive(Debug, PartialEq)]
pub enum TripletError {
    InvalidSum,
    NoTripletsFound,
}

impl std::fmt::Display for TripletError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            TripletError::InvalidSum => write!(f, "无效的和值"),
            TripletError::NoTripletsFound => write!(f, "未找到符合条件的三元组"),
        }
    }
}

impl std::error::Error for TripletError {}

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    // 处理边界情况
    if sum < 12 {
        return HashSet::new();
    }
    
    let mut triplets = HashSet::new();
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        let sum_sq = (sum as u64) * (sum as u64);
        let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

// 返回Result的版本
pub fn find_safe(sum: u32) -> Result<HashSet<[u32; 3]>, TripletError> {
    if sum < 12 {
        return Err(TripletError::InvalidSum);
    }
    
    let triplets = find(sum);
    
    if triplets.is_empty() {
        Err(TripletError::NoTripletsFound)
    } else {
        Ok(triplets)
    }
}

// 支持更大整数类型的版本
use std::collections::HashSet as StdHashSet;

pub fn find_u64(sum: u64) -> StdHashSet<[u64; 3]> {
    let mut triplets = StdHashSet::new();
    
    if sum < 12 {
        return triplets;
    }
    
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        let sum_sq = sum * sum;
        let numerator = sum_sq - 2 * sum * a;
        let denominator = 2 * (sum - a);
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = numerator / denominator;
            
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    if a * a + b * b == c * c {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

扩展功能

基于基础实现,我们可以添加更多功能:

rust 复制代码
use std::collections::HashSet;

pub struct PythagoreanTripletFinder;

impl PythagoreanTripletFinder {
    pub fn new() -> Self {
        PythagoreanTripletFinder
    }
    
    pub fn find(&self, sum: u32) -> HashSet<[u32; 3]> {
        find(sum)
    }
    
    // 验证三元组是否为毕达哥拉斯三元组
    pub fn is_pythagorean_triplet(&self, triplet: &[u32; 3]) -> bool {
        let [a, b, c] = triplet;
        a * a + b * b == c * c
    }
    
    // 生成原始毕达哥拉斯三元组(互质的三元组)
    pub fn generate_primitive_triplets(&self, limit: u32) -> HashSet<[u32; 3]> {
        let mut triplets = HashSet::new();
        
        let m_limit = ((limit as f64).sqrt() as u32) + 1;
        
        for m in 2..m_limit {
            for n in 1..m {
                if (m - n) % 2 == 1 && gcd(m, n) == 1 {
                    let a = m * m - n * n;
                    let b = 2 * m * n;
                    let c = m * m + n * n;
                    
                    let (a, b) = if a < b { (a, b) } else { (b, a) };
                    
                    if a + b + c <= limit {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
        
        triplets
    }
    
    // 从原始三元组生成所有三元组
    pub fn generate_all_triplets_from_primitive(&self, primitive: &[u32; 3], limit: u32) -> Vec<[u32; 3]> {
        let mut triplets = Vec::new();
        let [a, b, c] = primitive;
        let sum = a + b + c;
        
        let mut k = 1;
        while k * sum <= limit {
            triplets.push([k * a, k * b, k * c]);
            k += 1;
        }
        
        triplets
    }
    
    // 查找特定范围内的所有毕达哥拉斯三元组
    pub fn find_in_range(&self, min_sum: u32, max_sum: u32) -> Vec<(u32, HashSet<[u32; 3]>)> {
        let mut results = Vec::new();
        
        for sum in min_sum..=max_sum {
            let triplets = self.find(sum);
            if !triplets.is_empty() {
                results.push((sum, triplets));
            }
        }
        
        results
    }
    
    // 统计特定和值的三元组数量
    pub fn count_triplets(&self, sum: u32) -> usize {
        self.find(sum).len()
    }
}

// 计算最大公约数
fn gcd(mut a: u32, mut b: u32) -> u32 {
    while b != 0 {
        let temp = b;
        b = a % b;
        a = temp;
    }
    a
}

pub fn find(sum: u32) -> HashSet<[u32; 3]> {
    let mut triplets = HashSet::new();
    
    if sum < 12 {
        return triplets;
    }
    
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        let sum_sq = (sum as u64) * (sum as u64);
        let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

// 毕达哥拉斯三元组分析
pub struct TripletAnalysis {
    pub sum: u32,
    pub triplets: HashSet<[u32; 3]>,
    pub count: usize,
    pub smallest_triplet: Option<[u32; 3]>,
    pub largest_triplet: Option<[u32; 3]>,
}

impl PythagoreanTripletFinder {
    pub fn analyze(&self, sum: u32) -> TripletAnalysis {
        let triplets = self.find(sum);
        let count = triplets.len();
        
        let smallest_triplet = triplets.iter().min().copied();
        let largest_triplet = triplets.iter().max().copied();
        
        TripletAnalysis {
            sum,
            triplets,
            count,
            smallest_triplet,
            largest_triplet,
        }
    }
}

// 便利函数
pub fn is_pythagorean_triplet(triplet: &[u32; 3]) -> bool {
    let [a, b, c] = triplet;
    a * a + b * b == c * c
}

pub fn format_triplets(triplets: &HashSet<[u32; 3]>) -> String {
    let mut formatted = Vec::new();
    
    for triplet in triplets {
        formatted.push(format!("({}, {}, {})", triplet[0], triplet[1], triplet[2]));
    }
    
    formatted.sort();
    formatted.join(", ")
}

实际应用场景

毕达哥拉斯三元组在实际开发中有以下应用:

  1. 数学软件:数学计算和教育工具
  2. 游戏开发:几何游戏和益智游戏
  3. 图形学:计算机图形学中的几何计算
  4. 密码学:数论相关算法
  5. 教育工具:数学教学演示
  6. 算法竞赛:数学算法问题解决
  7. 工程计算:三角形相关计算
  8. 科学研究:数论研究工具

算法复杂度分析

  1. 时间复杂度

    • 暴力搜索:O(n²)
    • 优化实现:O(n)
    • Euclid公式:O(√n)
  2. 空间复杂度:O(k)

    • 其中k是符合条件的三元组数量

与其他实现方式的比较

rust 复制代码
// 使用递归的实现
pub fn find_recursive(sum: u32) -> HashSet<[u32; 3]> {
    fn find_helper(sum: u32, a: u32, b: u32) -> HashSet<[u32; 3]> {
        if a >= sum / 3 {
            return HashSet::new();
        }
        
        if b >= sum / 2 {
            return find_helper(sum, a + 1, a + 2);
        }
        
        let c = sum - a - b;
        
        let mut triplets = if c > b && a * a + b * b == c * c {
            let mut set = HashSet::new();
            set.insert([a, b, c]);
            set
        } else {
            HashSet::new()
        };
        
        triplets.extend(find_helper(sum, a, b + 1));
        triplets
    }
    
    find_helper(sum, 1, 2)
}

// 使用第三方库的实现
// [dependencies]
// num = "0.4"

use num::integer::gcd;

pub fn find_with_num_library(sum: u32) -> HashSet<[u32; 3]> {
    let mut triplets = HashSet::new();
    
    if sum < 12 {
        return triplets;
    }
    
    let max_a = sum / 3;
    
    for a in 1..=max_a {
        let sum_sq = (sum as u64) * (sum as u64);
        let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
        let denominator = 2 * ((sum as u64) - (a as u64));
        
        if denominator > 0 && numerator % denominator == 0 {
            let b = (numerator / denominator) as u32;
            
            if b > a {
                let c = sum - a - b;
                
                if c > b {
                    if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                        // 使用第三方库的gcd函数验证是否为原始三元组
                        if gcd(gcd(a, b), c) == 1 {
                            // 这是一个原始三元组
                        }
                        triplets.insert([a, b, c]);
                    }
                }
            }
        }
    }
    
    triplets
}

// 使用并行计算的实现
// [dependencies]
// rayon = "1.5"

use rayon::prelude::*;

pub fn find_parallel(sum: u32) -> HashSet<[u32; 3]> {
    if sum < 12 {
        return HashSet::new();
    }
    
    let max_a = sum / 3;
    
    (1..=max_a)
        .into_par_iter()
        .filter_map(|a| {
            let sum_sq = (sum as u64) * (sum as u64);
            let numerator = sum_sq - 2 * (sum as u64) * (a as u64);
            let denominator = 2 * ((sum as u64) - (a as u64));
            
            if denominator > 0 && numerator % denominator == 0 {
                let b = (numerator / denominator) as u32;
                
                if b > a {
                    let c = sum - a - b;
                    
                    if c > b {
                        if (a as u64) * (a as u64) + (b as u64) * (b as u64) == (c as u64) * (c as u64) {
                            Some([a, b, c])
                        } else {
                            None
                        }
                    } else {
                        None
                    }
                } else {
                    None
                }
            } else {
                None
            }
        })
        .collect()
}

// 使用生成器模式的实现
pub struct TripletGenerator {
    sum: u32,
    a: u32,
    b: u32,
}

impl TripletGenerator {
    pub fn new(sum: u32) -> Self {
        TripletGenerator { sum, a: 1, b: 2 }
    }
}

impl Iterator for TripletGenerator {
    type Item = [u32; 3];
    
    fn next(&mut self) -> Option<Self::Item> {
        while self.a < self.sum / 3 {
            while self.b < self.sum / 2 {
                let c = self.sum - self.a - self.b;
                
                if c > self.b {
                    if self.a * self.a + self.b * self.b == c * c {
                        let triplet = [self.a, self.b, c];
                        self.b += 1;
                        return Some(triplet);
                    }
                }
                
                self.b += 1;
            }
            
            self.a += 1;
            self.b = self.a + 1;
        }
        
        None
    }
}

总结

通过 pythagorean-triplet 练习,我们学到了:

  1. 数学算法:掌握了毕达哥拉斯三元组的数学性质和计算方法
  2. 优化技巧:学会了使用数学公式优化搜索算法
  3. 集合操作:深入理解了HashSet的使用和去重机制
  4. 边界处理:理解了如何处理各种边界情况
  5. 性能优化:学会了算法复杂度分析和优化技巧
  6. 设计模式:理解了生成器模式和工厂模式的应用

这些技能在实际开发中非常有用,特别是在数学计算、算法设计、游戏开发等场景中。毕达哥拉斯三元组虽然是一个数学问题,但它涉及到了算法优化、数学计算、集合操作、性能优化等许多核心概念,是学习Rust实用编程的良好起点。

通过这个练习,我们也看到了Rust在数学计算和算法实现方面的强大能力,以及如何用安全且高效的方式实现经典数学算法。这种结合了安全性和性能的语言特性正是Rust的魅力所在。

相关推荐
星释1 小时前
Rust 练习册 :Nth Prime与素数算法
开发语言·算法·rust
lkbhua莱克瓦242 小时前
Java基础——集合进阶3
java·开发语言·笔记
多喝开水少熬夜2 小时前
Trie树相关算法题java实现
java·开发语言·算法
QT 小鲜肉2 小时前
【QT/C++】Qt定时器QTimer类的实现方法详解(超详细)
开发语言·数据库·c++·笔记·qt·学习
WBluuue3 小时前
数据结构与算法:树上倍增与LCA
数据结构·c++·算法
lsx2024063 小时前
MySQL WHERE 子句详解
开发语言
bruk_spp3 小时前
牛客网华为在线编程题
算法
Tony Bai3 小时前
【Go模块构建与依赖管理】09 企业级实践:私有仓库与私有 Proxy
开发语言·后端·golang
Lucky小小吴3 小时前
开源项目5——Go版本快速管理工具
开发语言·golang·开源