【Deep-ML系列】Linear Regression Using Gradient Descent(手写梯度下降)

题目链接:Deep-ML

这道题主要是要考虑矩阵乘法的维度,保证维度正确,就可以获得最终的theata

python 复制代码
import numpy as np
def linear_regression_gradient_descent(X: np.ndarray, y: np.ndarray, alpha: float, iterations: int) -> np.ndarray:
    """
    Linear regression
    :param X: m * n
    :param y:
    :param alpha:
    :param iterations:
    :return:
    """
    m, n = X.shape
    theta = np.zeros((n, 1))
    y = y.reshape(m, 1)     # 保证y是列向量
    for i in range(iterations):
        prediction = np.dot(X, theta)   # m * 1
        error = prediction - y          # m * 1
        gradient = np.dot(X.T, error)   # n * 1
        theta = theta - alpha * (1 / m) * gradient
    theta = np.round(theta, decimals=4)
    return theta

if __name__ == '__main__':
    X = np.array([[1, 1], [1, 2], [1, 3]])
    y = np.array([1, 2, 3])
    alpha = 0.01
    iterations = 1000
    print(linear_regression_gradient_descent(X, y, alpha, iterations))
相关推荐
数模竞赛Paid answer1 小时前
2025年MathorCup数学建模A题汽车风阻预测解题文档与程序
算法·数学建模·mathorcup
嵌入式小企鹅1 小时前
CPU供需趋紧、DeepSeek V4全链适配、小米开源万亿模型
人工智能·学习·开源·嵌入式·小米·算力·昇腾
草莓熊Lotso1 小时前
Vibe Coding 时代:LangChain 与 LangGraph 全链路解析
linux·运维·服务器·数据库·人工智能·mysql·langchain
快乐非自愿2 小时前
RAG夺命10连问,你能抗住第几问?
人工智能·面试·程序员
千匠网络4 小时前
破局出海壁垒,千匠网络新能源汽车跨境出海解决方案
人工智能
马丁聊GEO6 小时前
解码AI用户心智,筑牢可信GEO根基——悠易科技深度参与《中国AI用户态度与行为研究报告(2026)》发布会
人工智能·科技
nap-joker6 小时前
Fusion - Mamba用于跨模态目标检测
人工智能·目标检测·计算机视觉·fusion-mamba·可见光-红外成像融合·远距离/伪目标问题
一只幸运猫.7 小时前
2026Java 后端面试完整版|八股简答 + AI 大模型集成技术(最新趋势)
人工智能·面试·职场和发展
Old Uncle Tom7 小时前
OpenClaw 记忆系统 -- 记忆预加载
java·数据结构·算法·agent
Promise微笑7 小时前
2026年国产替代油介损测试仪:油介损全场景解决方案与技术演进
大数据·网络·人工智能