基本数学问题
一、判断闰年
闰年规则:
- 能被400整除的年份 → 是闰年
- 能被100整除但不能被400整除的年份 → 不是闰年
- 能被4整除但不能被100整除的年份 → 是闰年
- 其他年份 → 不是闰年
Java实现,基础版本:
java
/**
* 判断指定年份是否为闰年
* @param year 年份
* @return true-是闰年, false-不是闰年
*/
public static boolean isLeapYear(int year) {
if (year <= 0) {
throw new IllegalArgumentException("年份必须为正整数");
}
if (year % 400 == 0) {
return true; // 能被400整除的是闰年
} else if (year % 100 == 0) {
return false; // 能被100整除但不能被400整除的不是闰年
} else {
return year % 4 == 0; // 能被4整除但不能被100整除的是闰年
}
}
/**
* 简化版本(逻辑合并)
*/
public static boolean isLeapYearSimple(int year) {
if (year <= 0) {
throw new IllegalArgumentException("年份必须为正整数");
}
return (year % 400 == 0) || (year % 4 == 0 && year % 100 != 0);
}
/**
* 使用Java 8+的java.time API
*/
public static boolean isLeapYearUsingJavaTime(int year) {
if (year <= 0) {
throw new IllegalArgumentException("年份必须为正整数");
}
return java.time.Year.of(year).isLeap();
}
//测试
// 测试代码
@Test
public void test() {
int[] testYears = {
1900, // 不是闰年(百年不闰)
2000, // 是闰年(四百年再闰)
2020, // 是闰年
2021, // 不是闰年
2024, // 是闰年
2100, // 不是闰年
2400, // 是闰年
1, // 不是闰年
4, // 是闰年
100, // 不是闰年
400 // 是闰年
};
System.out.println("=== 手动实现(基础版本)===");
for (int year : testYears) {
System.out.println(year + "年: " + (MathDemo.isLeapYear(year) ? "闰年" : "平年"));
}
System.out.println("\n=== 手动实现(简化版本)===");
for (int year : testYears) {
System.out.println(year + "年: " + (MathDemo.isLeapYearSimple(year) ? "闰年" : "平年"));
}
// 测试边界条件和异常
System.out.println("\n=== 异常处理测试 ===");
try {
System.out.println("0年: " + MathDemo.isLeapYear(0));
} catch (IllegalArgumentException e) {
System.out.println("0年: " + e.getMessage());
}
try {
System.out.println("-2024年: " + MathDemo.isLeapYear(-2024));
} catch (IllegalArgumentException e) {
System.out.println("-2024年: " + e.getMessage());
}
}
=== 手动实现(基础版本)===
1900年: 平年
2000年: 闰年
2020年: 闰年
2021年: 平年
2024年: 闰年
2100年: 平年
2400年: 闰年
1年: 平年
4年: 闰年
100年: 平年
400年: 闰年
=== 手动实现(简化版本)===
1900年: 平年
2000年: 闰年
2020年: 闰年
2021年: 平年
2024年: 闰年
2100年: 平年
2400年: 闰年
1年: 平年
4年: 闰年
100年: 平年
400年: 闰年
=== 异常处理测试 ===
0年: 年份必须为正整数
-2024年: 年份必须为正整数
二、随机数生成算法
2.1、Java中的随机方法
编程语言中提供了随机数的生成方法。例如,Java语言中就提供了如下三种随机数生成方法:
- Math.random(方法,产生的随机数是0~1之间的一个double,可以把它乘以一定的数,比如乘以100,它就是100以内的随机。
- java.util包里面提供Random的类,可以新建一个Random的对象来产生随机数,它可以产生随机整数、随机float、随机double、随机long。
- System类中有一个currentTimeMillis(0方法,该方法返回一个从1970年1月1日0点0分0秒到目前的一个毫秒数,返回类型是log,可以将它作为一个随机数,对一些数取模,就可以把它限制在一个范围之内。
2.2、Math.random()
java
public void random1() {
//使用 Math.random()
// ---------- 1. Math.random() ----------
System.out.println("=== 使用 Math.random() ===");
// 生成 [0.0, 1.0) 之间的 double
double random1 = Math.random();
System.out.println("随机double [0,1): " + random1);
// 生成 [0, 10) 之间的整数
int random2 = (int)(Math.random() * 10);
System.out.println("随机整数 [0,9]: " + random2);
// 生成 [1, 10] 之间的整数
int random3 = (int)(Math.random() * 10) + 1;
System.out.println("随机整数 [1,10]: " + random3);
// 生成 [min, max] 之间的整数
int min = 5, max = 15;
int random4 = (int)(Math.random() * (max - min + 1)) + min;
System.out.println("随机整数 [5,15]: " + random4);
// 生成指定范围内的 double
double minD = 2.5, maxD = 7.5;
double random5 = Math.random() * (maxD - minD) + minD;
System.out.println("随机double [2.5,7.5): " + String.format("%.2f", random5));
}
2.3、java.util.Random
java
public void random2() {
Random random = new Random();
System.out.println("\n=== 使用 java.util.Random ===");
// 生成 int 范围内的随机整数
int r1 = random.nextInt();
System.out.println("int范围内的随机整数: " + r1);
// 生成 [0, bound) 的随机整数
int r2 = random.nextInt(10); // [0, 9]
System.out.println("随机整数 [0,9]: " + r2);
// 生成 [min, max] 的随机整数
int min = 5, max = 15;
int r3 = random.nextInt(max - min + 1) + min;
System.out.println("随机整数 [5,15]: " + r3);
// 生成随机 long
long r4 = random.nextLong();
System.out.println("随机long: " + r4);
// 生成 [0.0, 1.0) 的随机 float
float r5 = random.nextFloat();
System.out.println("随机float: " + r5);
// 生成 [0.0, 1.0) 的随机 double
double r6 = random.nextDouble();
System.out.println("随机double: " + r6);
// 生成随机 boolean
boolean r7 = random.nextBoolean();
System.out.println("随机boolean: " + r7);
// 生成随机字节数组
byte[] bytes = new byte[5];
random.nextBytes(bytes);
System.out.print("随机字节数组: ");
for (byte b : bytes) {
System.out.print(b + " ");
}
System.out.println();
}
2.4、ThreadLocalRandom
java
public void random3() {
System.out.println("\n=== 使用 ThreadLocalRandom (多线程推荐) ===");
// 获取当前线程的 ThreadLocalRandom 实例
ThreadLocalRandom tlr = ThreadLocalRandom.current();
// 基本用法
int r1 = tlr.nextInt();
System.out.println("随机整数: " + r1);
int r2 = tlr.nextInt(1, 11); // [1, 10]
System.out.println("随机整数 [1,10]: " + r2);
double r3 = tlr.nextDouble(5.0, 10.0); // [5.0, 10.0)
System.out.println("随机double [5.0,10.0): " + String.format("%.2f", r3));
long r4 = tlr.nextLong(1000L, 10000L); // [1000, 9999]
System.out.println("随机long [1000,9999]: " + r4);
// 高斯分布(正态分布)随机数
double gaussian = tlr.nextGaussian();
System.out.println("高斯分布随机数: " + String.format("%.4f", gaussian));
}
2.5、SecureRandom
java
public void random4() throws NoSuchAlgorithmException {
System.out.println("\n=== 使用 SecureRandom (加密安全) ===");
SecureRandom secureRandom = new SecureRandom();
// 生成加密安全的随机整数
int secureInt = secureRandom.nextInt();
System.out.println("安全随机整数: " + secureInt);
// 生成 [0, 100) 的安全随机整数
int secureRange = secureRandom.nextInt(100);
System.out.println("安全随机整数 [0,99]: " + secureRange);
// 生成安全随机字节数组
byte[] secureBytes = new byte[16];
secureRandom.nextBytes(secureBytes);
System.out.print("16字节安全随机数: ");
for (byte b : secureBytes) {
System.out.printf("%02x", b);
}
System.out.println();
// 指定算法
SecureRandom sha1Prng = SecureRandom.getInstance("SHA1PRNG");
int sha1Random = sha1Prng.nextInt(1000);
System.out.println("SHA1PRNG随机数: " + sha1Random);
}
2.6、随机数工具类
java
/**
* 生成安全随机密码
* @param length 密码长度
* @return 随机密码
*/
public static String generateSecurePassword(int length) {
String uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
String lowercase = "abcdefghijklmnopqrstuvwxyz";
String digits = "0123456789";
String special = "!@#$%^&*()_+-=[]{}|;:,.<>?";
String allChars = uppercase + lowercase + digits + special;
StringBuilder password = new StringBuilder(length);
// 确保至少包含每种类型一个字符
password.append(uppercase.charAt(SECURE_RANDOM.nextInt(uppercase.length())));
password.append(lowercase.charAt(SECURE_RANDOM.nextInt(lowercase.length())));
password.append(digits.charAt(SECURE_RANDOM.nextInt(digits.length())));
password.append(special.charAt(SECURE_RANDOM.nextInt(special.length())));
// 填充剩余字符
for (int i = 4; i < length; i++) {
password.append(allChars.charAt(SECURE_RANDOM.nextInt(allChars.length())));
}
// 打乱顺序
char[] passwordArray = password.toString().toCharArray();
for (int i = passwordArray.length - 1; i > 0; i--) {
int j = SECURE_RANDOM.nextInt(i + 1);
char temp = passwordArray[i];
passwordArray[i] = passwordArray[j];
passwordArray[j] = temp;
}
return new String(passwordArray);
}
// ========== 高级随机功能 ==========
/**
* 从数组中随机选择一个元素
*/
public static <T> T randomChoice(T[] array) {
if (array == null || array.length == 0) {
return null;
}
return array[RANDOM.nextInt(array.length)];
}
/**
* 从List中随机选择一个元素
*/
public static <T> T randomChoice(List<T> list) {
if (list == null || list.isEmpty()) {
return null;
}
return list.get(RANDOM.nextInt(list.size()));
}
/**
* 从数组中随机选择多个不重复元素
*/
public static <T> List<T> randomSample(T[] array, int sampleSize) {
if (array == null || array.length == 0 || sampleSize <= 0) {
return new ArrayList<>();
}
List<T> list = new ArrayList<>(Arrays.asList(array));
Collections.shuffle(list, RANDOM);
if (sampleSize > list.size()) {
sampleSize = list.size();
}
return list.subList(0, sampleSize);
}
/**
* 生成符合正态分布的随机数
*/
public static double nextGaussian(double mean, double stdDev) {
return RANDOM.nextGaussian() * stdDev + mean;
}
/**
* 随机打乱数组
*/
public static <T> void shuffleArray(T[] array) {
for (int i = array.length - 1; i > 0; i--) {
int j = RANDOM.nextInt(i + 1);
T temp = array[i];
array[i] = array[j];
array[j] = temp;
}
}
/**
* 随机打乱List
*/
public static <T> void shuffleList(List<T> list) {
Collections.shuffle(list, RANDOM);
}
/**
* 生成随机颜色 (RGB)
*/
public static int randomColor() {
int r = RANDOM.nextInt(256);
int g = RANDOM.nextInt(256);
int b = RANDOM.nextInt(256);
return (r << 16) | (g << 8) | b;
}
// ========== 测试方法 ==========
public static void main(String[] args) {
System.out.println("=== 随机数工具类测试 ===\n");
// 1. 基本随机数
System.out.println("1. 基本随机数:");
System.out.println("随机整数 [1, 100]: " + randomInt(1, 100));
System.out.println("随机double [0.0, 1.0): " + String.format("%.4f", randomDouble()));
System.out.println("随机double [10.5, 20.5): " + String.format("%.2f", randomDouble(10.5, 20.5)));
System.out.println("随机boolean (70% true): " + randomBoolean(0.7));
// 2. 线程安全随机数
System.out.println("\n2. 线程安全随机数:");
System.out.println("随机整数 [50, 100]: " + threadSafeRandomInt(50, 100));
// 3. 安全随机数
System.out.println("\n3. 加密安全随机数:");
System.out.println("安全随机整数 [1000, 9999]: " + secureRandomInt(1000, 9999));
System.out.println("安全随机字符串 (8位): " + secureRandomString(8));
System.out.println("安全随机密码 (12位): " + generateSecurePassword(12));
// 4. 高级功能
System.out.println("\n4. 高级随机功能:");
String[] fruits = {"苹果", "香蕉", "橙子", "葡萄", "西瓜", "芒果"};
System.out.println("随机选择水果: " + randomChoice(fruits));
List<String> sample = randomSample(fruits, 3);
System.out.println("随机选择3种水果: " + sample);
// 正态分布
System.out.println("\n5. 正态分布随机数:");
for (int i = 0; i < 5; i++) {
double gaussian = nextGaussian(100, 15); // 均值100, 标准差15
System.out.println("正态分布: " + String.format("%.2f", gaussian));
}
// 打乱数组
Integer[] numbers = {1, 2, 3, 4, 5, 6, 7, 8, 9};
shuffleArray(numbers);
System.out.println("\n6. 打乱的数组: " + Arrays.toString(numbers));
// 随机颜色
int color = randomColor();
System.out.println("\n7. 随机颜色 (RGB): #" +
String.format("%06X", color & 0xFFFFFF));
// 性能测试
System.out.println("\n8. 性能测试:");
long start = System.currentTimeMillis();
for (int i = 0; i < 100000; i++) {
randomInt(1, 100);
}
long end = System.currentTimeMillis();
System.out.println("生成10万个随机数耗时: " + (end - start) + "ms");
// 多线程测试
System.out.println("\n9. 多线程测试:");
Runnable task = () -> {
for (int i = 0; i < 5; i++) {
int num = ThreadLocalRandom.current().nextInt(1, 100);
System.out.println(Thread.currentThread().getName() + ": " + num);
}
};
Thread t1 = new Thread(task, "线程1");
Thread t2 = new Thread(task, "线程2");
t1.start();
t2.start();
}
}
三、π近似值
3.1、割圆术
割圆术原理:
- 从正六边形开始(内接于单位圆)
- 不断倍增边数:6 → 12 → 24 → 48 → 96 → ...
- 用正多边形的周长逼近圆周长
- 计算 π 的近似值
数学公式:
- 设单位圆半径为 1:
- 正 n 边形的边长:aₙ
- 正 2n 边形的边长:a₂ₙ
- 递推公式:a₂ₙ = √(2 - √(4 - aₙ²))
Java实现:
java
/**
* 割圆术计算π
* @param iterations iterations 迭代次数(分割次数)
* @return π 的近似值
*/
public static double calculatePiByInscribedPolygon(int iterations) {
int n = 6; // 从正六边形开始
double sideLength = 1.0; // 正六边形边长(单位圆)
for (int i = 0; i < iterations; i++) {
// 计算正 2n 边形的边长
double innerTerm = 4 - sideLength * sideLength;
sideLength = Math.sqrt(2 - Math.sqrt(innerTerm));
n *= 2; // 边数翻倍
}
// 周长 = 边数 × 边长
double perimeter = n * sideLength;
// π ≈ 周长 / (2 × 半径)
return perimeter / 2.0;
}
/**
* 使用割圆术计算 π(改进公式,数值更稳定)
*/
public static double calculatePiImproved(int iterations) {
int n = 6;
double sideLength = 1.0;
double piApprox = 0;
System.out.println("迭代过程:");
System.out.println("边数\t\tπ近似值\t\t\t误差");
System.out.println("---------------------------------------------");
for (int i = 0; i < iterations; i++) {
// 计算当前 π 的近似值
double perimeter = n * sideLength;
piApprox = perimeter / 2.0;
double error = Math.abs(Math.PI - piApprox);
System.out.printf("%d\t\t%.10f\t%.10f%n", n, piApprox, error);
// 计算下一轮多边形的边长
double inner = 4 - sideLength * sideLength;
if (inner < 0) {
inner = 0; // 防止数值误差导致的负数
}
sideLength = Math.sqrt(2 - Math.sqrt(inner));
n *= 2;
}
return piApprox;
}
3.2、蒙特卡罗算法
蒙特卡罗法计算 π 的原理:
- 基本思想:
- 在一个边长为 2 的正方形中,内接一个半径为 1 的圆:
- 正方形面积:S_square = 4
- 圆的面积:S_circle = π
- 圆的面积与正方形面积之比:π/4
- 算法步骤:
- 在正方形区域内随机生成 N 个点
- 统计落在圆内的点数 M
- 计算比例:M/N ≈ 圆的面积/正方形面积 = π/4
- 计算 π 的近似值:π ≈ 4 × M/N
Java实现:
java
public class MonteCarloPi {
/**
* 基础版本蒙特卡罗法计算 π
* @param totalPoints 总点数
* @return π 的近似值
*/
public static double calculatePiBasic(long totalPoints) {
Random random = new Random();
long insideCircle = 0;
for (long i = 0; i < totalPoints; i++) {
//在-1,1范围生成随机点
double x = random.nextDouble() * 2 - 1;
double y = random.nextDouble() * 2 - 1;
//判断点是否在圆内
if (x * x + y * y <= 1.0) {
insideCircle++;
}
}
//π约等于 4 * (圆内点数 / 总点数)
return 4.0 * insideCircle / totalPoints;
}
/**
* 优化版本(减少乘法运算)
*/
public static double calculatePiOptimized(long totalPoints) {
Random random = new Random();
long insideCircle = 0;
for (long i = 0; i < totalPoints; i++) {
// 使用 nextDouble 的优化版本
double x = 2.0 * random.nextDouble() - 1.0;
double y = 2.0 * random.nextDouble() - 1.0;
// 避免使用 Math.pow
double distanceSquared = x * x + y * y;
if (distanceSquared <= 1.0) {
insideCircle++;
}
}
return 4.0 * insideCircle / totalPoints;
}
/**
* 使用 ThreadLocalRandom(多线程友好)
*/
public static double calculatePiWithThreadLocalRandom(long totalPoints) {
long insideCircle = 0;
for (long i = 0; i < totalPoints; i++) {
ThreadLocalRandom random = ThreadLocalRandom.current();
double x = random.nextDouble(-1.0, 1.0);
double y = random.nextDouble(-1.0, 1.0);
if (x * x + y * y <= 1.0) {
insideCircle++;
}
}
return 4.0 * insideCircle / totalPoints;
}
/**
* 使用 SecureRandom(加密安全,但速度慢)
*/
public static double calculatePiSecure(long totalPoints) {
SecureRandom random = new SecureRandom();
long insideCircle = 0;
for (long i = 0; i < totalPoints; i++) {
double x = random.nextDouble() * 2 - 1;
double y = random.nextDouble() * 2 - 1;
if (x * x + y * y <= 1.0) {
insideCircle++;
}
}
return 4.0 * insideCircle / totalPoints;
}
public static void main(String[] args) {
System.out.println("=== 蒙特卡罗法计算 π 值 ===\n");
long[] pointCounts = {1000L, 10000L, 100000L, 1000000L, 10000000L, 100000000L};
for (long points : pointCounts) {
System.out.printf("点数: %,d%n", points);
long startTime = System.currentTimeMillis();
double pi = calculatePiBasic(points);
long endTime = System.currentTimeMillis();
System.out.printf(" π ≈ %.10f%n", pi);
System.out.printf(" 误差: %.10f%n", Math.abs(Math.PI - pi));
System.out.printf(" 时间: %,d ms%n", endTime - startTime);
System.out.printf(" 相对误差: %.6f%%%n", Math.abs(Math.PI - pi) / Math.PI * 100);
System.out.println();
}
}
}
=== 蒙特卡罗法计算 π 值 ===
点数: 1,000
π ≈ 3.2080000000
误差: 0.0664073464
时间: 2 ms
相对误差: 2.113811%
点数: 10,000
π ≈ 3.1492000000
误差: 0.0076073464
时间: 1 ms
相对误差: 0.242149%
点数: 100,000
π ≈ 3.1386400000
误差: 0.0029526536
时间: 8 ms
相对误差: 0.093986%
点数: 1,000,000
π ≈ 3.1429160000
误差: 0.0013233464
时间: 43 ms
相对误差: 0.042123%
点数: 10,000,000
π ≈ 3.1416544000
误差: 0.0000617464
时间: 358 ms
相对误差: 0.001965%
点数: 100,000,000
π ≈ 3.1416113200
误差: 0.0000186664
时间: 3,522 ms
相对误差: 0.000594%
3.3、级数公式
常见级数公式:
- 莱布尼茨级数:收敛最慢
- 尼拉坎塔级数:收敛比莱布尼茨快
- 马青公式:收敛最快,适合计算机计算
- 拉马努金公式:收敛极快,每项增加约8位精度
- 楚德诺夫斯基公式:收敛最快,每项增加约14位精度
Java实现:
java
public class SeriesPiCalculator {
/**
* 莱布尼茨级数
* π/4 = 1 - 1/3 + 1/5 - 1/7 + 1/9 - ...
* 收敛速度:每 10 项增加约 1 位精度
*/
public static class LeibnizSeries {
public static double calculatePi(int terms) {
double piOver4 = 0.0;
for (int n = 0; n < terms; n++) {
double term = 1.0 / (2 * n + 1);
if (n % 2 == 0) {
piOver4 += term; // 奇数项加
} else {
piOver4 -= term; // 偶数项减
}
}
return piOver4 * 4.0;
}
}
/**
* 高精度版本
*/
public static BigDecimal calculatePiHighPrecision(int terms, int precision) {
MathContext mc = new MathContext(precision + 10, RoundingMode.HALF_UP);
BigDecimal piOver4 = BigDecimal.ZERO;
for (int n = 0; n < terms; n++) {
BigDecimal denominator = new BigDecimal(2 * n + 1);
BigDecimal term = BigDecimal.ONE.divide(denominator, mc);
if (n % 2 == 0) {
piOver4 = piOver4.add(term, mc);
} else {
piOver4 = piOver4.subtract(term, mc);
}
}
return piOver4.multiply(new BigDecimal(4), mc)
.setScale(precision, RoundingMode.HALF_UP);
}
}
四、矩阵运算
java
import java.util.Arrays;
import java.util.Random;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
/**
* 矩阵类 - 支持基本运算
*/
public class Matrix {
private final int rows;
private final int cols;
private final double[][] data;
// ========== 构造函数 ==========
public Matrix(int rows, int cols) {
this.rows = rows;
this.cols = cols;
this.data = new double[rows][cols];
}
public Matrix(double[][] data) {
this.rows = data.length;
this.cols = data[0].length;
this.data = new double[rows][cols];
for (int i = 0; i < rows; i++) {
System.arraycopy(data[i], 0, this.data[i], 0, cols);
}
}
public Matrix(int rows, int cols, double value) {
this(rows, cols);
fill(value);
}
public Matrix(Matrix other) {
this(other.data);
}
// ========== 工厂方法 ==========
public static Matrix zeros(int rows, int cols) {
return new Matrix(rows, cols, 0.0);
}
public static Matrix ones(int rows, int cols) {
return new Matrix(rows, cols, 1.0);
}
public static Matrix identity(int n) {
Matrix result = zeros(n, n);
for (int i = 0; i < n; i++) {
result.set(i, i, 1.0);
}
return result;
}
public static Matrix random(int rows, int cols, double min, double max) {
Matrix result = new Matrix(rows, cols);
Random rand = new Random();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, min + (max - min) * rand.nextDouble());
}
}
return result;
}
public static Matrix randomGaussian(int rows, int cols, double mean, double std) {
Matrix result = new Matrix(rows, cols);
Random rand = new Random();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, mean + std * rand.nextGaussian());
}
}
return result;
}
public static Matrix diagonal(double[] diagonal) {
int n = diagonal.length;
Matrix result = zeros(n, n);
for (int i = 0; i < n; i++) {
result.set(i, i, diagonal[i]);
}
return result;
}
public static Matrix fromArray(double[] array, boolean asColumn) {
if (asColumn) {
Matrix result = new Matrix(array.length, 1);
for (int i = 0; i < array.length; i++) {
result.set(i, 0, array[i]);
}
return result;
} else {
Matrix result = new Matrix(1, array.length);
for (int j = 0; j < array.length; j++) {
result.set(0, j, array[j]);
}
return result;
}
}
// ========== 基本操作 ==========
public int getRows() { return rows; }
public int getCols() { return cols; }
public double[][] getData() { return data; }
public double get(int i, int j) {
checkBounds(i, j);
return data[i][j];
}
public void set(int i, int j, double value) {
checkBounds(i, j);
data[i][j] = value;
}
public double[] getRow(int i) {
checkRow(i);
return Arrays.copyOf(data[i], cols);
}
public double[] getColumn(int j) {
checkColumn(j);
double[] column = new double[rows];
for (int i = 0; i < rows; i++) {
column[i] = data[i][j];
}
return column;
}
public void setRow(int i, double[] rowData) {
checkRow(i);
if (rowData.length != cols) {
throw new IllegalArgumentException("行数据长度不匹配");
}
System.arraycopy(rowData, 0, data[i], 0, cols);
}
public void setColumn(int j, double[] colData) {
checkColumn(j);
if (colData.length != rows) {
throw new IllegalArgumentException("列数据长度不匹配");
}
for (int i = 0; i < rows; i++) {
data[i][j] = colData[i];
}
}
public Matrix reshape(int newRows, int newCols) {
if (rows * cols != newRows * newCols) {
throw new IllegalArgumentException("元素数量不匹配");
}
Matrix result = new Matrix(newRows, newCols);
int index = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
int newI = index / newCols;
int newJ = index % newCols;
result.set(newI, newJ, data[i][j]);
index++;
}
}
return result;
}
public Matrix flatten() {
return reshape(rows * cols, 1);
}
public void fill(double value) {
for (int i = 0; i < rows; i++) {
Arrays.fill(data[i], value);
}
}
public void fill(DoubleUnaryOperator generator) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
data[i][j] = generator.applyAsDouble(data[i][j]);
}
}
}
public Matrix copy() {
return new Matrix(this);
}
public double[][] toArray() {
double[][] copy = new double[rows][cols];
for (int i = 0; i < rows; i++) {
System.arraycopy(data[i], 0, copy[i], 0, cols);
}
return copy;
}
public double[] toFlatArray() {
double[] flat = new double[rows * cols];
int index = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
flat[index++] = data[i][j];
}
}
return flat;
}
// ========== 矩阵运算 ==========
public Matrix add(Matrix other) {
checkSameDimensions(other);
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j] + other.get(i, j));
}
}
return result;
}
public Matrix add(double scalar) {
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j] + scalar);
}
}
return result;
}
public Matrix subtract(Matrix other) {
checkSameDimensions(other);
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j] - other.get(i, j));
}
}
return result;
}
public Matrix subtract(double scalar) {
return add(-scalar);
}
public Matrix multiply(Matrix other) {
if (cols != other.rows) {
throw new IllegalArgumentException(
String.format("矩阵维度不匹配: %dx%d * %dx%d", rows, cols, other.rows, other.cols));
}
Matrix result = new Matrix(rows, other.cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < other.cols; j++) {
double sum = 0.0;
for (int k = 0; k < cols; k++) {
sum += data[i][k] * other.get(k, j);
}
result.set(i, j, sum);
}
}
return result;
}
public Matrix multiply(double scalar) {
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j] * scalar);
}
}
return result;
}
public Matrix elementwiseMultiply(Matrix other) {
checkSameDimensions(other);
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j] * other.get(i, j));
}
}
return result;
}
public Matrix elementwiseDivide(Matrix other) {
checkSameDimensions(other);
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (other.get(i, j) == 0) {
throw new ArithmeticException("除零错误");
}
result.set(i, j, data[i][j] / other.get(i, j));
}
}
return result;
}
public Matrix transpose() {
Matrix result = new Matrix(cols, rows);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(j, i, data[i][j]);
}
}
return result;
}
public double sum() {
double total = 0.0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
total += data[i][j];
}
}
return total;
}
public double mean() {
return sum() / (rows * cols);
}
public double trace() {
if (!isSquare()) {
throw new IllegalArgumentException("矩阵必须是方阵");
}
double trace = 0.0;
for (int i = 0; i < rows; i++) {
trace += data[i][i];
}
return trace;
}
public double determinant() {
if (!isSquare()) {
throw new IllegalArgumentException("矩阵必须是方阵");
}
if (rows == 1) {
return data[0][0];
}
if (rows == 2) {
return data[0][0] * data[1][1] - data[0][1] * data[1][0];
}
if (rows == 3) {
return data[0][0] * (data[1][1] * data[2][2] - data[1][2] * data[2][1])
- data[0][1] * (data[1][0] * data[2][2] - data[1][2] * data[2][0])
+ data[0][2] * (data[1][0] * data[2][1] - data[1][1] * data[2][0]);
}
// 对于更大的矩阵,使用LU分解
return new LUDecomposition(this).determinant();
}
public Matrix inverse() {
if (!isSquare()) {
throw new IllegalArgumentException("矩阵必须是方阵");
}
if (rows == 1) {
if (data[0][0] == 0) {
throw new ArithmeticException("矩阵不可逆");
}
return new Matrix(1, 1, 1.0 / data[0][0]);
}
if (rows == 2) {
double det = determinant();
if (det == 0) {
throw new ArithmeticException("矩阵不可逆");
}
Matrix result = new Matrix(2, 2);
result.set(0, 0, data[1][1] / det);
result.set(0, 1, -data[0][1] / det);
result.set(1, 0, -data[1][0] / det);
result.set(1, 1, data[0][0] / det);
return result;
}
// 使用LU分解求逆
return new LUDecomposition(this).inverse();
}
public Matrix power(int n) {
if (!isSquare()) {
throw new IllegalArgumentException("矩阵必须是方阵");
}
if (n < 0) {
return inverse().power(-n);
}
if (n == 0) {
return identity(rows);
}
if (n == 1) {
return copy();
}
// 快速幂算法
Matrix result = identity(rows);
Matrix base = copy();
int exp = n;
while (exp > 0) {
if ((exp & 1) == 1) {
result = result.multiply(base);
}
base = base.multiply(base);
exp >>= 1;
}
return result;
}
public double norm(String type) {
switch (type.toLowerCase()) {
case "frobenius":
case "fro":
return Math.sqrt(elementwiseMultiply(this).sum());
case "1":
double max = 0;
for (int j = 0; j < cols; j++) {
double colSum = 0;
for (int i = 0; i < rows; i++) {
colSum += Math.abs(data[i][j]);
}
max = Math.max(max, colSum);
}
return max;
case "inf":
max = 0;
for (int i = 0; i < rows; i++) {
double rowSum = 0;
for (int j = 0; j < cols; j++) {
rowSum += Math.abs(data[i][j]);
}
max = Math.max(max, rowSum);
}
return max;
case "2":
// 计算谱范数(最大奇异值)
SingularValueDecomposition svd = new SingularValueDecomposition(this);
return svd.getSingularValues()[0];
default:
throw new IllegalArgumentException("不支持的类型: " + type);
}
}
public double frobeniusNorm() {
return norm("fro");
}
public double norm1() {
return norm("1");
}
public double normInf() {
return norm("inf");
}
public double norm2() {
return norm("2");
}
public double conditionNumber(String type) {
if (!isSquare()) {
throw new IllegalArgumentException("矩阵必须是方阵");
}
Matrix inverse = inverse();
return norm(type) * inverse.norm(type);
}
public Matrix solve(Matrix B) {
if (rows != B.rows) {
throw new IllegalArgumentException("方程维度不匹配");
}
if (rows == cols) {
// 方阵,使用LU分解
return new LUDecomposition(this).solve(B);
} else {
// 非方阵,使用QR分解求解最小二乘
return new QRDecomposition(this).solve(B);
}
}
public double[] solve(double[] b) {
Matrix B = Matrix.fromArray(b, true);
Matrix X = solve(B);
return X.getColumn(0);
}
// ========== 矩阵分解 ==========
public LUDecomposition lu() {
return new LUDecomposition(this);
}
public QRDecomposition qr() {
return new QRDecomposition(this);
}
public CholeskyDecomposition cholesky() {
return new CholeskyDecomposition(this);
}
public SingularValueDecomposition svd() {
return new SingularValueDecomposition(this);
}
public EigenvalueDecomposition eig() {
return new EigenvalueDecomposition(this);
}
// ========== 判断函数 ==========
public boolean isSquare() {
return rows == cols;
}
public boolean isSymmetric() {
if (!isSquare()) return false;
for (int i = 0; i < rows; i++) {
for (int j = i + 1; j < cols; j++) {
if (Math.abs(data[i][j] - data[j][i]) > 1e-10) {
return false;
}
}
}
return true;
}
public boolean isDiagonal() {
if (!isSquare()) return false;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (i != j && Math.abs(data[i][j]) > 1e-10) {
return false;
}
}
}
return true;
}
public boolean isUpperTriangular() {
if (!isSquare()) return false;
for (int i = 1; i < rows; i++) {
for (int j = 0; j < i; j++) {
if (Math.abs(data[i][j]) > 1e-10) {
return false;
}
}
}
return true;
}
public boolean isLowerTriangular() {
if (!isSquare()) return false;
for (int i = 0; i < rows; i++) {
for (int j = i + 1; j < cols; j++) {
if (Math.abs(data[i][j]) > 1e-10) {
return false;
}
}
}
return true;
}
public boolean isOrthogonal() {
if (!isSquare()) return false;
Matrix I = identity(rows);
Matrix ATA = transpose().multiply(this);
return ATA.equals(I, 1e-10);
}
public boolean equals(Matrix other, double tolerance) {
if (rows != other.rows || cols != other.cols) {
return false;
}
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (Math.abs(data[i][j] - other.get(i, j)) > tolerance) {
return false;
}
}
}
return true;
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (!(obj instanceof Matrix)) return false;
Matrix other = (Matrix) obj;
return equals(other, 1e-10);
}
// ========== 实用函数 ==========
public Matrix apply(DoubleUnaryOperator function) {
Matrix result = new Matrix(rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, function.applyAsDouble(data[i][j]));
}
}
return result;
}
public Matrix max(double value) {
return apply(x -> Math.max(x, value));
}
public Matrix min(double value) {
return apply(x -> Math.min(x, value));
}
public Matrix abs() {
return apply(Math::abs);
}
public Matrix exp() {
return apply(Math::exp);
}
public Matrix log() {
return apply(x -> Math.log(x + 1e-10)); // 避免log(0)
}
public Matrix sqrt() {
return apply(Math::sqrt);
}
public Matrix sigmoid() {
return apply(x -> 1.0 / (1.0 + Math.exp(-x)));
}
public Matrix relu() {
return apply(x -> Math.max(0, x));
}
public Matrix softmax() {
Matrix exp = this.subtract(this.max()).exp();
double sum = exp.sum();
return exp.multiply(1.0 / sum);
}
public double max() {
double max = data[0][0];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (data[i][j] > max) {
max = data[i][j];
}
}
}
return max;
}
public double min() {
double min = data[0][0];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (data[i][j] < min) {
min = data[i][j];
}
}
}
return min;
}
public double maxAbs() {
double max = Math.abs(data[0][0]);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
double abs = Math.abs(data[i][j]);
if (abs > max) {
max = abs;
}
}
}
return max;
}
public int[] argmax() {
double max = data[0][0];
int maxI = 0, maxJ = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (data[i][j] > max) {
max = data[i][j];
maxI = i;
maxJ = j;
}
}
}
return new int[]{maxI, maxJ};
}
public int[] argmin() {
double min = data[0][0];
int minI = 0, minJ = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (data[i][j] < min) {
min = data[i][j];
minI = i;
minJ = j;
}
}
}
return new int[]{minI, minJ};
}
public Matrix slice(int startRow, int endRow, int startCol, int endCol) {
if (startRow < 0 || endRow > rows || startRow >= endRow ||
startCol < 0 || endCol > cols || startCol >= endCol) {
throw new IllegalArgumentException("切片范围无效");
}
int newRows = endRow - startRow;
int newCols = endCol - startCol;
Matrix result = new Matrix(newRows, newCols);
for (int i = 0; i < newRows; i++) {
for (int j = 0; j < newCols; j++) {
result.set(i, j, data[startRow + i][startCol + j]);
}
}
return result;
}
public Matrix concat(Matrix other, String axis) {
if ("vertical".equalsIgnoreCase(axis) || "v".equalsIgnoreCase(axis)) {
if (cols != other.cols) {
throw new IllegalArgumentException("列数必须相同");
}
Matrix result = new Matrix(rows + other.rows, cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j]);
}
}
for (int i = 0; i < other.rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(rows + i, j, other.get(i, j));
}
}
return result;
} else {
if (rows != other.rows) {
throw new IllegalArgumentException("行数必须相同");
}
Matrix result = new Matrix(rows, cols + other.cols);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
result.set(i, j, data[i][j]);
}
for (int j = 0; j < other.cols; j++) {
result.set(i, cols + j, other.get(i, j));
}
}
return result;
}
}
public Matrix vstack(Matrix other) {
return concat(other, "vertical");
}
public Matrix hstack(Matrix other) {
return concat(other, "horizontal");
}
// ========== 静态运算 ==========
public static Matrix add(Matrix A, Matrix B) {
return A.add(B);
}
public static Matrix subtract(Matrix A, Matrix B) {
return A.subtract(B);
}
public static Matrix multiply(Matrix A, Matrix B) {
return A.multiply(B);
}
public static Matrix dot(Matrix A, Matrix B) {
return A.multiply(B);
}
public static Matrix kron(Matrix A, Matrix B) {
int m = A.rows, n = A.cols;
int p = B.rows, q = B.cols;
Matrix result = new Matrix(m * p, n * q);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
double aij = A.get(i, j);
for (int k = 0; k < p; k++) {
for (int l = 0; l < q; l++) {
result.set(i * p + k, j * q + l, aij * B.get(k, l));
}
}
}
}
return result;
}
public static Matrix cross(Matrix A, Matrix B) {
if (A.rows != 3 || A.cols != 1 || B.rows != 3 || B.cols != 1) {
throw new IllegalArgumentException("必须是3x1向量");
}
Matrix result = new Matrix(3, 1);
result.set(0, 0, A.get(1, 0) * B.get(2, 0) - A.get(2, 0) * B.get(1, 0));
result.set(1, 0, A.get(2, 0) * B.get(0, 0) - A.get(0, 0) * B.get(2, 0));
result.set(2, 0, A.get(0, 0) * B.get(1, 0) - A.get(1, 0) * B.get(0, 0));
return result;
}
public static double dot(Matrix A, Matrix B) {
if (!A.isVector() || !B.isVector() || A.rows * A.cols != B.rows * B.cols) {
throw new IllegalArgumentException("必须是相同维度的向量");
}
double sum = 0.0;
double[] aFlat = A.toFlatArray();
double[] bFlat = B.toFlatArray();
for (int i = 0; i < aFlat.length; i++) {
sum += aFlat[i] * bFlat[i];
}
return sum;
}
public boolean isVector() {
return rows == 1 || cols == 1;
}
// ========== 验证函数 ==========
private void checkBounds(int i, int j) {
if (i < 0 || i >= rows || j < 0 || j >= cols) {
throw new IndexOutOfBoundsException(
String.format("索引 [%d,%d] 超出范围 [%d,%d]", i, j, rows, cols));
}
}
private void checkRow(int i) {
if (i < 0 || i >= rows) {
throw new IndexOutOfBoundsException("行索引超出范围: " + i);
}
}
private void checkColumn(int j) {
if (j < 0 || j >= cols) {
throw new IndexOutOfBoundsException("列索引超出范围: " + j);
}
}
private void checkSameDimensions(Matrix other) {
if (rows != other.rows || cols != other.cols) {
throw new IllegalArgumentException(
String.format("矩阵维度不匹配: %dx%d vs %dx%d", rows, cols, other.rows, other.cols));
}
}
// ========== 显示和字符串 ==========
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("Matrix(%dx%d):\n", rows, cols));
for (int i = 0; i < rows; i++) {
sb.append("[");
for (int j = 0; j < cols; j++) {
sb.append(String.format("%10.4f", data[i][j]));
if (j < cols - 1) {
sb.append(" ");
}
}
sb.append("]\n");
}
return sb.toString();
}
public String toString(String format) {
StringBuilder sb = new StringBuilder();
sb.append(String.format("Matrix(%dx%d):\n", rows, cols));
for (int i = 0; i < rows; i++) {
sb.append("[");
for (int j = 0; j < cols; j++) {
sb.append(String.format(format, data[i][j]));
if (j < cols - 1) {
sb.append(" ");
}
}
sb.append("]\n");
}
return sb.toString();
}
public void print() {
System.out.println(this);
}
public void print(String format) {
System.out.println(toString(format));
}
}