目前业界流行的 Transformer 模型架构虽然在大多数场景表现优秀,但其上下文窗口(Window)长度的限制,通常仅为几千到几万个 Token,这使得它们在处理长文本、多轮对话或需要大规模上下文记忆的任务中,往往无法保持语义连贯性和信息准确性。
arxiv 上论文地址
[2501.00663] Titans: Learning to Memorize at Test Time
论文中提到Titans 具有三种架构设计变体,分别是 Memory as a Context(MAC)、Memory as a Gate(MAG)和 Memory as a Layer(MAL),可以根据不同的任务需求整合短期与长期记忆。其中"MAC"架构变体将长期记忆作为上下文的一部分,允许注意力机制动态结合历史信息与当前数据,适合处理需要详细历史上下文的任务。其中
MAC架构包括三个分支:(1)核心分支,(2)上下文(长期)记忆分支,(3)持久性记忆分支。核心分支将相应的长期和持久性记忆与输入序列连接起来。接下来,注意力机制在序列上执行,并决定应该将哪些信息存储在长期记忆中。在测试时,对应于上下文记忆的参数仍在学习,对应于核心分支的参数负责上下文学习,而持久性记忆的参数负责存储任务知识,因此是固定的。13
总结来说,MAC架构通过三个分支来实现记忆功能:
- 核心分支将输入序列与长期和持久性记忆连接起来
- 注意力机制决定将哪些信息存储在长期记忆中
- 持久性记忆存储任务知识,在测试时保持固定
这种设计可以有效地将短期和长期记忆整合到模型中,提高其在各种任务上的性能。
论文声称,Titans 系列模型架构在长序列处理任务中的表现明显优于现有模型,无论是语言建模还是时间序列预测,Titans 在准确性和效率上都展现了"压倒性优势",甚至在某些场景中超越了如 GPT-4 等具有数十倍参数的模型
这张图展示了神经网络训练中的几个关键概念,特别是如何利用并行处理来提高效率。以下是图中每个部分的简要说明:
-
Linear Within-Chunk:
- 这里展示了在一个数据块内部进行线性计算的过程,可能是通过累积(cum sum)来实现的。
- 这种方式可以提高计算效率,尤其是在处理大规模数据时。
-
Non-Linear Cross-Chunk:
- 显示了如何在数据块之间进行非线性计算。
- 这种计算通常涉及梯度的计算,可能会使用一些激活函数来实现。
-
Momentum Calculation:
- 这一部分提到梯度的预计算,通常用于实现动量更新(Momentum)。
- 动量法通过累积过去的梯度来加速训练过程,帮助模型更快收敛。
-
Weight Decay:
- 这段展示了如何在训练过程中实现权重衰减(Weight Decay),以防止过拟合。
- 说明了通过矩阵乘法(matmul)来计算权重的更新。
该图描述了一种名为"Memory as a Context (MAC)"的架构,主要涉及记忆在神经网络中的角色。以下是各个部分的简要说明:
-
Neural Memory:
- 这一部分表示神经网络中的记忆模块,包括三种类型的记忆:上下文记忆(contextual long-term memory)、顺序记忆(sequential memory),以及持久记忆(persistent memory)。
- 这些记忆模块用于存储不同时间尺度的信息,以便在处理输入序列时使用。
-
Core:
- 核心部分负责处理输入序列,并将其与相应的上下文和持久记忆结合。
- 该部分的注意力机制决定了哪些信息在处理过程中是重要的,帮助模型在长期内保持信息。
-
Parameters:
- 图中提到的参数对应于不同的分支:核心分支负责上下文学习,而持久记忆分支则负责与任务相关的知识。
- 这些参数是可学习的,且与输入数据无关。
-
在测试时,模型利用之前存储的记忆来做出决策,展示了记忆对于模型性能的重要性。整体而言,这张图强调了在MAC架构中记忆的多样性和重要性,如何通过不同类型的记忆模块来增强模型的学习和推理能力。这种设计旨在提高模型在复杂任务中的表现,尤其是在需要长期记忆的情境下。
github上有一个非官方的titans的模型实现 titans-pytorch
class MemoryAsContextTransformer(Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
segment_len,
neural_memory_segment_len=None,
num_longterm_mem_tokens=0,
num_persist_mem_tokens=0,
dim_head=64,
heads=8,
ff_mult=4,
num_residual_streams=4,
neural_memory_kwargs: dict = dict(),
neural_memory_layers: tuple[int, ...] | None = None,
aux_kv_recon_loss_weight=0.
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim=dim, num_axial_dims=2)
# long term mem tokens
self.segment_len = segment_len
self.num_longterm_mem_tokens = num_longterm_mem_tokens
has_longterm_mems = num_longterm_mem_tokens > 0
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
# hyper conection
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(
num_residual_streams, disable=num_residual_streams == 1)
self.layers = ModuleList([])
self.neural_mem_layers = ModuleList([])
self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
layers = tuple(range(1, depth + 1))
if not exists(neural_memory_layers):
neural_memory_layers = layers if has_longterm_mems else ()
assert not (num_longterm_mem_tokens > 0 and len(
neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
for layer in layers:
is_first = layer == 1
# neural memory
mem = None
if layer in neural_memory_layers:
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
mem = NeuralMemory(
dim=dim,
chunk_size=self.neural_memory_segment_len,
**neural_memory_kwargs
)
mem = init_hyper_conn(dim=dim, branch=mem)
self.neural_mem_layers.append(mem)
# attention and feedforward
attn = SegmentedAttention(
dim=dim,
dim_head=dim_head,
heads=heads,
segment_len=segment_len,
accept_value_residual=not is_first,
num_longterm_mem_tokens=num_longterm_mem_tokens,
num_persist_mem_tokens=num_persist_mem_tokens
)
ff = FeedForward(dim=dim, mult=ff_mult)
self.layers.append(ModuleList([
init_hyper_conn(dim=dim, branch=attn),
init_hyper_conn(dim=dim, branch=ff)
]))
self.norm = nn.RMSNorm(dim)
self.to_logits = LinearNoBias(dim, num_tokens)
# auxiliary loss on kv recon
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
self.aux_kv_recon_loss_weight = aux_kv_recon_loss_weight
self.register_buffer('zero', torch.tensor(0.), persistent=False)
-
初始化参数:
num_tokens
: 输入 token 的总数。dim
: 嵌入和隐藏层的维度。depth
: 网络的层数。segment_len
: 每个输入段的长度。neural_memory_segment_len
: 神经记忆的分段长度(可选)。num_longterm_mem_tokens
: 长期记忆 token 的数量。num_persist_mem_tokens
: 持久记忆 token 的数量。dim_head
: 注意力头的维度。heads
: 注意力的头数。ff_mult
: 前馈网络的扩展因子。num_residual_streams
: 残差流的数量。neural_memory_kwargs
: 神经记忆的参数。neural_memory_layers
: 使用神经记忆的层数(可选)。aux_kv_recon_loss_weight
: 额外的损失权重。
-
嵌入层:
self.token_emb
: 使用nn.Embedding
将输入的 token ID 转换为嵌入向量。self.axial_pos_emb
: 使用连续的轴向位置嵌入来增强模型对输入序列中 token 相对位置的理解。
-
长期记忆:
self.longterm_mems
: 长期记忆的参数,随机初始化并乘以一个小的常数以避免初始时过大的值。
-
超连接:
- 通过
get_init_and_expand_reduce_stream_functions
获取超连接初始化、扩展和收缩流的函数。这有助于促进不同层之间的信息流动。
- 通过
-
层的构建:
- 使用
ModuleList
构建多层结构,包括神经记忆层、注意力层和前馈层。 - 每层的注意力和前馈网络都通过
init_hyper_conn
函数进行超连接。
- 使用
-
注意力机制 -
SegmentedAttention
:- 该模块负责处理输入序列的注意力计算,支持灵活的注意力机制。
accept_value_residual
控制是否接受值残差,这在第一层中通常设置为 False。
-
前馈网络 -
FeedForward
:- 该模块包括一个归一化层、全连接层和激活函数,最后是另一个全连接层,输出与输入相同维度的结果。
-
归一化和输出层:
self.norm
: 使用RMSNorm
进行归一化处理,确保输入在每个批次中的均值和方差平衡。self.to_logits
: 一个线性变换,将模型的输出映射到 token 的数量,用于生成最终的 logits。
-
辅助损失:
self.has_aux_kv_recon_loss
: 检查是否需要计算辅助的键值重建损失。self.aux_kv_recon_loss_weight
: 控制辅助损失的权重。
关键功能
- 多层网络:通过多个层的堆叠,使得模型能够学习更复杂的模式和特征。
- 灵活的注意力机制 :通过
SegmentedAttention
支持多种注意力变体,增强模型的灵活性。 - 长期和持久记忆:结合长期和持久记忆来增强模型对上下文的理解,适应更复杂的任务。
- 超连接:通过超连接促进信息流,增强模型的表现。
关于SegmentedAttention
作者的实现是
class SegmentedAttention(Module):
def __init__(
self,
dim,
segment_len,
num_persist_mem_tokens=0,
num_longterm_mem_tokens=0,
dim_head=64,
heads=8,
accept_value_residual=False,
attend_kwargs: dict = dict(),
use_flex_attn=False
):
super().__init__()
self.norm = nn.RMSNorm(dim)
dim_inner = dim_head * heads
self.rotary_emb = RotaryEmbedding(dim_head)
self.attend = Attend(causal=True, **attend_kwargs)
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
self.to_out = LinearNoBias(dim_inner, dim)
self.to_learned_v_mix = nn.Sequential(
nn.Linear(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if accept_value_residual else None
self.segment_len = segment_len
self.num_longterm_mem_tokens = num_longterm_mem_tokens
# total_segment_len = segment_len + num_longterm_mem_tokens
self.split_heads = Rearrange('b n (h d) -> b h n d', h=heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
# flex attn related
assert not (use_flex_attn and not exists(
flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
self.use_flex_attn = use_flex_attn
self.segment_len = segment_len
self.num_persist_mem_tokens = num_persist_mem_tokens
def forward_flex(
self,
seq,
value_residual=None,
flex_attn_fn: Callable | None = None
):
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
batch, seq_len = seq.shape[:2]
# attention
seq = self.norm(seq)
q, k, v = self.to_qkv(seq).chunk(3, dim=-1)
q, k, v = map(self.split_heads, (q, k, v))
# value residual
orig_v = v
if exists(self.to_learned_v_mix):
mix = self.to_learned_v_mix(seq)
v = v.lerp(value_residual, mix)
# take care of persistent memory key / values
pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b=batch)
# relative positions
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
# persistent memory
k = cat((pmk, k), dim=-2)
v = cat((pmv, v), dim=-2)
# prep flex attention
if not exists(flex_attn_fn):
block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
flex_attn_fn = partial(flex_attention, block_mask=block_mask)
# attention
out = flex_attn_fn(q, k, v)
out = self.merge_heads(out)
out = self.to_out(out)
return out, orig_v
def forward(
self,
seq,
value_residual=None,
flex_attn_fn: Callable | None = None
):
if seq.is_cuda and self.use_flex_attn:
return self.forward_flex(seq, value_residual, flex_attn_fn)
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
total_segment_len = segment_len + num_longterm_mem_tokens
batch, seq_len = seq.shape[:2]
# auto pad to multiple
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
seq = self.norm(seq)
q, k, v = self.to_qkv(seq).chunk(3, dim=-1)
q, k, v = map(self.split_heads, (q, k, v))
# value residual
orig_v = v
if exists(self.to_learned_v_mix):
mix = self.to_learned_v_mix(seq)
v = v.lerp(value_residual, mix)
# take care of persistent memory key / values
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b=seq.shape[0])
# relative positions
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
# persistent memory
k = cat((pmk, k), dim=-2)
v = cat((pmv, v), dim=-2)
# attention
out, _ = self.attend(q, k, v)
out = self.merge_heads(out)
out = self.to_out(out)
out = inverse_segment(out)
return out, orig_v
-
初始化参数:
dim
: 输入特征的维度。segment_len
: 每个输入段的长度。num_persist_mem_tokens
: 持久记忆 token 的数量。num_longterm_mem_tokens
: 长期记忆 token 的数量。dim_head
: 注意力头的维度。heads
: 注意力的头数。accept_value_residual
: 是否接受值残差。attend_kwargs
: 其他传递给注意力机制的参数。use_flex_attn
: 是否使用灵活的注意力机制。
-
归一化层:
self.norm
: 使用RMSNorm
进行输入序列的归一化,以提高模型稳定性和收敛速度。
-
旋转嵌入:
self.rotary_emb
: 使用RotaryEmbedding
来处理查询和键的相对位置。
-
注意力机制:
self.attend
: 一个注意力计算模块,支持因果注意力(causal attention)。self.to_qkv
: 将输入序列转换为查询、键和值的线性变换,不使用偏置。self.to_out
: 将注意力输出转换为原始特征维度的线性变换。
-
值残差处理:
self.to_learned_v_mix
: 如果accept_value_residual
为 True,则使用一个线性层来学习如何混合值残差。
-
持久记忆:
self.persistent_memory
: 用于存储持久记忆的参数,初始化为零。
-
头部分割与合并:
self.split_heads
: 将多头注意力的输出分割为多个头。self.merge_heads
: 将多个头的输出合并为一个输出。
前向传播
forward_flex
方法
-
适用情况:如果在 CUDA 上且使用灵活注意力。
-
输入处理:
- 对输入序列进行归一化处理。
- 将输入序列转换为查询、键和值。
- 将这些张量分割为多个头。
-
值残差处理:
- 如果
to_learned_v_mix
存在,则根据混合系数进行值的线性插值。
- 如果
-
持久记忆:
- 将持久记忆的键和值与当前的键和值进行拼接,以增强模型对历史信息的利用。
-
灵活注意力计算:
- 创建块掩码并调用灵活注意力函数
flex_attn_fn
,计算注意力输出。
- 创建块掩码并调用灵活注意力函数
-
输出处理:
- 将输出合并回原始维度,并返回输出和原始值。
forward
方法
-
适用情况:如果没有使用灵活注意力。
-
输入处理:
- 对输入序列进行填充和分段处理。
- 对序列进行归一化。
-
注意力计算:
- 将输入序列转换为查询、键和值,分割为多个头。
- 处理值残差(如有)。
-
持久记忆:
- 处理持久记忆的键和值,拼接到当前的键和值上。
-
注意力计算:
- 调用
self.attend
计算注意力输出。
- 调用
-
输出处理:
- 合并头部输出,进行线性变换,并将输出恢复到原始序列形状。
SegmentedAttention
类通过结合分段处理、持久记忆和灵活的注意力机制,能够有效地处理长序列数据。这种设计使得模型在处理复杂任务时具有更高的灵活性和适应性,特别是在需要利用历史信息的情况下。整体上,该模块增强了变换器模型的能力,适合于多种自然语言处理和序列建模任务。
使用enwik8.gz 数据集训练文本自动生成
with gzip.open('./data/enwik8.gz') as file:
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
data_train, data_val = np.split(data, [int(90e6)])
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
enwik8.gz
数据集是一个广泛使用的文本数据集,主要用于训练和评估语言模型和文本生成模型。以下是该数据集的一些主要特点以及它为何适合用于文本自动生成的原因:
-
丰富的语言模式:大规模和多样化的文本使得模型可以学习到丰富的语言模式和语法结构,从而生成自然流畅的文本。
-
上下文学习:数据集的连续性有助于模型理解上下文信息,这对于生成相关且连贯的文本至关重要。
-
多样化的主题:模型可以通过学习不同主题的文本来增强其生成能力,使得生成的文本能够在多种上下文中应用。
-
适应性:训练好的模型可以在各种文本生成任务中进行微调,比如故事创作、对话系统、摘要生成等,体现了良好的适应性。
-
强大的评估基准 :
enwik8
常用于评估语言模型的性能,提供了一个标准化的基准,方便与其他模型进行比较。
开启训练,每五百次训练后随机进行一次续写输出
第一次输出
****************************************************************************************************
input_str: which, unchecked, would cause the sugars in the nectar to ferment. After the final regurgitation, t
output_str: �ù:Ä^ ; e u u d ; u Óo eùg �Cu Ze e n ue e de ^e e d¢Óu n e oenu den ene ÓZ ½vÎr}Ce �e de Óue n eeen en n�ede nZu g^ e Ó nu ^ed e ^e dªeÿ´U´V enu n nue Ze u d u Ä o´V n e ed¢ ^ n ee u geT^ du d pee u e ^ ue u eÿrZ^d^u Og^d nZ g n n e e ^Ó^ do^ J�püeuu ÓZoþedoËôede ^d ^ ud gZhnunCpdÎuueÓü Z u Î^do´Óeu ÎZZee´(d �ÿ;(ÿ(�p(eÓen ZnCU�unp´s`ñeeôÓÓeÿdZZuengud^T½C`�n e e¢
经过多轮训练后
****************************************************************************************************
input_str: ants (war veterans), who had been the heart of the liberation struggle 20 years earlier. He agreed t
output_str: he artistical list system. Th flowing the book of the film inerstand confiture of any terms ions party of the Battle far partree and regards from the furthere treator and when examples sprence all an electrone a format the factory of the original naturelses. == Even studies== The [[puter (never Certroof|Lord (Spand Critic Walico, Hapports Crowinged Statistics persons in the Courater. The Eliberal United Sta
****************************************************************************************************
input_str: involve believing that human nature is purely good or that each and every person is capable of livi
output_str: ng an unknown it domain multit in the steams, considered here soletered social information ofrom the [[University of Panufallds far run]] to [[2006]]. This pending the two kills in which sanced a [[Who Continential Crushering]]. In approximation of the between the [[Neil Translate Legar]], a bassad that the followership, from [[sung-falstelless]] finalists are parently music. The armorship to aspect communit
****************************************************************************************************
input_str: it is the longest continuously held long distance foot race in North America. The local newspaper al
output_str: l at the complaint in 1851, ched it has larger to head the sultiples to the increased against actually its the assembled storough a world. In [[1871]], a [[ry]] as a require to the [[Time]]] differentially respectively [[[Beat]]s. This part of [[Universolution]] and the [[Classic Chur]], the [[Song of Alliance Portuition]] (ALC). [[John Section]] Benaun|Khaughan Vickener, and whe [[Temple of Treaty|Antoniomet