LDU机器学习大作业TCR-抗原结合预测

目录标题

要求

需要使用机器学习的方法,训练模型,完成生物分子相互作用的预测任务,并撰写报告。

  1. 模型性能判别的标准是AUC数值大小。
  2. 需要在代码文件中注释用的什么机器学习方法,每个函数功能需要注释(输入,输出)
  3. 要求将训练好的模型保存好并提交,所有的文件放在model文件夹中,最后测试时直接调用这个模型进行测试即可
  4. 代码文件的最下面放训练模型和预测模型的函数,命名为train、predict

要有main函数 ,直接运行main函数可以导入训练好的模型并测试测试集的性能,最后打印输出测试集的AUC值。测试集文件只提供名称,代码提交时设置好即可
数据内容 :本研究的数据集由 TCR(T细胞受体)与Epitope(抗原表位)序列配对组成。每条样本包含一个 Epitope 序列、一个 CDR3β 序列及其对应标签。
标签:标签 1 表示该配对具有真实结合关系,0 表示随机生成的非结合样本。

1、根据要求配置环境

进入conda创建环境python=3.11.10

python 复制代码
conda create -n ML python=3.11.10

激活环境

python 复制代码
conda activate ML

创建torch==2.5.1版本

python 复制代码
pip install torch==2.5.1 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

配置其他库,因为比较少就一个一个安装了

python 复制代码
pip install numpy==2.1.3
pip install pandas==2.2.3
pip install tqdm
pip install scikit-learn==1.5.2
pip install subword-nmt

查看配置结果

python 复制代码
pip list

2、对已给文件进行分析

2.1 简单阅读论文了解背景

  • 本课题研究了T细胞受体简称TCR抗原表位简称Epitope的结合关系,这种结合关系是免疫系统的基础,准确的预测谁与谁的结合,对于了解免疫系统的疾病和开发针对免疫系统的疫苗至关重要。
  • 现有的传统方法只关注TCR和Epitope的序列 ,忽视了他们在人体内形成的复杂的相互作用网络及拓扑机构
  • 本文提出了GTE一种图学习框架 ,这是一个新型的图神经网络模型,他不再是孤立地看待每一对序列,而是将所有序列看做图中的节点,通过学习这个网络拓扑结构来预测关系。
  • 本文创新点引入了动态边缘更新来解决负样本不足和Deep AUC最大化来解决数据不平衡优化模型性能。

2.2 train.csv和vaild.csv 数据集分析

CDR3.beta: T 细胞受体 β \beta β 链的 CDR3 序列 (TCR CDR3B sequence)。这是 T 细胞受体上用于识别抗原表位的关键区域序列
Epitope: 抗原表位序列 (Epitope sequence)。这是被 T 细胞识别的外来分子序列 。
Label: 标签 (Label)。表示该 Epitope-CDR3B 配对是否具有真实的结合关系 。

2.3 了解基本实现

2.3.1 data_processing.py

这是函数是从csv文件中加载训练数据和测试数据,从一个pickle文件加载节点的嵌入向量,转换成PyTorch Geometric 可以使用的图格式。图中节点表示 "Epitope"(抗原决定簇)和"CDR3.beta"(T 细胞受体的一个区域),边表示它们之间的相互作用,标签 y 表示这种相互作用的类型或强度。

python 复制代码
import torch
import pickle
import numpy as np
from torch_geometric.data import Data
import pandas as pd
from arg_parser import parse_args
import os
import yaml


def TEINet_embeddings_5fold(config_path):
    """
    加载数据、节点的嵌入向量,并为 TEINet 模型将数据转换为 PyTorch Geometric 的 Data 对象。
    该函数通常用于 K-fold 交叉验证的其中一折(fold)。

    Args:
        config_path (str): YAML 配置文件路径,包含数据路径、嵌入路径和训练/测试文件列表。

    Returns:
        list: 包含两个 PyTorch Geometric Data 对象的列表:[训练集 Data, 测试集 Data]。
    """
    
    # 假设 parse_args() 函数返回的参数对象包含 'split' 属性
    # 'split' 用于指定在配置文件中加载哪个交叉验证折(例如 'fold0', 'fold1' 等)
    args = parse_args() 

    # --- 1. 加载配置和嵌入 ---
    
    with open(config_path) as file:
        # 从 YAML 文件加载配置字典
        config = yaml.safe_load(file)

    with open(config["embeddings_path"], 'rb') as f:
        # 从指定的路径加载预计算的节点嵌入字典
        # 键为 Epitope 或 CDR3.beta 序列,值为对应的嵌入向量(如 numpy 数组)
        embedding_dict = pickle.load(f)

    # 从配置中获取当前折的训练和测试文件列表
    # args.split 决定使用哪个折(如 config['fold0'])
    train_file_list = config[args.split]['train_data']['file_list']
    test_file_list = config[args.split]['test_data']['file_list']
    # 获取数据文件的基本路径
    file_path = config['path']

    # --- 2. 加载和合并训练数据 ---
    
    train_data = []
    for file_name in train_file_list:
        # 读取单个训练数据文件 (CSV)
        data = pd.read_csv(os.path.join(file_path, file_name))
        train_data.append(data)
    # 将所有训练数据文件合并成一个 DataFrame
    train_data = pd.concat(train_data)


    # --- 3. 加载和合并测试数据 ---
    
    test_data = []
    for file_name in test_file_list:
        # 读取单个测试数据文件 (CSV)
        data = pd.read_csv(os.path.join(file_path, file_name))
        test_data.append(data)
    # 将所有测试数据文件合并成一个 DataFrame
    test_data = pd.concat(test_data) 

    # --- 4. 处理数据并转换为 PyG Data 对象 ---
    
    all_data = []
    # 依次处理训练数据和测试数据
    for data in [train_data, test_data]:
        node_index = {} # 字典:用于将节点名称(序列)映射到图中的唯一整数索引
        num_nodes = 0   # 图中节点的总数
        edge_list = []  # 边的列表,存储为 (node_idx_1, node_idx_2) 元组
        X = []          # 节点的特征矩阵列表(存储嵌入向量)
        y_list = []     # 边的标签列表(存储 'Label')
        
        # 遍历 DataFrame 中的每一行,每一行代表一个相互作用(即图中的一条边)
        for _, row in data.iterrows():
            label = float(row["Label"])
            # 相互作用的两个节点
            nodes = [row["Epitope"], row["CDR3.beta"]] 
            
            # 遍历两个节点,为它们分配唯一的索引并收集嵌入
            for node in nodes:
                if node not in node_index:
                    # 如果节点是新发现的,分配一个新索引
                    node_index[node] = num_nodes
                    num_nodes += 1
                    # 从嵌入字典中获取该节点的嵌入向量,并加入特征列表 X
                    X.append(embedding_dict[node]) 
            
            # 收集该相互作用(边)的标签
            y_list.append(label)
            # 记录这条边(使用节点的整数索引)
            edge_list.append((node_index[nodes[0]], node_index[nodes[1]]))

        
        # 将特征列表 X 转换为 PyTorch tensor
        X = torch.tensor(np.array(X), dtype=torch.float)
        # 将边列表转换为 PyTorch Geometric 所需的格式 (2 x num_edges)
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        # 将标签列表转换为 PyTorch tensor
        y = torch.tensor(y_list, dtype=torch.float)

        # 创建 PyTorch Geometric 的 Data 对象,存储图的结构和特征
        all_data.append(Data(x=X, edge_index=edge_index, y=y, num_nodes=num_nodes))

    return all_data

返回的列表包含两个data对象[Train_Graph,Test_Graph]

每个data封装了关键属性
data.x: 节点特征矩阵。图上所有不重复的 Epitope 和 CDR3.beta 序列的嵌入向量(N,f)
data.dege_index: 边连接矩阵。表示图中的所有相互作用(边)。存储了源节点 ID 到 目标节点 ID 的映射(2,E)
data.y: 边标签向量。与 edge_index 中的每条边一一对应,表示该相互作用的标签(0 或 1)(E)
data.num_nodes: 节点总数。图中不重复序列的总数量 N N N

处理步骤:

  • 原始的csv文件输入,假如:

  • 序列嵌入字典

    每个节点都对应一个自己独有的向量,假如:

    data.x=嵌入向量组合

  • 首先给节点分配一个数字C1=0,P1=1,C2=2,P2=3

    处理前:(C1,P1)(C2,P2) (C2,P1)

    处理后:

    data.edge_index=[0 2 2 / 1 3 1]

  • 标签关系 1 0 1

  • Train_Graph = Data ( x = 张量 ( N × F ) , / / 所有序列的身份卡 edge_index = 张量 ( 2 × E ) , / / 相互作用的连接地图 y = 张量 ( E ) / / 相互作用的标签 ) \text{Train\_Graph} = \text{Data}( \text{x}=\text{张量}(N \times F), \quad // \text{所有序列的身份卡} \\ \text{edge\_index}=\text{张量}(2 \times E), \quad // \text{相互作用的连接地图} \\ \text{y}=\text{张量}(E) \quad // \text{相互作用的标签} \\ ) Train_Graph=Data(x=张量(N×F),//所有序列的身份卡edge_index=张量(2×E),//相互作用的连接地图y=张量(E)//相互作用的标签)

2.3.2 create_embeding.py

细读data_processing.py发现给出的文件没有embeddings.pkl嵌入文件也没有配置yaml文件,并且还没有arg_parser.py参数文件。

处理生成embeddings.pkl嵌入文件步骤:

  1. 首先安装esm库

    python 复制代码
    pip install fair-esm 
  2. 生成嵌入向量库create_embeding.py

    python 复制代码
    import torch
    import esm
    import pandas as pd
    import numpy as np
    import pickle
    import os
    from tqdm import tqdm
    
    # --- 1. 配置参数 ---
    # 替换为您的实际文件路径 (假设文件位于 'dataset' 文件夹)
    TRAIN_FILE = "dataset/train.csv"
    VALID_FILE = "dataset/valid.csv"
    OUTPUT_DIR = "models"
    OUTPUT_PKL_PATH = os.path.join(OUTPUT_DIR, "my_project_embeddings.pkl")
    
    
    # --- 2. 序列提取与准备 ---
    def get_unique_sequences(train_path, valid_path):
        """提取所有不重复的 Epitope 和 CDR3.beta 序列,并返回 ESM 要求的格式。"""
        print("-> 正在读取并提取所有独特的序列...")
        train_df = pd.read_csv(train_path)
        valid_df = pd.read_csv(valid_path)
    
        all_sequences = set()
    
        # 提取所有独特的序列
        all_sequences.update(train_df['Epitope'].astype(str).unique())
        all_sequences.update(valid_df['Epitope'].astype(str).unique())
        all_sequences.update(train_df['CDR3.beta'].astype(str).unique())
        all_sequences.update(valid_df['CDR3.beta'].astype(str).unique())
    
        # 清理和过滤掉无效/缺失的序列
        unique_sequences = {seq for seq in all_sequences if seq and seq != 'nan'}
    
        print(f"-> 共找到 {len(unique_sequences)} 个独特的序列。")
    
        # 将序列转换为 ESM 批量编码所需的格式: [("name1", "SEQ1"), ("name2", "SEQ2"), ...]
        # 我们使用序列本身作为 'name'
        esm_data_input = [(seq, seq) for seq in unique_sequences]
    
        return esm_data_input, unique_sequences
    
    
    # --- 3. 嵌入生成与保存函数 ---
    def generate_and_save_embeddings(esm_data_input, unique_sequences):
        # --- A. 加载模型 ---
        print("-> 正在加载 ESM-2 650M 模型...")
        # 使用您代码中指定的强大模型
        model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        batch_converter = alphabet.get_batch_converter()
        model.eval()
    
        # 自动选择 CUDA 或 CPU
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"-> 正在使用设备: {device}")
        model = model.to(device)
    
        # --- B. 批量编码 ---
        # ESM 编码器一次处理所有序列可能导致内存溢出。
        # 我们按批次(Batch)处理,这里设置一个合理的批次大小
        BATCH_SIZE = 128
    
        embedding_dict = {}
    
        # 启用 tqdm 进度条,并按批次进行处理
        for i in tqdm(range(0, len(esm_data_input), BATCH_SIZE), desc="生成嵌入批次"):
            batch_data = esm_data_input[i:i + BATCH_SIZE]
    
            # 1. 转换数据为 tokens
            batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
            batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
            batch_tokens = batch_tokens.to(device)  # 将 token 移动到 GPU
    
            # 2. 提取残基表示 (Per-residue representations)
            with torch.no_grad():
                # repr_layers=[33] 确保我们使用第33层(即最后一层)的表示
                results = model(batch_tokens, repr_layers=[33], return_contacts=False)
    
            # 获取第33层的 tokens 嵌入
            token_representations = results["representations"][33].cpu()  # 移回 CPU 处理
    
            # 3. 核心步骤:生成序列级表示 (Per-sequence representations)
            # 通过对残基嵌入求平均,将可变长度的序列转换为固定长度的向量
    
            for j, tokens_len in enumerate(batch_lens):
                # NOTE: token 0 是 BOS (序列开始) token,最后一个 token 是 EOS (序列结束) token
                # 我们对从 token 1 到 tokens_len - 1 的残基嵌入求平均
    
                # (tokens_len.item() - 1) 是 EOS 的位置
                sequence_embedding = token_representations[j, 1: tokens_len.item() - 1].mean(0).numpy()
    
                # 找到对应的原始序列字符串
                original_seq = batch_data[j][1]
    
                # 保存到字典
                embedding_dict[original_seq] = sequence_embedding
    
        # --- C. 保存文件 ---
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        with open(OUTPUT_PKL_PATH, 'wb') as f:
            pickle.dump(embedding_dict, f)
    
        # 4. 验证保存结果
        embedding_dim = embedding_dict[list(embedding_dict.keys())[0]].shape[0]
        print("\n==============================================")
        print(f"✅ 嵌入生成成功!文件已保存到: {OUTPUT_PKL_PATH}")
        print(f"✅ 字典中包含 {len(embedding_dict)} 个序列的嵌入。")
        print(f"✅ 每个嵌入向量的维度 (F) 为: {embedding_dim}")
        print("==============================================")
    
    
    # --- 4. 主程序执行 ---
    if __name__ == "__main__":
        esm_data_input, unique_sequences = get_unique_sequences(TRAIN_FILE, VALID_FILE)
        generate_and_save_embeddings(esm_data_input, unique_sequences)
  3. 检查是否生成

2.3.3 配置yaml文件
yaml 复制代码
# ----------------------------------------------------------------------
# 全局配置 (Global Configuration)
# ----------------------------------------------------------------------

# 项目数据集名称,便于识别
dataset_name: MyDataSets

# 数据文件所在的根目录。
# 假设您的 train.csv 和 valid.csv 文件位于项目根目录下的 'dataset' 文件夹中。
path: dataset 

# 模型保存路径。请根据您的需要修改文件名。
save_model: results/MyProject_FinalModel.pth

# 预计算嵌入文件的路径。
# 必须指向您上一步使用 ESM-2 生成的那个文件。
embeddings_path: models/my_project_embeddings.pkl 


# ----------------------------------------------------------------------
# 默认划分配置 (Default Split Configuration)
# ----------------------------------------------------------------------
# 
# 这是一个单一的、固定的划分方式,用于读取 train.csv 和 valid.csv
# 在您的程序中,您需要将 'default_split' 作为参数传递给数据加载函数。
#
default_split:
  train_data:
    # 训练集文件列表,只有一个 train.csv
    file_list:
      - train.csv
  test_data:
    # 测试集文件列表,使用 valid.csv 作为验证/测试集
    file_list:
      - valid.csv

3、项目设计model

11月1日写了一点,哈哈今天是11月30日又开始了。

3.1、生成嵌入脚本create_embeding.py

使用上面的脚本生成train.csv和val.csv的嵌入pkl文件。如下图这样的。

3.2、加载数据和构件图脚本data_loader.py

我们修改上面的加载方式,修改成单纯的加载构建模块,把配置模块拆分出来。

3.2.1 加载已经生成的嵌入文件
python 复制代码
def load_embeddings(embeddings_path):
    """
    从pickle文件加载预先计算的序列嵌入。
    
    Args:
        embeddings_path: 嵌入pickle文件的路径
        
    Returns:
        dict: 字典将序列映射到嵌入向量
    """
    with open(embeddings_path, 'rb') as f:
        embeddings = pickle.load(f)
    return embeddings
3.2.2 根据嵌入文件从csv文件中构建图
python 复制代码
def build_graph_from_csv(csv_path, embeddings_dict):
    """
    从CSV文件中构建图
    
    CSV format: columns=['Epitope', 'CDR3.beta', 'Label']
    
    Args:
        csv_path: CSV文件路径
        embeddings_dict: 字典嵌入序列
        
    Returns:
        torch_geometric.data.Data: 图的数据格式
    """
    # Read CSV
    data = pd.read_csv(csv_path)
    
    # 初始化图组件
    node_index = {}  # sequence -> node_id
    num_nodes = 0
    edge_list = []
    X = []  # Node features
    y_list = []  # Edge labels
    
    # 构件图
    for _, row in data.iterrows():
        label = float(row['Label'])
        sequences = [row['Epitope'], row['CDR3.beta']]
        
        # 创建不存在的节点
        for seq in sequences:
            if seq not in node_index:
                if seq not in embeddings_dict:
                    print(f"Warning: Sequence '{seq}' not found in embeddings, skipping...")
                    break
                node_index[seq] = num_nodes
                num_nodes += 1
                X.append(embeddings_dict[seq])
        else:
            # 当两个节点创建时,创建边
            y_list.append(label)
            edge_list.append((node_index[sequences[0]], node_index[sequences[1]]))
    
    # 转换张量
    X = torch.tensor(np.array(X), dtype=torch.float)
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    y = torch.tensor(y_list, dtype=torch.float)
    
    return Data(x=X, edge_index=edge_index, y=y, num_nodes=num_nodes)
3.2.3 加载数据
python 复制代码
def load_data(train_path, valid_path, embeddings_path):
    """
    加载数据
    
    Args:
        train_path: Path to training CSV
        valid_path: Path to validation CSV
        embeddings_path: Path to embeddings file
        
    Returns:
        tuple: (train_data, valid_data) 作为PyG数据对象
    """
    print("加载嵌入序列...")
    embeddings = load_embeddings(embeddings_path)
    print(f"加载 {len(embeddings)} 嵌入序列")
    
    print("构建训练图...")
    train_data = build_graph_from_csv(train_path, embeddings)
    print(f"train: {train_data.num_nodes} nodes, {train_data.num_edges} edges")
    
    print("构建验证图...")
    valid_data = build_graph_from_csv(valid_path, embeddings)
    print(f"Valid: {valid_data.num_nodes} nodes, {valid_data.num_edges} edges")
    
    return train_data, valid_data

3.3、配置模块config.py

需要的一些路径,模型参数单独放在一起,训练、测试时引入

python 复制代码
"""
配置文件
"""

# 数据路径
TRAIN_DATA_PATH = 'dataset/train.csv'
VALID_DATA_PATH = 'dataset/valid.csv'
MODEL_SAVE_PATH = 'dataset/best_model.pth'
EMBEDDINGS_PATH = 'dataset/my_project_embeddings.pkl'  # 嵌入文件

# 模型参数
EMBEDDING_DIM = 320  # 序列嵌入维度
HIDDEN_DIM = 128     # 隐藏层尺寸
MLP_HIDDEN_DIM = 256 # MLP隐藏维度
DROPOUT_RATE = 0.1   # Dropout

# 训练参数
LEARNING_RATE = 0.001
EPOCHS = 100
BATCH_SIZE = 1       # 图形级批处理
WEIGHT_DECAY = 5e-4

# 损失权重
BCE_WEIGHT = 1.0     # 二元交叉熵损失权重
AUC_WEIGHT = 0.0     # AUC损失

# 不平衡处理
POSITIVE_WEIGHT = 1.0  # 正类权重的乘数

# 设置
DEVICE = 'cuda'
GPU_ID = 0

# 随机种子
RANDOM_SEED = 42

# 评估
SAVE_BEST_MODEL = True  # 是否保存

3.4、最重要的模型model.py

3.4、训练脚本train.py

3.5、预测脚本predict.py

3.6、配置测试集预测脚本main.py

4、测试结果

相关推荐
ByteCraze38 分钟前
如何处理大模型幻觉问题?
前端·人工智能·深度学习·机器学习·node.js
LCG米39 分钟前
实战:基于ESP32-S3的微型边缘AI计算棒设计,实现低成本图像识别
人工智能
丝斯201140 分钟前
AI学习笔记整理(23)—— AI核心技术(深度学习7)
人工智能·笔记·学习
双木的木40 分钟前
Coggle数据科学 | 并行智能体:洞察复杂系统的 14 种并发设计模式
运维·人工智能·python·设计模式·chatgpt·自动化·音视频
LitchiCheng40 分钟前
Mujoco 机械臂 OMPL 进行 RRT 关节空间路径规划避障、绕障
开发语言·人工智能·python
三年呀43 分钟前
深入探索量子机器学习:原理、实践与未来趋势的全景剖析
人工智能·深度学习·机器学习·量子计算
阿杰学AI43 分钟前
AI核心知识22——大语言模型之重要参数Top-P(简洁且通俗易懂版)
人工智能·ai·语言模型·aigc·模型参数·top-p
腾讯云开发者43 分钟前
架构火花|35岁程序员该做些什么:留在国企vs切换赛道
人工智能
Christo344 分钟前
AAAI-2013《Spectral Rotation versus K-Means in Spectral Clustering》
人工智能·算法·机器学习·数据挖掘·kmeans