关于MediaEval数据集的Dataset构建(Text部分-使用PLM BERT)

python 复制代码
import random
import numpy as np
import pandas as pd
import torch
from transformers import BertModel,BertTokenizer
from tqdm.auto import tqdm
from torch.utils.data import Dataset
import re
"""参考Game-On论文"""
"""util.py"""
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    # 用于设置生成随机数的种子
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
"""util.py"""

"""文本预处理-textGraph.py"""
# 文本DataSet类

def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = re.sub(r'(@.*?)[\s]', ' ', text)

    # Replace '&' with '&'
    text = re.sub(r'&', '&', text)

    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    # removes links
    text = re.sub(r'(?P<url>https?://[^\s]+)', r'', text)

    # remove @usernames
    text = re.sub(r"\@(\w+)", "", text)

    # remove # from #tags
    text = text.replace('#', '')

    return text

class TextDataset(Dataset):
    def __init__(self,df,tokenizer):
        # 包含推文的主文件框架
        self.df = df.reset_index(drop=True)

        # 使用的分词器
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # 帖子的文本内容
        text = self.df['tweetText'][idx]
        # 作为唯一标识符的id 'tweetId'
        unique_id = self.df['tweetId'][idx]

        # 创建一个空的列表来存储输出结果
        input_ids = []
        attention_mask = []
        # 使用tokenizer分词器
        encoded_sent = self.tokenizer.encode_plus(
            text = text_preprocessing(text), # 这里使用的是预处理的句子,而不是直接对原句子使用tokenizer
            add_special_tokens=True,        # 添加[CLS]以及[SEP]等特殊词元
            max_length=512,                 # 最大截断长度
            padding='max_length',            # padding的最大长度
            return_attention_mask=True,     # 返回attention_mask
            truncation=True                 #
        )
        # 获取编码效果
        input_ids = encoded_sent.get('input_ids')
        # 获取attention_mask结果
        attention_mask = encoded_sent.get('attention_mask')

        # 将列表转换成张量
        input_ids = torch.tensor(input_ids)
        attention_mask =torch.tensor(attention_mask)

        return {'input_ids':input_ids,'attention_mask':attention_mask,'unique_id':unique_id}

def store_data(bert,device,df,dataset,store_dir):
    lengths = []
    bert.eval()

    for idx in tqdm(range(len(df))):
        sample = dataset.__getitem__(idx)
        print('原始sample[input_ids]和sample[attention_mask]的维度:',sample['input_ids'].shape,sample['attention_mask'].shape)
        # 升维
        input_ids,attention_mask = sample['input_ids'].unsqueeze(0),sample['attention_mask'].unsqueeze(0)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        # 得到唯一标识属性
        unique_id = sample['unique_id']

        # 计算token的个数
        num_tokens = attention_mask.sum().detach().cpu().item()
        """不生成新的计算图,而是只做权重更新"""
        with torch.no_grad():
            out = bert(input_ids=input_ids,attention_mask=attention_mask)
        # last_hidden_state.shape是(batch_size,sequence_length,hidden_size)
        out_tokens = out.last_hidden_state[:,1:num_tokens,:].detach().cpu().squeeze(0).numpy() # token向量

        # 保存token级别表示
        filename = f'{emed_dir}{unique_id}.npy'

        try:
            np.save(filename, out_tokens)
            print(f"文件{filename}保存成功")
        except FileNotFoundError:
            # 文件不存在,创建新文件并保存
            np.save(filename, out_tokens)
            print(f"文件{filename}创建成功并保存成功")
        lengths.append(num_tokens)

        ## Save semantic/ whole text representation
        # 保存语义  也就是整个文本的表示
        out_cls = out.last_hidden_state[:,0,:].unsqueeze(0).detach().cpu().squeeze(0).numpy() ## cls vector
        filename = f'{emed_dir}{unique_id}_full_text.npy'
        # 尝试保存.npy文件,如果文件不存在则自动创建
        try:
            np.save(filename, out_cls)
            print(f"文件{filename}保存成功")
        except FileNotFoundError:
            # 文件不存在,创建新文件并保存
            np.save(filename, out_cls)
            print(f"文件{filename}创建成功并保存成功")
    return lengths

if __name__=='__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 根目录
    root_dir = "./dataset/image-verification-corpus-master/image-verification-corpus-master/mediaeval2015/"
    emed_dir = './Embedding_File'
    # 文件路径
    train_csv_name = "tweetsTrain.csv"
    test_csv_name = "tweetsTest.csv"

    # 加载PLM和分词器
    tokenizer = BertTokenizer.from_pretrained('./bert/')
    bert = BertModel.from_pretrained('./bert/', return_dict=True)
    bert = bert.to(device)

    # 用于存储每个推文的Embedding
    store_dir ="Embed_Post/"

    # 创建训练数据集的Embedding表示
    df_train = pd.read_csv(f'{root_dir}{train_csv_name}')
    df_train = df_train.dropna().reset_index(drop=True)

    # 训练数据集的编码结果
    train_dataset = TextDataset(df_train,tokenizer)
    lengths = store_data(bert, device, df_train, train_dataset, store_dir)

    ## Create graph data for testing set
    # 为测试集创建Embedding表示
    df_test = pd.read_csv(f'{root_dir}{test_csv_name}')
    df_test = df_test.dropna().reset_index(drop=True)
    test_dataset = TextDataset(df_test, tokenizer)

    lengths = store_data(bert, device, df_test, test_dataset, store_dir)

"""文本预处理-textGraph.py"""
相关推荐
love530love5 分钟前
【笔记】旧版MSYS2 环境中 Rust 升级问题及解决过程
开发语言·人工智能·windows·笔记·python·rust·virtualenv
VR最前沿11 分钟前
Xsens-AAA工作室品质,为动画师准备
人工智能·科技
白熊18812 分钟前
【推荐算法】NeuralCF:深度学习重构协同过滤的革命性突破
深度学习·重构·推荐算法
Leo.yuan23 分钟前
API是什么意思?如何实现开放API?
大数据·运维·数据仓库·人工智能·信息可视化
MarkHD24 分钟前
第十四天 设计一个OTA升级AB测试方案
网络·人工智能·ab测试
VR最前沿43 分钟前
全新Xsens Animate版本是迄今为止最大的软件升级,提供更清晰的数据、快捷的工作流程以及从录制开始就更直观的体验
人工智能·科技·机器人·自动化
禺垣1 小时前
知识图谱技术概述
大数据·人工智能·深度学习·知识图谱
zhongqu_3dnest1 小时前
众趣科技与我爱我家达成战略合作:AI空间计算技术赋能重塑房产服务新范式
人工智能·科技·三维建模·空间计算·vr看房·房产经纪
我就是全世界1 小时前
2025主流智能体Agent终极指南:Manus、OpenManus、MetaGPT、AutoGPT与CrewAI深度横评
人工智能·python·机器学习
MYH5161 小时前
类Transformer架构
人工智能