神经网络基础-神经网络补充概念-09-m个样本的梯度下降

概念

当应用梯度下降算法到具有 m 个训练样本的逻辑回归问题时,我们需要对每个样本计算梯度并进行平均,从而更新模型参数。这个过程通常称为批量梯度下降(Batch Gradient Descent)。

代码实现

python 复制代码
import numpy as np

def sigmoid(z):
    return 1 / (1 + np.exp(-z))

def compute_loss(X, y, theta):
    m = len(y)
    h = sigmoid(X.dot(theta))
    loss = (-1/m) * np.sum(y * np.log(h) + (1 - y) * np.log(1 - h))
    return loss

def batch_gradient_descent(X, y, theta, learning_rate, num_iterations):
    m = len(y)
    losses = []
    
    for _ in range(num_iterations):
        h = sigmoid(X.dot(theta))
        gradient = X.T.dot(h - y) / m
        theta -= learning_rate * gradient
        
        loss = compute_loss(X, y, theta)
        losses.append(loss)
        
    return theta, losses

# 生成一些模拟数据
np.random.seed(42)
m = 100
n = 2
X = np.random.randn(m, n)
X = np.hstack((np.ones((m, 1)), X))
theta_true = np.array([1, 2, 3])
y = (X.dot(theta_true) + np.random.randn(m) * 0.2) > 0

# 初始化参数和超参数
theta = np.zeros(X.shape[1])
learning_rate = 0.01
num_iterations = 1000

# 执行批量梯度下降
theta_optimized, losses = batch_gradient_descent(X, y, theta, learning_rate, num_iterations)

# 打印优化后的参数
print("优化后的参数:", theta_optimized)

# 绘制损失函数下降曲线
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel('迭代次数')
plt.ylabel('损失')
plt.title('损失函数下降曲线')
plt.show()

使用了 m 个训练样本,而不是一个。我们首先定义了 sigmoid 函数和计算损失的函数 compute_loss,然后实现了 batch_gradient_descent 函数来执行批量梯度下降。

相关推荐
AC赳赳老秦几秒前
代码生成超越 GPT-4:DeepSeek-V4 编程任务实战与 2026 开发者效率提升指南
数据库·数据仓库·人工智能·科技·rabbitmq·memcache·deepseek
液态不合群2 分钟前
推荐算法中的位置消偏,如何解决?
人工智能·机器学习·推荐算法
饭饭大王6666 分钟前
当 AI 系统开始“自省”——在 `ops-transformer` 中嵌入元认知能力
人工智能·深度学习·transformer
ujainu7 分钟前
CANN仓库中的AIGC可移植性工程:昇腾AI软件栈如何实现“一次开发,多端部署”的跨生态兼容
人工智能·aigc
初恋叫萱萱8 分钟前
CANN 生态实战指南:从零构建一个高性能边缘 AI 应用的完整流程
人工智能
Lethehong11 分钟前
CANN ops-nn仓库深度解读:AIGC时代的神经网络算子优化实践
人工智能·神经网络·aigc
开开心心就好13 分钟前
AI人声伴奏分离工具,离线提取伴奏K歌用
java·linux·开发语言·网络·人工智能·电脑·blender
TechWJ13 分钟前
CANN ops-nn神经网络算子库技术剖析:NPU加速的基石
人工智能·深度学习·神经网络·cann·ops-nn
凌杰13 分钟前
AI 学习笔记:LLM 的部署与测试
人工智能
心疼你的一切14 分钟前
拆解 CANN 仓库:实现 AIGC 文本生成昇腾端部署
数据仓库·深度学习·aigc·cann