逻辑斯特回归

*分类是离散的,回归是连续的

下载数据集

train=True:下载训练集

逻辑斯蒂函数保证输出值在0-1之间

能够把实数值映射到0-1之间

导函数类似正态分布

其他饱和函数sigmoid functions

循环神经网络经常使用tanh函数

与线性回归区别

塞戈马无参数,构造函数无区别

更改损失函数MSE->BCE损失(越小越好)

分布的差异:KL散度,cross-entropy交叉熵

二分类的交叉熵

python 复制代码
# -*- coding: utf-8 -*-
# @Time    : 2023-07-18 20:26
# @Author  : yuer
# @FileName: exercise06.py
# @Software: PyCharm
import matplotlib.pyplot as plt
import numpy as np
import torch

# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])


# 先根据x算出y值再根据y的范围找到分类

class logisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(logisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
        # x_data,y_data都是一维,与线性回归相比构造没有函数区别

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred


model = logisticRegressionModel()

# 默认情况size_average=True 即loss是1/n倍的,False设置loss不除n
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# SGD梯度下降优化方法 初始化w,b都为0

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()  # 清空梯度
    loss.backward()  # 反馈算梯度并更新
    optimizer.step()  # 更新w,b的值

print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=', y_test.data)

x = np.linspace(0, 10, 200)  # 在线性空间中以均匀步长生成数字序列;在0-10之间的200个点
x_t = torch.Tensor(x).view((200, 1))  # 转换为200*1的矩阵
y_t = model(x_t)  # 利用模型训练
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
相关推荐
小白狮ww7 分钟前
开源新基准!OmniGen2 文本图像对齐度提升 8.6%,视觉一致性超越现有开源模型15%
人工智能·机器学习·开源
算家计算8 分钟前
阿里开源最强编程模型Qwen3-Coder!超越GPT-4.1,登顶开源榜首
人工智能·ai编程·资讯
老周聊大模型9 分钟前
破界协同:企业级HITL AI Agent双闭环架构解密——让机器智能与人类智慧共舞
人工智能·机器学习·程序员
说私域16 分钟前
开源AI智能客服、AI智能名片与S2B2C商城小程序在客户复购与转介绍中的协同效应研究
人工智能·小程序·开源
NineData27 分钟前
NineData新增SQL Server到MySQL复制链路,高效助力异构数据库迁移
数据库·人工智能·mysql
Code_流苏1 小时前
ChatGPT Agent深度解析:告别单纯问答,一个指令搞定复杂任务?
人工智能·自然语言处理·chatgpt·openai·agent·智能体
别摸我的婴儿肥1 小时前
从0开始LLM-注意力机制-4
人工智能·python·算法
聚客AI1 小时前
LLM→RAG→Agent→Training的企业级AI应用落地分层实施路线图
人工智能·llm·agent
芒果快进我嘴里1 小时前
opencv-图像处理
人工智能·opencv·计算机视觉
新智元1 小时前
金牌模型三位核心华人光速离职!谷歌IMO夺金24h即遭小扎闪电抄家
人工智能·openai