transformer 最简单学习3, 训练文本数据输入的形式

1、输入数据中,源数据和目标数据的定义

cpp 复制代码
def get_batch(source,i):
    '''
    
    用于获取每个批数据合理大小的源数据和目标数据
    参数source 是通过batchfy   得到的划分batch个 ,的所有数据,并且转置列表示
    i第几个batch
    '''
    bptt = 15  #超参数,一次输入多少个batch 数据,现在数据矩阵,一行表示一个batch, 一共有n个行,  

    # len(source) - 1 - i  从大往小变化,知道小到bptt,所以seq_len,大部分时间都是bptt 个=15个,最后几个训练才越来越少
    seq_len = min(bptt, len(source) -1-i)  #一共是列的元素长度,30个,  行是10个,一共三个batch ,
    # 这是转置过的,现在,就变成30个batch,每个batch 长度是3
    
    # 行数错一位,目标数据是原数据向下一位,
    data = source[i:i+seq_len]
    # 这里最后会越界,使用view(-1) 保证形状正常
    target = source[i+1:i+1+seq_len]
    return data,target #

文本数据,是每个单词对应的索引,需要对数据进行切分成整块的batch, (n行,batch列), 变成竖着的,

(batch行,n列)

然后,横着一个一个 切分成一个个batch数据,下移一个索引获取目标数据,

(n行,batch列)

cpp 复制代码
【 
     [A,B,C,D,E,F]
     [G,H,I,J,K,L]
     [M,N,O,P,Q,R],
     ......
 】

(batch行,n列)

横着看,每一位 AGMS 对应 BHNT, AB, GH, MN, ST, 是相邻的两个字

相关推荐
绝顶大聪明9 小时前
【深度学习】神经网络-part2
人工智能·深度学习·神经网络
Danceful_YJ9 小时前
16.使用ResNet网络进行Fashion-Mnist分类
人工智能·深度学习·神经网络·resnet
iFulling10 小时前
【计算机网络】第四章:网络层(上)
学习·计算机网络
香蕉可乐荷包蛋10 小时前
AI算法之图像识别与分类
人工智能·学习·算法
xiaoli232711 小时前
课题学习笔记1——文本问答与信息抽取关键技术研究论文阅读(用于无结构化文本问答的文本生成技术)
笔记·学习
人生游戏牛马NPC1号11 小时前
学习 Flutter (四):玩安卓项目实战 - 中
android·学习·flutter
LGGGGGQ12 小时前
嵌入式学习-PyTorch(7)-day23
人工智能·pytorch·学习
甄卷12 小时前
李沐动手学深度学习Pytorch-v2笔记【08线性回归+基础优化算法】2
pytorch·深度学习·算法
stm 学习ing12 小时前
Python暑期学习笔记3
笔记·python·学习
豆豆12 小时前
神经网络构建
人工智能·深度学习·神经网络