关于MediaEval数据集的Dataset构建(Text部分-使用PLM BERT)

python 复制代码
import random
import numpy as np
import pandas as pd
import torch
from transformers import BertModel,BertTokenizer
from tqdm.auto import tqdm
from torch.utils.data import Dataset
import re
"""参考Game-On论文"""
"""util.py"""
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    # 用于设置生成随机数的种子
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
"""util.py"""

"""文本预处理-textGraph.py"""
# 文本DataSet类

def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = re.sub(r'(@.*?)[\s]', ' ', text)

    # Replace '&' with '&'
    text = re.sub(r'&', '&', text)

    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    # removes links
    text = re.sub(r'(?P<url>https?://[^\s]+)', r'', text)

    # remove @usernames
    text = re.sub(r"\@(\w+)", "", text)

    # remove # from #tags
    text = text.replace('#', '')

    return text

class TextDataset(Dataset):
    def __init__(self,df,tokenizer):
        # 包含推文的主文件框架
        self.df = df.reset_index(drop=True)

        # 使用的分词器
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # 帖子的文本内容
        text = self.df['tweetText'][idx]
        # 作为唯一标识符的id 'tweetId'
        unique_id = self.df['tweetId'][idx]

        # 创建一个空的列表来存储输出结果
        input_ids = []
        attention_mask = []
        # 使用tokenizer分词器
        encoded_sent = self.tokenizer.encode_plus(
            text = text_preprocessing(text), # 这里使用的是预处理的句子,而不是直接对原句子使用tokenizer
            add_special_tokens=True,        # 添加[CLS]以及[SEP]等特殊词元
            max_length=512,                 # 最大截断长度
            padding='max_length',            # padding的最大长度
            return_attention_mask=True,     # 返回attention_mask
            truncation=True                 #
        )
        # 获取编码效果
        input_ids = encoded_sent.get('input_ids')
        # 获取attention_mask结果
        attention_mask = encoded_sent.get('attention_mask')

        # 将列表转换成张量
        input_ids = torch.tensor(input_ids)
        attention_mask =torch.tensor(attention_mask)

        return {'input_ids':input_ids,'attention_mask':attention_mask,'unique_id':unique_id}

def store_data(bert,device,df,dataset,store_dir):
    lengths = []
    bert.eval()

    for idx in tqdm(range(len(df))):
        sample = dataset.__getitem__(idx)
        print('原始sample[input_ids]和sample[attention_mask]的维度:',sample['input_ids'].shape,sample['attention_mask'].shape)
        # 升维
        input_ids,attention_mask = sample['input_ids'].unsqueeze(0),sample['attention_mask'].unsqueeze(0)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        # 得到唯一标识属性
        unique_id = sample['unique_id']

        # 计算token的个数
        num_tokens = attention_mask.sum().detach().cpu().item()
        """不生成新的计算图,而是只做权重更新"""
        with torch.no_grad():
            out = bert(input_ids=input_ids,attention_mask=attention_mask)
        # last_hidden_state.shape是(batch_size,sequence_length,hidden_size)
        out_tokens = out.last_hidden_state[:,1:num_tokens,:].detach().cpu().squeeze(0).numpy() # token向量

        # 保存token级别表示
        filename = f'{emed_dir}{unique_id}.npy'

        try:
            np.save(filename, out_tokens)
            print(f"文件{filename}保存成功")
        except FileNotFoundError:
            # 文件不存在,创建新文件并保存
            np.save(filename, out_tokens)
            print(f"文件{filename}创建成功并保存成功")
        lengths.append(num_tokens)

        ## Save semantic/ whole text representation
        # 保存语义  也就是整个文本的表示
        out_cls = out.last_hidden_state[:,0,:].unsqueeze(0).detach().cpu().squeeze(0).numpy() ## cls vector
        filename = f'{emed_dir}{unique_id}_full_text.npy'
        # 尝试保存.npy文件,如果文件不存在则自动创建
        try:
            np.save(filename, out_cls)
            print(f"文件{filename}保存成功")
        except FileNotFoundError:
            # 文件不存在,创建新文件并保存
            np.save(filename, out_cls)
            print(f"文件{filename}创建成功并保存成功")
    return lengths

if __name__=='__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 根目录
    root_dir = "./dataset/image-verification-corpus-master/image-verification-corpus-master/mediaeval2015/"
    emed_dir = './Embedding_File'
    # 文件路径
    train_csv_name = "tweetsTrain.csv"
    test_csv_name = "tweetsTest.csv"

    # 加载PLM和分词器
    tokenizer = BertTokenizer.from_pretrained('./bert/')
    bert = BertModel.from_pretrained('./bert/', return_dict=True)
    bert = bert.to(device)

    # 用于存储每个推文的Embedding
    store_dir ="Embed_Post/"

    # 创建训练数据集的Embedding表示
    df_train = pd.read_csv(f'{root_dir}{train_csv_name}')
    df_train = df_train.dropna().reset_index(drop=True)

    # 训练数据集的编码结果
    train_dataset = TextDataset(df_train,tokenizer)
    lengths = store_data(bert, device, df_train, train_dataset, store_dir)

    ## Create graph data for testing set
    # 为测试集创建Embedding表示
    df_test = pd.read_csv(f'{root_dir}{test_csv_name}')
    df_test = df_test.dropna().reset_index(drop=True)
    test_dataset = TextDataset(df_test, tokenizer)

    lengths = store_data(bert, device, df_test, test_dataset, store_dir)

"""文本预处理-textGraph.py"""
相关推荐
水如烟5 分钟前
孤能子视角:“组织行为学–组织文化“
人工智能
大山同学9 分钟前
图片补全-Context Encoder
人工智能·机器学习·计算机视觉
薛定谔的猫198221 分钟前
十七、用 GPT2 中文对联模型实现经典上联自动对下联:
人工智能·深度学习·gpt2·大模型 训练 调优
壮Sir不壮33 分钟前
2026年奇点:Clawdbot引爆个人AI代理
人工智能·ai·大模型·claude·clawdbot·moltbot·openclaw
PaperRed ai写作降重助手41 分钟前
高性价比 AI 论文写作软件推荐:2026 年预算友好型
人工智能·aigc·论文·写作·ai写作·智能降重
玉梅小洋1 小时前
Claude Code 从入门到精通(七):Sub Agent 与 Skill 终极PK
人工智能·ai·大模型·ai编程·claude·ai工具
-嘟囔着拯救世界-1 小时前
【保姆级教程】Win11 下从零部署 Claude Code:本地环境配置 + VSCode 可视化界面全流程指南
人工智能·vscode·ai·编辑器·html5·ai编程·claude code
正见TrueView1 小时前
程一笑的价值选择:AI金玉其外,“收割”老人败絮其中
人工智能
Imm7771 小时前
中国知名的车膜品牌推荐几家
人工智能·python
风静如云1 小时前
Claude Code:进入dash模式
人工智能