transformer 最简单学习3, 训练文本数据输入的形式

1、输入数据中,源数据和目标数据的定义

cpp 复制代码
def get_batch(source,i):
    '''
    
    用于获取每个批数据合理大小的源数据和目标数据
    参数source 是通过batchfy   得到的划分batch个 ,的所有数据,并且转置列表示
    i第几个batch
    '''
    bptt = 15  #超参数,一次输入多少个batch 数据,现在数据矩阵,一行表示一个batch, 一共有n个行,  

    # len(source) - 1 - i  从大往小变化,知道小到bptt,所以seq_len,大部分时间都是bptt 个=15个,最后几个训练才越来越少
    seq_len = min(bptt, len(source) -1-i)  #一共是列的元素长度,30个,  行是10个,一共三个batch ,
    # 这是转置过的,现在,就变成30个batch,每个batch 长度是3
    
    # 行数错一位,目标数据是原数据向下一位,
    data = source[i:i+seq_len]
    # 这里最后会越界,使用view(-1) 保证形状正常
    target = source[i+1:i+1+seq_len]
    return data,target #

文本数据,是每个单词对应的索引,需要对数据进行切分成整块的batch, (n行,batch列), 变成竖着的,

(batch行,n列)

然后,横着一个一个 切分成一个个batch数据,下移一个索引获取目标数据,

(n行,batch列)

cpp 复制代码
【 
     [A,B,C,D,E,F]
     [G,H,I,J,K,L]
     [M,N,O,P,Q,R],
     ......
 】

(batch行,n列)

横着看,每一位 AGMS 对应 BHNT, AB, GH, MN, ST, 是相邻的两个字

相关推荐
医工交叉实验工坊14 分钟前
iPS 细胞帕金森疗法落地日本:治疗费 5530 万日元(237.57万人民币)
学习
李白不吃坚果36 分钟前
误差量化分析的思考_5_17
学习·cmos·集成电路·误差·量化分析·模拟集成电路设计
xian_wwq36 分钟前
【学习笔记】探讨大模型应用安全建设系列2——安全评估:攻击面梳理与差距分析
笔记·学习·安全
Gigavision1 小时前
rPPGMamba:面向 PURE-UBFC-MMPD 跨被试远程生理感知的 Mamba 时序建模方案
python·深度学习·rppg
星夜夏空991 小时前
STM32单片机学习(15) —— PC串口通信实验
stm32·单片机·学习
网络工程小王1 小时前
【大模型vLLM 使用】学习笔记
笔记·学习·llama
初心未改HD1 小时前
深度学习之优化器详解
人工智能·深度学习
星夜夏空991 小时前
STM32单片机学习(14) —— STM32的串口外设
stm32·单片机·学习
栉甜1 小时前
APIs学习
前端·javascript·css·学习·html
吃好睡好便好2 小时前
说说梳头的保健作用
学习