Pytorch 使用报错 RuntimeError: Caught RuntimeError in DataLoader worker process 0.

这个错误是可能是由于在DataLoader的工作进程中尝试访问CUDA设备导致的。PyTorch的DataLoader使用多进程加载数据,而CUDA上下文不能在子进程中直接使用。

修改前的代码为:

复制代码
def prepare_data(file_path):
    # 读取Excel文件
    df = pd.read_excel(file_path, header=None)
    df = df.iloc[1:]
    print(df)

    # 提取特征和标签
    features = df.iloc[:, :-1].values.astype('float32')  # extract feature
    labels = df.iloc[:, -1].values.astype('int64')  # extract label



    # 数据标准化
    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(
        features, labels, test_size=0.2, random_state=42
    )

    # 转换为PyTorch张量并移动到设备
    X_train = torch.tensor(X_train, device=device)
    X_test = torch.tensor(X_test, device=device)
    y_train = torch.tensor(y_train, device=device)
    y_test = torch.tensor(y_test, device=device)

    return X_train, X_test, y_train, y_test, scaler

数据加载修改为下运行OK:

复制代码
class ExcelDataset(Dataset):
    def __init__(self, features, labels):
        # 确保数据在CPU上
        self.features = features.cpu() if features.is_cuda else features
        self.labels = labels.cpu() if labels.is_cuda else labels
相关推荐
赵英英俊36 分钟前
Python day15
开发语言·python
zxsd_xyz1 小时前
基于LabVIEW与Python混合编程的变声器设计与实现
开发语言·python·labview
Danceful_YJ1 小时前
15.手动实现BatchNorm(BN)
人工智能·深度学习·神经网络·batchnorm
wh_xia_jun2 小时前
医疗数据分析中标准化的作用
人工智能·机器学习
李昊哲小课3 小时前
K近邻算法的分类与回归应用场景
python·机器学习·分类·数据挖掘·回归·近邻算法·sklearn
jndingxin3 小时前
OpenCV直线段检测算法类cv::line_descriptor::LSDDetector
人工智能·opencv·算法
胖达不服输3 小时前
「日拱一码」027 深度学习库——PyTorch Geometric(PyG)
人工智能·pytorch·深度学习·pyg·深度学习库
deephub3 小时前
贝叶斯状态空间神经网络:融合概率推理和状态空间实现高精度预测和可解释性
人工智能·深度学习·神经网络·贝叶斯概率·状态空间
壹立科技3 小时前
壹脉销客AI电子名片源码核心架构
人工智能·架构·电子名片
YUQI的博客4 小时前
小白入门:通过手搓神经网络理解深度学习
人工智能·深度学习·神经网络