文章目录
- [1. 代码解析](#1. 代码解析)
- [2. 版本兼容背景](#2. 版本兼容背景)
1. 代码解析
下面这段代码通过 try-except 异常捕获机制,适配 sklearn 新旧版本中 RMSE 函数的不同写法:
- 新版 sklearn 直接提供
root_mean_squared_error函数(专门计算 RMSE); - 旧版 sklearn 只有
mean_squared_error函数(默认计算 MSE),需通过squared=False参数转为 RMSE;
最终无论 sklearn 版本是新是旧,代码中 mean_squared_error 变量都指向计算 RMSE 的函数。
python
from functools import partial # 必须导入,否则会报NameError
try:
# 尝试从sklearn.metrics导入新版的RMSE函数,并将其重命名为mean_squared_error
from sklearn.metrics import root_mean_squared_error as mean_squared_error
except ImportError:
# 如果导入失败(说明是旧版sklearn),执行以下逻辑
# 1. 导入旧版的MSE函数
from sklearn.metrics import mean_squared_error
# 2. 使用partial固定参数squared=False,将MSE函数转为RMSE函数
mean_squared_error = partial(mean_squared_error, squared=False)
关键细节解释
root_mean_squared_error:sklearn 1.0 及以上版本新增的函数,直接返回均方根误差(RMSE),公式为:
R M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 RMSE = \sqrt{\frac{1}{n}\sum_{i=1}^n (y_i - \hat{y}_i)^2} RMSE=n1i=1∑n(yi−y^i)2mean_squared_error(旧版) :sklearn 1.0 之前的版本只有这个函数,默认squared=True,返回均方误差(MSE);当squared=False时,返回 RMSE。partial:Pythonfunctools模块中的函数(代码中省略了from functools import partial,需确保已导入),作用是"固定函数的部分参数",这里把mean_squared_error的squared参数固定为False,相当于创建了一个"默认计算 RMSE 的新函数",并重新赋值给mean_squared_error变量。
2. 版本兼容背景
| sklearn 版本 | 计算 RMSE 的方式 |
|---|---|
| ≥1.0 | root_mean_squared_error(y_true, y_pred) |
| <1.0 | mean_squared_error(y_true, y_pred, squared=False) |
如果直接写死其中一种方式,会导致:
- 用新版 sklearn 运行旧版写法:虽然能运行(sklearn 1.0+ 仍兼容
squared=False),但不够优雅; - 用旧版 sklearn 运行新版写法:会报
ImportError(找不到root_mean_squared_error),代码直接崩溃。