如何兼容不同版本的 scikit-learn(sklearn)库,统一获取“均方根误差(RMSE)”的计算函数

文章目录

  • [1. 代码解析](#1. 代码解析)
  • [2. 版本兼容背景](#2. 版本兼容背景)

1. 代码解析

下面这段代码通过 try-except 异常捕获机制,适配 sklearn 新旧版本中 RMSE 函数的不同写法:

  • 新版 sklearn 直接提供 root_mean_squared_error 函数(专门计算 RMSE);
  • 旧版 sklearn 只有 mean_squared_error 函数(默认计算 MSE),需通过 squared=False 参数转为 RMSE;

最终无论 sklearn 版本是新是旧,代码中 mean_squared_error 变量都指向计算 RMSE 的函数

python 复制代码
from functools import partial  # 必须导入,否则会报NameError

try:
    # 尝试从sklearn.metrics导入新版的RMSE函数,并将其重命名为mean_squared_error
    from sklearn.metrics import root_mean_squared_error as mean_squared_error
except ImportError:
    # 如果导入失败(说明是旧版sklearn),执行以下逻辑
    # 1. 导入旧版的MSE函数
    from sklearn.metrics import mean_squared_error
    # 2. 使用partial固定参数squared=False,将MSE函数转为RMSE函数
    mean_squared_error = partial(mean_squared_error, squared=False)

关键细节解释

  • root_mean_squared_error :sklearn 1.0 及以上版本新增的函数,直接返回均方根误差(RMSE),公式为:
    R M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 RMSE = \sqrt{\frac{1}{n}\sum_{i=1}^n (y_i - \hat{y}_i)^2} RMSE=n1i=1∑n(yi−y^i)2
  • mean_squared_error(旧版) :sklearn 1.0 之前的版本只有这个函数,默认 squared=True,返回均方误差(MSE);当 squared=False 时,返回 RMSE。
  • partial :Python functools 模块中的函数(代码中省略了 from functools import partial,需确保已导入),作用是"固定函数的部分参数",这里把 mean_squared_errorsquared 参数固定为 False,相当于创建了一个"默认计算 RMSE 的新函数",并重新赋值给 mean_squared_error 变量。

2. 版本兼容背景

sklearn 版本 计算 RMSE 的方式
≥1.0 root_mean_squared_error(y_true, y_pred)
<1.0 mean_squared_error(y_true, y_pred, squared=False)

如果直接写死其中一种方式,会导致:

  • 用新版 sklearn 运行旧版写法:虽然能运行(sklearn 1.0+ 仍兼容 squared=False),但不够优雅;
  • 用旧版 sklearn 运行新版写法:会报 ImportError(找不到 root_mean_squared_error),代码直接崩溃。
相关推荐
小二·1 小时前
Python Web 开发进阶实战:可持续计算 —— 在 Flask + Vue 中构建碳感知应用(Carbon-Aware Computing)
前端·python·flask
Java程序员威哥1 小时前
【包教包会】SpringBoot依赖Jar指定位置打包:配置+原理+避坑全解析
java·开发语言·spring boot·后端·python·微服务·jar
Java程序员威哥1 小时前
Java微服务可观测性实战:Prometheus+Grafana+SkyWalking全链路监控落地
java·开发语言·python·docker·微服务·grafana·prometheus
UR的出不克2 小时前
基于PyTorch的MNIST手写数字识别系统 - 从零到实战
人工智能·python·数字识别
one____dream2 小时前
【算法】大整数数组连续进位
python·算法
one____dream2 小时前
【算法】合并两个有序链表
数据结构·python·算法·链表
程序员敲代码吗2 小时前
持续集成/持续部署(CI/CD) for Python
jvm·数据库·python
人工智能AI技术2 小时前
【Agent从入门到实践】16 接口与网络:API调用、HTTP请求,Agent与外部交互的基础
人工智能·python
余衫马2 小时前
Qt for Python:PySide6 入门指南(下篇)
c++·python·qt