机器学习线性回归算法是入门机器学习理解人工智能模型很好示例

线性回归是通过一个或多个自变量与因变量 之间进行建模的回归分析,其特点为一个或多个称为回归系数的模型参数的线性组合。如下图所示,样本点为历史数据,回归曲线要能最贴切的模拟样本点的趋势,将误差降到最小

用python是实现这个模型

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

class LinearRegression:
    def __init__(self, learning_rate=0.01, epochs=1000):
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.weights = None
        self.bias = None
    
    def fit(self, X, y):
        """训练线性回归模型"""
        # 获取样本数和特征数
        n_samples, n_features = X.shape
        
        # 确保y是一维数组
        y = y.ravel()
        
        # 初始化权重和偏置
        self.weights = np.zeros(n_features)
        self.bias = 0
        
        # 训练过程
        for epoch in range(self.epochs):
            # 前向传播
            y_pred = np.dot(X, self.weights) + self.bias
            
            # 计算损失
            loss = (1 / (2 * n_samples)) * np.sum((y_pred - y) ** 2)
            
            # 计算梯度
            dw = (1 / n_samples) * np.dot(X.T, (y_pred - y))
            db = (1 / n_samples) * np.sum(y_pred - y)
            
            # 确保dw是一维数组
            dw = dw.ravel()
            
            # 更新权重和偏置
            self.weights -= self.learning_rate * dw
            self.bias -= self.learning_rate * db
            
            # 每100个epoch打印一次损失
            if (epoch + 1) % 100 == 0:
                print(f"Epoch {epoch+1}/{self.epochs}, Loss: {loss:.4f}")
    
    def predict(self, X):
        """预测"""
        return np.dot(X, self.weights) + self.bias
    
    def score(self, X, y):
        """计算R²评分"""
        y_pred = self.predict(X)
        # 确保y是一维数组
        y = y.ravel()
        y_pred = y_pred.ravel()
        # 计算总平方和
        ss_total = np.sum((y - np.mean(y)) ** 2)
        # 计算残差平方和
        ss_residual = np.sum((y - y_pred) ** 2)
        # 计算R²
        r2 = 1 - (ss_residual / ss_total)
        return r2

def generate_data(n_samples=100, noise=0.1):
    """生成线性回归数据"""
    # 生成特征
    X = np.random.rand(n_samples, 1) * 10
    # 生成真实标签
    y = 2 * X + 3 + np.random.randn(n_samples, 1) * noise
    return X, y

def main():
    print("=== 线性回归算法实现 ===")
    
    # 生成数据
    X, y = generate_data(n_samples=100, noise=0.5)
    print(f"生成数据形状: X={X.shape}, y={y.shape}")
    
    # 划分训练集和测试集
    split_idx = int(0.8 * len(X))
    X_train, X_test = X[:split_idx], X[split_idx:]
    y_train, y_test = y[:split_idx], y[split_idx:]
    print(f"训练集大小: {X_train.shape[0]}, 测试集大小: {X_test.shape[0]}")
    
    # 初始化模型
    model = LinearRegression(learning_rate=0.01, epochs=1000)
    
    # 训练模型
    print("\n开始训练...")
    model.fit(X_train, y_train)
    
    # 测试模型
    y_pred = model.predict(X_test)
    r2 = model.score(X_test, y_test)
    print(f"\n测试集R²评分: {r2:.4f}")
    
    # 打印模型参数
    print(f"\n模型参数:")
    print(f"权重: {model.weights[0]:.4f}")
    print(f"偏置: {model.bias:.4f}")
    
    # 可视化结果
    plt.figure(figsize=(10, 6))
    # 绘制训练数据
    plt.scatter(X_train, y_train, color='blue', label='训练数据')
    # 绘制测试数据
    plt.scatter(X_test, y_test, color='green', label='测试数据')
    # 绘制预测线
    x_range = np.linspace(0, 10, 100).reshape(-1, 1)
    y_range_pred = model.predict(x_range)
    plt.plot(x_range, y_range_pred, color='red', label='预测线')
    plt.xlabel('X')
    plt.ylabel('y')
    plt.title('线性回归模型')
    plt.legend()
    plt.grid(True)
    plt.savefig('linear_regression_result.png')
    print("\n结果已保存为 linear_regression_result.png")
    
    # 打印预测示例
    print("\n预测示例:")
    for i in range(5):
        x_sample = X_test[i]
        y_true = y_test[i][0]
        y_pred_sample = model.predict(x_sample.reshape(1, -1))[0]
        print(f"X={x_sample[0]:.2f}, 真实值={y_true:.2f}, 预测值={y_pred_sample:.2f}")
    
    print("\n=== 线性回归实现完成 ===")

if __name__ == "__main__":
    main()
相关推荐
碳基硅坊7 小时前
电商场景下的商品自动识别与辅助上架
人工智能
熊猫钓鱼>_>8 小时前
强化学习与决策优化:从理论到工程落地的完整指南
人工智能·llm·强化学习·rl·马尔可夫·mdp·决策过程
-柚子皮-8 小时前
强化学习DPO算法
人工智能
tzc_fly8 小时前
AnisoAlign:各向异性模态对齐
人工智能·深度学习·机器学习
企客宝CRM8 小时前
2026年中小企业CRM选型指南:企客宝CRM处于什么位置?
android·算法·企业微信·rxjava·crm
极客老王说Agent8 小时前
2026供应链智变:实在Agent供应链库存预测助手核心能力与配置深度教程
人工智能·机器学习·ai·chatgpt
刘一说8 小时前
AI热点资讯日报 - 2026年5月15日
人工智能
橙淮8 小时前
二叉树核心概念与Java实现详解
数据结构·算法
冬奇Lab8 小时前
RAG 系列(十七):Agentic RAG——让 Agent 主导检索过程
人工智能·llm·源码