【Python机器学习】回归——局部加权线性回归

线性回归有一个问题就是有可能出现过拟合现象,因为它求的是具有最小均方误差的无偏估计。如果模型欠拟合将不能取得最好的预测效果。所以有些方法允许在估计中引入一些偏差,从而降低预测的均方误差。

其中一个方法就是局部加权线性回归(LWLR),在该算法中,我们给待预测点附近的每个点赋予一定的权重,然后在这个子集上基于最小均方差来进行普通的回归,与kNN一样,这种算法每次预测均需要事先选取出对应的数据子集。

该算法解出回归系数w的形式如下:

其中w是一个矩阵,用来给每个数据点赋予权重。

LWLR使用"核"来对附近的点赋予更高的权重。核的类型可以自由选择,最常用的核是高斯核,高斯核对应的权重如下:

这样就构建了一个只包含对角元素的权重矩阵w,并且点x与x(i)越近,w(i,i)将会越大。上述工时包含一个需要用户指定的参数k,它决定了对附近的点赋予多大的权重,这也是使用LWLR时唯一需要考虑的参数。

下面是具体的代码实现:

python 复制代码
def lwlr(testPoint,xArr,yArr,k=1.0):
    xMat=mat(xArr)
    yMat=mat(yArr).T
    m=shape(xMat)[0]
    #创建对角矩阵
    weights=mat(eye((m)))
    for j in range(m):
        diffMat=testPoint-xMat[j,:]
        weights[j,j]=exp(diffMat*diffMat.T/(-2.0*k**2))
    xTx=xMat.T*(weights*xMat)
    if linalg.det(xTx)==0.0:
        print('行列式为0')
        return 
    ws=xTx.I*(xMat.T*(weights*yMat))
    return testPoint*ws

def lwlrTest(testArr,xArr,yArr,k=1.0):
    m=shape(testArr)[0]
    yHat=zeros(m)
    for i in range(m):
        yHat[i]=lwlr(testArr[i],xArr,yArr,k)
    return yHat

上述代码的作用是给定x空间中的任意一点,计算出对应的预测值yHat。函数lwlr()的开头读入数据并创建所需矩阵,之后创建对角权重矩阵weights。权重矩阵时一个方针,阶数等于样本点个数。也就是说,该矩阵为每个样本点初始化了一个权重,接着算法将遍历数据集,计算每个样本点对应的权重值:随着样本点与待预测点距离的递增,权重将以指数级衰减。输入参数k控制衰减的速度。在权重矩阵计算完毕后,就可以得到对回归系数ws的一个估计。

另一个函数是lwlrTest(),用于为数据集中每个点调用lwlr(),这有助于求解k的大小。

载入数据并对单点进行估计:

python 复制代码
xArr,yArr=loadDataSet('ex0.txt')
print(lwlr(xArr[0],xArr,yArr,1.0))
print(lwlr(xArr[0],xArr,yArr,0.001))

计算拟合曲线并绘图

python 复制代码
yHat=lwlrTest(xArr,xArr,yArr,0.003)
xMat=mat(xArr)
srtInd=xMat[:,1].argsort(0)
xSort=xMat[srtInd][:,0,:]

import matplotlib.pyplot as plt
fig=plt.figure()
ax=fig.add_subplot(111)
ax.plot(xSort[:,1],yHat[srtInd])
ax.scatter(xMat[:,1].flatten().A[0],mat(yArr).T.flatten().A[0],s=2,c='red')
plt.show()

可以观察上图的效果。

局部加权线性回归也存在一个问题,那就是增加了计算了,因为它对每个点做预测时都必须使用整个数据集。如果避免这些计算就可以减少程序运行时间,从而环节因计算了增加带来的问题。

相关推荐
机器学习之心3 分钟前
基于GSWOA-SVM三种策略改进鲸鱼算法优化支持向量机的数据多变量时间序列预测,Matlab代码
算法·支持向量机·matlab·优化支持向量机·gswoa-svm·三种策略改进鲸鱼算法
旖-旎8 分钟前
前缀和(和为K的子数组)(5)
c++·算法·leetcode·前缀和·哈希算法·散列表
进击的雷神10 分钟前
多展会框架复用、Next.js结构统一、北非网络优化、参数差异化配置——阿尔及利亚展爬虫四大技术难关攻克纪实
javascript·网络·爬虫·python
ZTLJQ11 分钟前
网络通信的基石:Python HTTP请求库完全解析
开发语言·python·http
进击的荆棘12 分钟前
优选算法——链表
数据结构·算法·链表·stl
凌波粒15 分钟前
LeetCode--203.移除链表元素(链表)
java·算法·leetcode·链表
不染尘.17 分钟前
背包问题BP
开发语言·c++·算法
华科大胡子17 分钟前
爬虫对抗:ZLibrary反爬机制实战分析
python
进击的小头19 分钟前
第17篇:卡尔曼滤波器之概率论初步
python·算法·概率论
是梦终空19 分钟前
计算机毕业设计269—基于python+深度学习+YOLOV8的交通标志识别系统(源代码+数据库+报告)
python·深度学习·opencv·毕业设计·torch·课程设计·pyqt5