深度学习:利用随机数据更快地测试一个新的模型在自己数据格式很复杂的时候

技巧:

比如下面一个新的模型deeponet ,我自己的数据很复杂,这里在代码最后用用随机生成的数据,两分钟就完成了代码的测试成功。

复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 带偏置项的 DeepONet 结构,包括 Branch 和 Trunk 网络
class DeepONet(nn.Module):
    def __init__(self, branch_input_dim, trunk_input_dim, hidden_dim):
        super(DeepONet, self).__init__()
        
        # Branch 网络,用于处理输入点云的特征(例如位移量、压强)
        self.branch_net = nn.Sequential(
            nn.Linear(branch_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Trunk 网络,用于处理时间和空间坐标 [x, y, z, t]
        self.trunk_net = nn.Sequential(
            nn.Linear(trunk_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 偏置项 bias
        self.bias = nn.Parameter(torch.zeros(1))  # 可训练的偏置项
        
        # 最终的输出层,预测位移或压强等物理状态
        self.fc_output = nn.Linear(hidden_dim, 3)
    
    def forward(self, point_features, coord_time):
        # Branch网络的输出
        branch_output = self.branch_net(point_features)
        
        # Trunk网络的输出
        trunk_output = self.trunk_net(coord_time)
        
        # 将 Branch 和 Trunk 的输出结合,计算最终的输出
        combined = branch_output * trunk_output
        output = self.fc_output(combined) + self.bias  # 加上偏置项
        
        return output

# 数据准备
# 输入的数据格式:
# point_features:3D点云的物理特征(例如位移量 pointDisplacement、压强 p)
# coord_time:空间位置和时间 [x, y, z, t]

# 示例数据的维度设置
branch_input_dim = 3  # 例如 [pointDisplacement, p, ...] 
trunk_input_dim = 4   # [x, y, z, t]
hidden_dim = 64       # 隐藏层维度,可根据需求调整

# 模型初始化
model = DeepONet(branch_input_dim, trunk_input_dim, hidden_dim)

# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练流程
def train(model, point_features, coord_time, target, epochs=1000):
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # 前向传播
        output = model(point_features, coord_time)
        
        # 计算损失
        loss = criterion(output, target)
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item()}")

# 示例数据,实际应用时需要替换为真实数据
N = 1000  # 样本数量
point_features = torch.randn(N, branch_input_dim)  # 3D点云的物理特征
coord_time = torch.randn(N, trunk_input_dim)       # [x, y, z, t]
target = torch.randn(N, 3)                         # 目标物理状态

# 训练模型
train(model, point_features, coord_time, target, epochs=1000)

# 推理:给定新的时空点,预测物理状态
def predict(model, point_features, coord_time):
    model.eval()
    with torch.no_grad():
        prediction = model(point_features, coord_time)
    return prediction

# 示例推理
new_point_features = torch.randn(1, branch_input_dim)
new_coord_time = torch.tensor([[0.5, 0.5, 0.5, 0.1]])  # 在 t=0.1 的 (0.5, 0.5, 0.5) 空间点
prediction = predict(model, new_point_features, new_coord_time)
print("Predicted state:", prediction)

输出如下:

复制代码
Epoch 0, Loss: 1.0260347127914429
Epoch 100, Loss: 0.7669863104820251
Epoch 200, Loss: 0.5786211490631104
Epoch 300, Loss: 0.4749055504798889
Epoch 400, Loss: 0.41076529026031494
Epoch 500, Loss: 0.36538082361221313
Epoch 600, Loss: 0.39494913816452026
Epoch 700, Loss: 0.30206459760665894
Epoch 800, Loss: 0.2839098572731018
Epoch 900, Loss: 0.2648167908191681
Predicted state: tensor([[-0.2604,  0.2214,  0.5066]])

Process finished with exit code 0
相关推荐
JoannaJuanCV13 分钟前
BEV和OCC学习-5:数据预处理流程
深度学习·目标检测·3d·occ·bev
KKKlucifer13 分钟前
当AI遇上防火墙:新一代智能安全解决方案全景解析
人工智能
DisonTangor1 小时前
【小红书拥抱开源】小红书开源大规模混合专家模型——dots.llm1
人工智能·计算机视觉·开源·aigc
浠寒AI2 小时前
智能体模式篇(上)- 深入 ReAct:LangGraph构建能自主思考与行动的 AI
人工智能·python
weixin_505154463 小时前
数字孪生在建设智慧城市中可以起到哪些作用或帮助?
大数据·人工智能·智慧城市·数字孪生·数据可视化
Best_Me073 小时前
深度学习模块缝合
人工智能·深度学习
YuTaoShao3 小时前
【论文阅读】YOLOv8在单目下视多车目标检测中的应用
人工智能·yolo·目标检测
算家计算4 小时前
字节开源代码模型——Seed-Coder 本地部署教程,模型自驱动数据筛选,让每行代码都精准落位!
人工智能·开源
伪_装4 小时前
大语言模型(LLM)面试问题集
人工智能·语言模型·自然语言处理
gs801404 小时前
Tavily 技术详解:为大模型提供实时搜索增强的利器
人工智能·rag