nvlink 训练笔记

目录

还没测试出效果


还没测试出效果

python 复制代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor

# 定义上述的大型全连接层模型
class LargeFullyConnectedModel(nn.Module):
    def __init__(self):
        super(LargeFullyConnectedModel, self).__init__()
        input_size = 10000
        hidden_size1 = 20000
        hidden_size2 = 15000
        hidden_size3 = 12000
        output_size = 5000

        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, hidden_size3)
        self.relu3 = nn.ReLU()
        self.fc4 = nn.Linear(hidden_size3, output_size)

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.relu3(self.fc3(x))
        x = self.fc4(x)
        return x

# 初始化模型并准备多卡环境
devices = [0, 1]  # 指定要使用的显卡编号列表
model = LargeFullyConnectedModel()
if torch.cuda.device_count() > 1 and len(devices) > 1:
    print(f"使用 {len(devices)} 个 GPU 进行推理")
    model = nn.DataParallel(model, device_ids=devices)
else:
    print("仅使用单个 GPU 进行推理")
model.to(torch.device(f"cuda:{devices[0]}" if torch.cuda.is_available() else "cpu"))

# 模拟数据加载(这里只是示例,实际需根据你的数据进行调整)
batch_size = 32
input_size = 10000
data = torch.randn(batch_size, input_size).to(torch.device(f"cuda:{devices[0]}"))
targets = torch.randint(0, 5000, (batch_size,)).to(torch.device(f"cuda:{devices[0]}"))

# 定义推理函数
def inference():
    model.eval()
    with torch.no_grad():
        outputs = model(data)
        # 可以根据需要进行后续处理,如计算损失、准确率等
    return outputs

if __name__ == "__main__":
    inference()
相关推荐
Learn Beyond Limits几秒前
循环神经网络的问题:梯度消失与梯度爆炸|Problems with RNNs: Vanishing and Exploding Gradients
人工智能·rnn·深度学习·神经网络·机器学习·自然语言处理·nlp
John_ToDebug44 分钟前
死锁案例:UI 线程阻塞等待跨进程 COM 注入
c++·windows·笔记
_饭团1 小时前
指针核心知识:5篇系统梳理2
c语言·笔记·学习·leetcode·面试·改行学it
WangJunXiang61 小时前
Nginx性能优化与监控笔记
笔记·nginx·性能优化
四谎真好看1 小时前
Redis学习笔记(实战篇2)
redis·笔记·学习·学习笔记
北岛寒沫2 小时前
北京大学国家发展研究员 中国经济专题 课程笔记(第二课 农村土地改革)
经验分享·笔记·学习
龙腾AI白云2 小时前
数字孪生底层逻辑和技术
深度学习·django·flask·fastapi·tornado
Alsian2 小时前
Day45 神经网络调参
深度学习·神经网络·机器学习
Piccab0o2 小时前
【学习笔记】——电磁相关
笔记·学习
boy快快长大2 小时前
【PyTorch】2.0 入门学习
人工智能·pytorch·学习