- 将文本分割为单词或子单元(tokens),然后将 tokens 映射为数字
0 加载数据集和分词器
python
from transformers import AutoTokenizer
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = load_dataset("rotten_tomatoes", split="train")
1 对单条样本进行分词
python
tokenizer(dataset[0]["text"])
'''
{'input_ids': [101, 1996, 2600, 2003, 16036, 2000, 2022, 1996, 7398, 2301, 1005, 1055, 2047, 1000, 16608, 1000, 1998, 2008, 2002, 1005, 1055, 2183, 2000, 2191, 1037, 17624, 2130, 3618, 2084, 7779, 29058, 8625, 13327, 1010, 3744, 1011, 18856, 19513, 3158, 5477, 4168, 2030, 7112, 16562, 2140, 1012, 102],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
'''
2 对整个数据集进行批量分词
python
def tokenization(example):
return tokenizer(example["text"])
dataset = dataset.map(tokenization, batched=True)
dataset[0].keys()
'''
dict_keys(['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'])
'''
原先的text和label的基础上,多了input_ids,token_type_ids和attention_mask三个key
3 设置数据集格式以适配 PyTorch
python
dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
dataset[0]
'''
{'label': tensor(1),
'input_ids': tensor([ 101, 1996, 2600, 2003, 16036, 2000, 2022, 1996, 7398, 2301,
1005, 1055, 2047, 1000, 16608, 1000, 1998, 2008, 2002, 1005,
1055, 2183, 2000, 2191, 1037, 17624, 2130, 3618, 2084, 7779,
29058, 8625, 13327, 1010, 3744, 1011, 18856, 19513, 3158, 5477,
4168, 2030, 7112, 16562, 2140, 1012, 102]),
'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}
'''