回归(多项式回归)

例子:回归(多项式回归)

训练数据:text.csv

复制代码
x,y
235,591
216,539
148,413
35,310
85,308
204,519
49,325
25,332
173,498
191,498
134,392
99,334
117,385
112,387
162,425
272,659
159,400
159,427
59,319
198,522
python 复制代码
import numpy as np
import matplotlib.pyplot as plt

#读入训练数据
train = np.loadtxt('text.csv',delimiter=',',skiprows=1)
train_x = train[:,0]
train_y = train[:,1]

#展示训练数据
#plt.plot(train_x,train_y,'o')
#plt.show()

#标准化数据
mu = train_x.mean()
sigma = train_x.std()
def standardize(x):
    return (x - mu)/sigma

train_z = standardize(train_x)
#plt.plot(train_z,train_y,'o')
#plt.show()

#均方误差
#在停止重复的条件里用上
def MSE(x,y):
    return (1/x.shape[0])*np.sum((y-f(x)) ** 2)

#生成三个随机数 代表三个参数 theta是参数列表
theta = np.random.rand(3)

#均方误差的历史记录
errors = []

#创建训练数据的矩阵
#因为训练数据很多 把它们都放在一个矩阵里
#直接和theta相乘
#theta0 + theta1*x1 + theta2*x2
def to_matrix(x):
    return np.vstack([np.ones(x.shape[0]),x,x**2]).T

X = to_matrix(train_z)

#预测函数
#theta0 + theta1*x1 + theta2*x2
#dot:矩阵乘法
def f(x):
    return np.dot(x,theta)

#目标函数 error误差 最小二乘法
def E(x,y):
    return 0.5*np.sum((y-f(x))**2)

#learning rate 学习率
ETA = 1e-3

#误差的差值
diff = 1;

#重复学习
errors.append(MSE(X,train_y))
error = E(X,train_y)
while diff>1e-2:
    #更新参数
    theta = theta - ETA*np.dot(f(X)-train_y,X)

    #计算差值
    errors.append(MSE(X,train_y))
    current_error = E(X,train_y)
    diff = errors[-2] - errors[-1]
    #不用均方误差的diff
    #diff = error - current_error
    error = current_error

'''
图表拟合展示
x = np.linspace(-3,3,100)
plt.plot(train_z,train_y,'o')
plt.plot(x,f(to_matrix(x)))
plt.show()
'''

#绘制误差变化图
x = np.arange(len(errors))
plt.plot(x,errors)
plt.show()
相关推荐
人工小情绪4 分钟前
Clawbot (OpenClaw)简介
人工智能
2501_9333295530 分钟前
品牌公关AI化实践:Infoseek舆情系统技术架构解析
人工智能·自然语言处理
咋吃都不胖lyh36 分钟前
CLIP 不是一个 “自主判断图像内容” 的图像分类模型,而是一个 “图文语义相似度匹配模型”—
人工智能·深度学习·机器学习
xiucai_cs39 分钟前
AI RAG 本地知识库实战
人工智能·知识库·dify·rag·ollama
zhangfeng113344 分钟前
大模型微调时 Firefly(流萤)和 LlamaFactory(LLaMA Factory)这两个工具/框架之间做出合适的选择
人工智能·llama
zhangyifang_0091 小时前
MCP——AI连接现实世界的“标准接口”
人工智能
LOnghas12111 小时前
电动汽车充电接口自动识别与定位_yolo13-C3k2-Converse_六种主流充电接口检测分类
人工智能·目标跟踪·分类
编码小哥1 小时前
OpenCV图像滤波技术详解:从均值滤波到双边滤波
人工智能·opencv·均值算法
阿杰学AI2 小时前
AI核心知识78——大语言模型之CLM(简洁且通俗易懂版)
人工智能·算法·ai·语言模型·rag·clm·语境化语言模型
新缸中之脑2 小时前
氛围编程一个全栈AI交易应用
人工智能