深度学习之pytorch实现逻辑斯蒂回归

深度学习之pytorch实现逻辑斯蒂回归

解决的问题

logistic 适用于分类问题,这里案例( y为0和1 ,0和 1 分别代表一类)

于解决二分类(0 or 1)问题的机器学习方法,用于估计某种事物的可能性

数学公式

logiatic函数

损失值

代码

也是用y=wx+b的模型来举例,之前的输出y属于实数集合R,现在我们要输出一个一个概率,也就是在区间[0,1]之间。我们就想到需要找出一个映射,把我们之前的输出集合R映射到区间[0,1],他就是函数Sigma,这样我们就轻松的实现了实数集合到0~1之间的映射

python 复制代码
import  torch
import  torch.nn.functional as F
import  numpy as np
import matplotlib.pyplot as plt

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x))#这里需要把原来的输出y传给sigmoid,即实现的区间的映射
        return  y_pred

model = LinearModel()

criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view(200,1)#将数据变成一个二百行一列的矩阵
y_t = model(x_t)
y = y_t.data.numpy()

plt.plot(x,y)
plt.plot([0,10],[0.5,0.5],c='r')
plt.ylabel('probablility of pass')
plt.xlabel('hours')
plt.grid()#画出网格
plt.show()

与线性回归代码的区别

数据

python 复制代码
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])

#线性回归
#x_data = torch.Tensor([[1.0],[2.0],[3.0]])
#y_data = torch.Tensor([[2.0],[4.0],[=6.0]])

损失值

ruby 复制代码
criterion = torch.nn.BCELoss(size_average=False)
#线性回归
#criterion = torch.nn.MSELoss(size_average=False)

构造回归的函数

python 复制代码
import torch.nn.functional as F
y_pred = F.sigmoid(self.linear(x))

#线性回归
#y_pred = self.linear(x)

结果分析

部分结果数据

964 1.1182234287261963

965 1.1176648139953613

966 1.1171066761016846

967 1.1165491342544556

968 1.1159923076629639

969 1.1154361963272095

970 1.1148808002471924

971 1.1143261194229126

972 1.113771915435791

973 1.1132186651229858

974 1.1126658916473389

975 1.1121137142181396

976 1.1115622520446777

977 1.1110115051269531

978 1.1104612350463867

979 1.1099116802215576

980 1.1093629598617554

981 1.1088148355484009

982 1.1082673072814941

983 1.1077203750610352

984 1.1071741580963135

985 1.106628656387329

986 1.106083631515503

987 1.105539321899414

988 1.104995846748352

989 1.1044528484344482

990 1.1039104461669922

991 1.1033687591552734

992 1.1028276681900024

993 1.1022872924804688

994 1.1017472743988037

995 1.101208209991455

996 1.1006698608398438

997 1.1001317501068115

998 1.0995947122573853

999 1.0990580320358276

相关推荐
程序员清洒11 分钟前
CANN模型安全:从对抗防御到隐私保护的全栈安全实战
人工智能·深度学习·安全
island131415 分钟前
CANN ops-nn 算子库深度解析:神经网络计算引擎的底层架构、硬件映射与融合优化机制
人工智能·神经网络·架构
小白|19 分钟前
CANN与实时音视频AI:构建低延迟智能通信系统的全栈实践
人工智能·实时音视频
Kiyra19 分钟前
作为后端开发你不得不知的 AI 知识——Prompt(提示词)
人工智能·prompt
艾莉丝努力练剑22 分钟前
实时视频流处理:利用ops-cv构建高性能CV应用
人工智能·cann
程序猿追22 分钟前
深度解析CANN ops-nn仓库 神经网络算子的性能优化与实践
人工智能·神经网络·性能优化
User_芊芊君子26 分钟前
CANN_PTO_ISA虚拟指令集全解析打造跨平台高性能计算的抽象层
人工智能·深度学习·神经网络
初恋叫萱萱29 分钟前
CANN 生态安全加固指南:构建可信、鲁棒、可审计的边缘 AI 系统
人工智能·安全
机器视觉的发动机34 分钟前
AI算力中心的能耗挑战与未来破局之路
开发语言·人工智能·自动化·视觉检测·机器视觉
铁蛋AI编程实战38 分钟前
通义千问 3.5 Turbo GGUF 量化版本地部署教程:4G 显存即可运行,数据永不泄露
java·人工智能·python