《动手学深度学习》-60translate实现

复制代码
1.读取数据
2.数据预处理:加空格,处理标点
3.分词:去掉第三列,并改为按单词分词
4.序列截断或填充
5.序列最后加结束符,并计算有效长度
6.生成数据集
def data_nmt(path='D:/PycharmDocument/limu/data/fra.txt'):
    with open(path,'r',encoding='utf-8',errors='ignore') as f:
        return f.read()
raw_txt=data_nmt()
# print(raw_txt[:175])
def preprocess_nmt(text):
    def no_space(char,prev_char):
        return char in set(',.!?') and prev_char !=' '
    text=text.replace('\u202f',' ').replace('\xa0',' ').lower()#将(窄)不换行空格换成普通空格
    out=[' '+char if i>0 and no_space(char,text[i-1]) else char for i,char in enumerate(text)]#如果是char标点符号,标点符号前加空格,否则直接加char
    return ''.join(out)
text=preprocess_nmt(raw_txt)
# print(text[:80])
def tokenize_nmt(text,num_examples=None):#划分标签和输入集
    sourch,target = [],[]
    for i,line in enumerate(text.split('\n')):
        if num_examples and i>num_examples:
            break
        parts=line.split('\t')[:2]
        if len(parts)==2:
            sourch.append(list(parts[0].split(' ')))
            target.append(list(parts[1].split(' ')))
    return sourch,target
source,target=tokenize_nmt(text)
# print(source[:6],target[:6])
def show_list_len_pair_hist(legend,xlabel,ylabel,xlist,ylist):
    plt.figure(figsize=(10,6))
    _,_,patches=plt.hist([[len(l) for l in xlist],[len(l) for l in ylist]])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    for patch in patches[1].patches:
        patch.set_hatch('/')
    plt.legend(legend)
    plt.show()
show_list_len_pair_hist(['source','target'],'# tokens per sequence','count',source,target)
复制代码
src_voc=test52text_process.Vocab(source,min_freq=2,reserved_tokens=['<pad>','<bos>','<eos>'])#句子填充、句子开始、结束
def truncate_pad(line,num_steps,padding_token):
    if len(line)>num_steps:
        return line[:num_steps]
    return line+[padding_token]*(num_steps-len(line))
def build_array_nmt(lines,vocab,num_steps):
    lines=[vocab[l] for l in lines]
    array=torch.tensor([truncate_pad(l, num_steps,vocab['<pad>']) for l in lines])
    valid_len=(array!=vocab['<pad>']).type(torch.int32).sum(1)
    return array,valid_len
复制代码
def load_data_nmt(batch_size, num_steps, num_examples=600):
    text = preprocess_nmt(raw_txt)
    source, target = tokenize_nmt(text, num_examples)
    # 1. 创建词表对象 (不要试图在这里解包 valid_len)
    src_vocab = test52text_process.Vocab(source, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = test52text_process.Vocab(target, min_freq=2,
                                         reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    # 3. 将 Tensor 放入数组,而不是放 Vocab 对象
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    # 4. 生成迭代器
    data_iter = d2l.load_array(data_arrays, batch_size)
    # --- 修复结束 ---
    return data_iter, src_vocab, tgt_vocab
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,
复制代码
train_iter,src_vocab,tgt_vocab=load_data_nmt(2,8)
for X,X_valid_len,Y,Y_valid_len in train_iter:
    print(X.type(torch.int32))
    print(X_valid_len)
    print(Y.type(torch.int32))
    print(Y_valid_len)
    break
相关推荐
Orchestrator_me几秒前
Python pip install报SSL错误
python·ssl·pip
老金带你玩AI几秒前
用ChatGPT管项目,让Codex只做Ticket
人工智能
开源量化GO2 分钟前
期货 K 线算信号 tick 级止损:天勤双序列 wait_update 触发规则
linux·运维·服务器·python
聆春烟雨簌簌10 分钟前
LangChain4j使用文档
开发语言·python
前端不太难10 分钟前
从模型部署到智能运营:企业AI的新挑战
人工智能
ZFSS18 分钟前
VS Code + Luma MCP 使用教程
人工智能·ai·ai作画·copilot·ai编程·ai写作
某林21218 分钟前
ROS2 语音机器人实战:从 KCF 跟随失效到 RTAB-Map 建图闭环的完整排障
人工智能·机器人·语音识别·ros2·架构重构·技术复盘·c++底层排错
Tongpao_SSDHDD20 分钟前
希捷酷鹰ST6000VX008实测解析:中小安防监控高性价比存储方案
大数据·数据库·人工智能
Ricky055322 分钟前
基于作物特性的语义分割技术用于高效农业病害评估(西班牙德国2025年联合研究)
人工智能·目标检测·图像分割