【Python机器学习】回归——局部加权线性回归

线性回归有一个问题就是有可能出现过拟合现象,因为它求的是具有最小均方误差的无偏估计。如果模型欠拟合将不能取得最好的预测效果。所以有些方法允许在估计中引入一些偏差,从而降低预测的均方误差。

其中一个方法就是局部加权线性回归(LWLR),在该算法中,我们给待预测点附近的每个点赋予一定的权重,然后在这个子集上基于最小均方差来进行普通的回归,与kNN一样,这种算法每次预测均需要事先选取出对应的数据子集。

该算法解出回归系数w的形式如下:

其中w是一个矩阵,用来给每个数据点赋予权重。

LWLR使用"核"来对附近的点赋予更高的权重。核的类型可以自由选择,最常用的核是高斯核,高斯核对应的权重如下:

这样就构建了一个只包含对角元素的权重矩阵w,并且点x与x(i)越近,w(i,i)将会越大。上述工时包含一个需要用户指定的参数k,它决定了对附近的点赋予多大的权重,这也是使用LWLR时唯一需要考虑的参数。

下面是具体的代码实现:

python 复制代码
def lwlr(testPoint,xArr,yArr,k=1.0):
    xMat=mat(xArr)
    yMat=mat(yArr).T
    m=shape(xMat)[0]
    #创建对角矩阵
    weights=mat(eye((m)))
    for j in range(m):
        diffMat=testPoint-xMat[j,:]
        weights[j,j]=exp(diffMat*diffMat.T/(-2.0*k**2))
    xTx=xMat.T*(weights*xMat)
    if linalg.det(xTx)==0.0:
        print('行列式为0')
        return 
    ws=xTx.I*(xMat.T*(weights*yMat))
    return testPoint*ws

def lwlrTest(testArr,xArr,yArr,k=1.0):
    m=shape(testArr)[0]
    yHat=zeros(m)
    for i in range(m):
        yHat[i]=lwlr(testArr[i],xArr,yArr,k)
    return yHat

上述代码的作用是给定x空间中的任意一点,计算出对应的预测值yHat。函数lwlr()的开头读入数据并创建所需矩阵,之后创建对角权重矩阵weights。权重矩阵时一个方针,阶数等于样本点个数。也就是说,该矩阵为每个样本点初始化了一个权重,接着算法将遍历数据集,计算每个样本点对应的权重值:随着样本点与待预测点距离的递增,权重将以指数级衰减。输入参数k控制衰减的速度。在权重矩阵计算完毕后,就可以得到对回归系数ws的一个估计。

另一个函数是lwlrTest(),用于为数据集中每个点调用lwlr(),这有助于求解k的大小。

载入数据并对单点进行估计:

python 复制代码
xArr,yArr=loadDataSet('ex0.txt')
print(lwlr(xArr[0],xArr,yArr,1.0))
print(lwlr(xArr[0],xArr,yArr,0.001))

计算拟合曲线并绘图

python 复制代码
yHat=lwlrTest(xArr,xArr,yArr,0.003)
xMat=mat(xArr)
srtInd=xMat[:,1].argsort(0)
xSort=xMat[srtInd][:,0,:]

import matplotlib.pyplot as plt
fig=plt.figure()
ax=fig.add_subplot(111)
ax.plot(xSort[:,1],yHat[srtInd])
ax.scatter(xMat[:,1].flatten().A[0],mat(yArr).T.flatten().A[0],s=2,c='red')
plt.show()

可以观察上图的效果。

局部加权线性回归也存在一个问题,那就是增加了计算了,因为它对每个点做预测时都必须使用整个数据集。如果避免这些计算就可以减少程序运行时间,从而环节因计算了增加带来的问题。

相关推荐
Blossom.1182 小时前
把AI“绣”进丝绸:生成式刺绣神经网络让古装自带摄像头
人工智能·pytorch·python·深度学习·神经网络·机器学习·fpga开发
PyHaVolask3 小时前
数据结构与算法分析
数据结构·算法·图论
小王C语言3 小时前
封装红黑树实现mymap和myset
linux·服务器·算法
星星也在雾里3 小时前
【管理多版本Python环境】Anaconda安装及使用
python·anaconda
用户3721574261353 小时前
使用 Python 将 CSV 文件转换为 PDF 的实践指南
python
大佬,救命!!!3 小时前
算法实现迭代2_堆排序
数据结构·python·算法·学习笔记·堆排序
天桥下的卖艺者3 小时前
R语言手搓一个计算生存分析C指数(C-index)的函数算法
c语言·算法·r语言
max5006003 小时前
多GPU数据并行训练中GPU利用率不均衡问题深度分析与解决方案
人工智能·机器学习·分类·数据挖掘
Espresso Macchiato3 小时前
Leetcode 3715. Sum of Perfect Square Ancestors
算法·leetcode·职场和发展·leetcode hard·树的遍历·leetcode 3715·leetcode周赛471
总有刁民想爱朕ha4 小时前
Python自动化从入门到实战(24)如何高效的备份mysql数据库,数据备份datadir目录直接复制可行吗?一篇给小白的完全指南
数据库·python·自动化·mysql数据库备份