复制代码
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import pandas as pd
import os
class custom_dataset(Dataset):
def __init__(self, root_path, file_path, transform, target_transform):
self.data = pd.read_csv(os.path.join(root_path, file_path))
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
data = row['data']
target = row['target']
if self.transform:
data = self.transform(data)
if self.target_transform:
target = self.target_transform(target)
return data, target
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
复制代码
import torch
from torch import nn
class network(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.linear = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.ReLU()
)
def forward(self, x):
x = self.linear(x)
return x
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
model = network(input_dim=8, hidden_dim=16, output_dim=8).to(torch.device('cpu'))
learning_rate = 1e-3
batch_size = 64
epochs = 5
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
def train(dataloader, model, loss_fn, optimizer, device):
model.train()
for batch, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss = loss.item()
print(f'train_loss:{loss}')
def test(dataloader, model, loss_fn):
num_batches = len(dataloader)
model.eval()
with torch.no_grad():
for x, y in dataloader:
pred = model(x)
test_loss += loss_fn(pred, y).item()
test_loss /= num_batches
print(f'test_loss:{test_loss}')
for t in range(epochs):
print(f'epoch {t+1}\n---------------------------')
train(train_dataloader, model, loss_fn, optimizer, device)
test(test_dataloader, model, loss_fn)