代码1实现逻辑回归并保存模型
python
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6],
[0.4, 34.0], [0.8, 62.3]]
data = np.array(data)
x_data = data[:, 0]
y_data = data[:, 1]
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)
print(x_train)
criterion = nn.MSELoss()
model = nn.Sequential(nn.Linear(1, 1))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epochs = 500
for n in range(1, epochs + 1):
epoch_loss = 0
for batch_x, batch_y in dataloader:
y_pred = model(batch_x.unsqueeze(1))
batch_loss = criterion(y_pred.squeeze(1), batch_y)
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
epoch_loss = epoch_loss + batch_loss.item()
avg_loss = epoch_loss / len(dataloader)
if n % 10 == 0 or n == 1:
print(f'epoch:{n}, loss:{epoch_loss}')
torch.save(model, 'entire_model.pth')
运行结果:
python
epoch:1, loss:1784.125
epoch:1, loss:5864.14892578125
epoch:10, loss:1340.3311767578125
epoch:10, loss:2131.239013671875
epoch:20, loss:473.2213439941406
epoch:20, loss:750.4476013183594
epoch:30, loss:146.10401916503906
epoch:30, loss:306.4013214111328
epoch:40, loss:93.8193588256836
epoch:40, loss:161.09381866455078
epoch:50, loss:22.086835861206055
epoch:50, loss:113.13406944274902
epoch:60, loss:66.36778259277344
epoch:60, loss:95.92328262329102
epoch:70, loss:37.97149658203125
epoch:70, loss:90.1829833984375
epoch:80, loss:27.833377838134766
epoch:80, loss:88.26276016235352
epoch:90, loss:18.019649505615234
epoch:90, loss:86.4774284362793
epoch:100, loss:36.294681549072266
epoch:100, loss:86.4249382019043
epoch:110, loss:25.34766960144043
epoch:110, loss:85.6535472869873
epoch:120, loss:71.55767059326172
epoch:120, loss:85.6304121017456
epoch:130, loss:54.34508514404297
epoch:130, loss:85.70595932006836
epoch:140, loss:58.45751953125
epoch:140, loss:85.98778343200684
epoch:150, loss:24.026874542236328
epoch:150, loss:85.38119125366211
epoch:160, loss:31.197525024414062
epoch:160, loss:85.36103820800781
epoch:170, loss:23.816781997680664
epoch:170, loss:85.37735176086426
epoch:180, loss:66.44145202636719
epoch:180, loss:86.07975769042969
epoch:190, loss:49.096153259277344
epoch:190, loss:85.98376846313477
epoch:200, loss:38.83055877685547
epoch:200, loss:87.02980041503906
epoch:210, loss:22.55113410949707
epoch:210, loss:85.6132755279541
epoch:220, loss:60.618438720703125
epoch:220, loss:85.66439247131348
epoch:230, loss:24.166812896728516
epoch:230, loss:85.43827819824219
epoch:240, loss:36.66695022583008
epoch:240, loss:85.72342681884766
epoch:250, loss:50.92716979980469
epoch:250, loss:86.27684783935547
epoch:260, loss:37.27833557128906
epoch:260, loss:85.69609069824219
epoch:270, loss:50.637638092041016
epoch:270, loss:86.2179069519043
epoch:280, loss:60.93098068237305
epoch:280, loss:85.41929817199707
epoch:290, loss:34.782196044921875
epoch:290, loss:85.72705841064453
epoch:300, loss:30.515146255493164
epoch:300, loss:85.36332130432129
epoch:310, loss:33.87446594238281
epoch:310, loss:85.8970718383789
epoch:320, loss:72.44877624511719
epoch:320, loss:85.54687786102295
epoch:330, loss:61.19231414794922
epoch:330, loss:85.43900299072266
epoch:340, loss:48.75373840332031
epoch:340, loss:85.7229118347168
epoch:350, loss:33.820648193359375
epoch:350, loss:85.89838409423828
epoch:360, loss:34.31058883666992
epoch:360, loss:85.568359375
epoch:370, loss:42.243125915527344
epoch:370, loss:86.17256927490234
epoch:380, loss:42.217655181884766
epoch:380, loss:86.17264938354492
epoch:390, loss:29.57950210571289
epoch:390, loss:86.71274185180664
epoch:400, loss:65.79289245605469
epoch:400, loss:86.1904239654541
epoch:410, loss:26.13401222229004
epoch:410, loss:85.56365013122559
epoch:420, loss:28.22481918334961
epoch:420, loss:86.30517959594727
epoch:430, loss:66.69523620605469
epoch:430, loss:85.44241905212402
epoch:440, loss:46.568904876708984
epoch:440, loss:87.0429573059082
epoch:450, loss:49.70370101928711
epoch:450, loss:86.14195251464844
epoch:460, loss:59.23515701293945
epoch:460, loss:85.56475257873535
epoch:470, loss:43.026187896728516
epoch:470, loss:86.16817855834961
epoch:480, loss:62.85814666748047
epoch:480, loss:86.68252944946289
epoch:490, loss:26.73046875
epoch:490, loss:86.20927047729492
epoch:500, loss:25.716747283935547
epoch:500, loss:85.53901290893555
代码2 加载模型进行预测:
python
import torch
entire_model = torch.load('entire_model.pth')
entire_model.eval()
x_test = torch.tensor([1.8], dtype=torch.float32)
with torch.no_grad():
y_pred = entire_model(x_test)
print(y_pred)
print(entire_model)
结果:
python
tensor([93.7325])
Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
)