流式加载文件到内存
由于内存有限,无法一次将全部文件从磁盘加载到内存,因此,我们需要分块加载读取,需要设置一个常数 CHUNK_SIZE 确定每次加载的数据大小。
但是我们无法保证每次切分恰好都是完整的单词,所以,我们利用文档边界 <|endoftext|> 来实现每次切分,防止把一个完整的文档从中间切开,也也就不会把一个单词切分到前后两个 chunk 中。
我们每次读取一个 CHUNK_SIZE 大小的chunk,但是只取<|endoftext|> 的部分,<|endoftext|> 保存到一个变量 leftover 中,下次拼接。 注意:我们使用流式读取文档内容,只切块,不做任何清洗,chunk 依然包含特殊 Token。
下面是我让 claude 根据代码用文档中的例子模拟分块读取的流程。
手动模拟(
在纸上或者思维中,模拟这个过程:
文件内容:ABCDE<|end|>FGH<|end|>
chunk_size = 8
leftover = ""
第1次:
- read(8) 读到:ABCDE<|e
- leftover + block = "" + "ABCDE<|e" = "ABCDE<|e"
- rfind("<|end|>") = -1
- leftover = "ABCDE<|e"
第2次:
- read(8) 读到:nd|>FGH
- leftover + block = "ABCDE<|e" + "nd|>FGH<" = "ABCDE<|end|>FGH<"
- rfind("<|end|>") = 5
- yield: "ABCDE<|end|>"
- leftover = "FGH<"
... 继续
# 生成器函数
def chunk_documents(file_path):
# ... 读取和切分
yield "Doc1<|endoftext|>" # 🔴 暂停在这里,把 Doc1 交给外部
# ... 继续
yield "Doc2<|endoftext|>" # 🔴 暂停在这里,把 Doc2 交给外部
# ...
# 外部使用(这是关键!)
word_counts = {} # 🎯 最终数据存这里
for chunk in chunk_documents(file_path): # for 循环是"接收者"
# chunk 就是 yield 的那个字符串
# 第1次循环:chunk = "Doc1<|endoftext|>"
# 第2次循环:chunk = "Doc2<|endoftext|>"
# 处理这个 chunk
words = re.findall(pattern, chunk)
for word in words:
word_counts[word] += 1 # 🎯 结果累积在这里
# chunk 用完了,被垃圾回收,释放内存
流程示意:
生成器 (yield) ─────> for 循环 (接收) ─────> 处理 ─────> word_counts (存储)
│ │ │
│ │ └─ 统计、计数
│ │
│ └─ 临时持有 chunk
│
└─ 不存储数据,只产生数据
内存中的情况:
时刻1:yield "Doc1<|endoftext|>"
内存:chunk = "Doc1<|endoftext|>" (~20 bytes)
word_counts = {"Doc1": 1}
时刻2:yield "Doc2<|endoftext|>"
内存:chunk = "Doc2<|endoftext|>" (~20 bytes)
word_counts = {"Doc1": 1, "Doc2": 1}
注意:"Doc1<|endoftext|>" 字符串已经被释放了!
对应的预分词也要逐个 chunk 进行:
大文件
↓
_chunk_documents 流式读取
↓
一个 chunk(保证 EOT 边界)
↓
按 special token 切开
↓
对每块普通文本做 GPT 正则预分词
↓
实时累计词频
↓
最终得到全局 word_counts
代码和简单测试
import os
import regex as re
from collections import defaultdict
from pathlib import Path
CHUNK_SIZE = 10
# the folser of the current file
BASE_DIR = Path(__file__).resolve().parent
# the root folder of the project(assignment1-basics)
PROJECT_ROOT = BASE_DIR.parent
def chunk_documents_streaming(
path: str,
special_token: str = "<|endoftext|>",
chunk_size: int = CHUNK_SIZE
):
"""
工具函数:流式分块读取
可以被任何代码调用,不绑定到类
"""
leftover = ""
# 打开文件,'r'是只读模式
with open(input_path, 'r', encoding='utf-8') as f:
while True:
# 1. 每次只读一小块 (比如 1MB)
block = f.read(CHUNK_SIZE)
# 2. 如果读不到了,说明文件读完了,退出循环
if not block:
break
# 3. 【拼接】把上次没处理完的尾巴,拼到这次的头上
block = leftover + block
leftover = "" # 清空尾巴缓存
# 4. 【找切分点】找最后一个特殊分隔符的位置
# rfind 返回最后一次出现的索引,没找到返回 -1
last_idx = block.rfind(special_token)
if last_idx == -1:
# 如果这块数据里一个分隔符都没有,说明这块数据可能只是一个超长句子的一部分
# 全部存入 leftover,等待下一次读取拼接
leftover = block
else:
# 5. 【切分与产出】
# yield 前半部分(包含分隔符),这是完整有效的数据
yield block[:last_idx + len(special_token)]
# 剩下的后半部分存入 leftover
leftover = block[last_idx + len(special_token):]
# 6. 处理最后剩下的一点点(通常是文件末尾)
if leftover:
yield leftover
pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def test_chunk_streaming(path):
chunks = list(chunk_documents_streaming(path,special_token="<|endoftext|>"))
print(f"Got {len(chunks)} 个 chunks\n")
for i, chunk in enumerate(chunks):
print(f"--- Chunk {i} ---")
print(repr(chunk)) # repr display \n 和 special characters
print("Ends with <|endoftext|>? ", chunk.endswith("<|endoftext|>"))
print()
return chunks
def _pretokenize_streaming(input_path, special_tokens):
"""pretokenize and count"""
word_counts = defaultdict(int)
pattern_compiled = re.compile(pattern)
special_pattern = "|".join(re.escape(t) for t in special_tokens)
# call the chunk_documents_streaming function
for chunk in chunk_documents_streaming(input_path):
if special_tokens:
blocks = re.split(special_pattern, chunk)
else:
blocks =[chunk]
for block in blocks:
for match in re.finditer(pattern_compiled, block):
word_counts[match.group(0)] += 1
return dict(word_counts)
# Construct the path of test file
TEST_FILE = PROJECT_ROOT / "tests" / "fixtures" / "tinystories_sample.txt"
print("The path of the test file:", TEST_FILE)
if __name__ == "__main__":
test_chunk_streaming(TEST_FILE)
word_counts = _pretokenize_streaming(TEST_FILE, special_tokens=["<|endoftext|>"])
print(word_counts)
核心代码解释
with open(...): 上下文管理器。不管程序是否出错,它都会自动关闭文件,释放资源。f.read(CHUNK_SIZE): 每次只从硬盘把一小部分数据(例如 1MB)搬到内存。这是省内存的关键。yield: 这是一个生成器(Generator)。它不会一次性把所有数据都给你,而是像挤牙膏一样,你循环一次,它挤出来一块。leftover(残余/剩余): 这是最精髓的部分。
注意:
re.findall和re.finditer区别
findall: 一次性找到所有匹配,生成一个巨大的列表 ['low', ' low', ...]。如果文件块很大,这会瞬间吃光内存。
finditer: 迭代器。找到一个,给你一个 match 对象。内存极度友好。
- 两个函数中的
special_token和special_tokens
- 数据类型不同:
strvsList[str] _chunk_documents_streaming中 (str) 前面解释过为了避免将完整的句子或者token 从中间切断。_pretokenize中 (List[str]) 是因为在统计词频时需要知道所有的特殊 token (比如<|endoftext|>,<|padding|>,<|mask|>),- 使用正则
re.split把这些特殊 Token 从普通文本中剔除出去。因为 BPE 统计词频时,不能把特殊 Token 拆碎了统计
输出
he path of the test file: ./assignment1-basics/tests/fixtures/tinystories_sample.txt
Got 5 个 chunks
--- Chunk 0 ---
"\nOnce upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed!\nHe said, "Wow, that is a really amazing vase! Can I buy it?"\nThe shopkeeper smiled and said, "Of course you can. You can take it home and show all your friends how amazing it is!"\nSo Ben took the vase home and he was so proud of it! He called his friends over and showed them the amazing vase. All his friends thought the vase was beautiful and couldn't believe how lucky Ben was.\nAnd that's how Ben found an amazing vase in the store!\n<|endoftext|>"
Ends with <|endoftext|>? True
--- Chunk 1 ---
.........................
{'\n': 24, 'O': 8, 'c': 5, ' up': 7, ' a': 52, ' ': 219, 'im': 4, 'h': 55, 'r': 55, ' was': 21, ' li': 4, 'l': 12, ' b': 11, 'y': 18, 'am': 4, ' B': 15, '.': 64, ' l': 10, 'v': 13, 'pl': 1, ' w': 11, 'rl': 1, ' ar': 2, 'u': 16, ' him': 3, ' H': 8, ' saw': 4, ' ma': 2, ' amazi': 5, 'g': 24, 'hi': 4, 'gs': 2, ',': 39, ' lik': 3, 'au': 2, 'i': 13, 'ul': 4, ' vas': 7, 's': 12, 'ha': 7, 'isplay': 1, ' i': 17, ' s': 17, ' O': 3, 'ay': 9, ' walki': 1, 'hr': 1, 'ugh': 3, ' wh': 1, ' h': 24, ' cam': 1, ' acr': 1, 'ss': 2, ' v': 6, 'ry': 8, ' sp': 2, 'cial': 2, ' Wh': 2, ' amaz': 2, '!': 6, 'H': 2, ' sai': 7, ' "': 2, 'W': 1, 'w': 18, ' is': 4, ' r': 5, 'ally': 1, ' Ca': 1, ' I': 5, ' buy': 1, '?"': 1, 'Th': 2, ' sh': 8, 'pk': 1, 'p': 3, ' smil': 3, ' c': 4, 'urs': 1, ' y': 3, ' ca': 7, ' Y': 2, 'ak': 3, 'm': 17, ' all': 4, 'ur': 2, 'ri': 15, '!"': 1, 'S': 1, 'k': 6, ' pr': 2, ' call': 3, ' his': 11, ' All': 1, "'": 1, 'li': 3, ' lucky': 1, 'A': 2, "'s": 3, 'liabl': 1, ' Olli': 7, ' liv': 2, ' riv': 1, ' wi': 12, 'amily': 5, ' Th': 14, ' play': 10, ' swim': 1, ' m': 9, ' "': 8, 'Olli': 1, ' hurry': 2, ' g': 4, 'ish': 6, '!"': 4, ' swam': 1, 'as': 1, 'ch': 2, 'uck': 3, 'Hi': 2, 'I': 2, ' my': 1, '."': 2, 'Whil': 1, 'chi': 1, ' big': 7, ' shi': 4, 'This': 1, ' bu': 2, ' happy': 2, 'rg': 1, ' ab': 2, ' Tim': 5, ' park': 1, 'ig': 5, 'a': 5, 'asy': 2, ' ha': 4, ' u': 1, ' happ': 1, 'ar': 1, ' shak': 1, ' scar': 1, ' k': 3, ' wha': 1, ' Bu': 2, 'ic': 3, ' surpris': 1, 'Tim': 1, ' A': 2, 'ly': 4, 'b': 8, ' pick': 5, 'rs': 3, ' bir': 1, 'ci': 2, 'si': 1, ' su': 1, 'ir': 1, ' gr': 2, 'humb': 5, ' curi': 1, 'usly': 1, ':': 1, 'his': 1, ' Wha': 1, '?"': 2, 'His': 1, ' S': 1, ' car': 1, 'ully': 1, ' carri': 1, ' back': 2, ' arriv': 1, ' happily': 1, ' His': 1, ' hugg': 1, ' appr': 1, 'cia': 1, 'Fr': 1, ' always': 1, ' as': 1, 'mi': 1, 'ywh': 1, ' small': 1, 'us': 1, ' girl': 1, ' Lucy': 11, 'ra': 6, ' Sh': 2, ' ball': 3, ' This': 1, ' spiri': 8, ' playi': 1, ' havi': 1, ' much': 1, ' hav': 2, ' Will': 1, ' la': 1, 'Spiri': 1, 'al': 2, ' imagi': 1, ' sa': 1, ' ig': 1}
符合预期,可以将测试过的从分块加载文件和预分词功能加入到BPETrainer 类中了。
完整的代码在 cs336_basics/train_bpe.py ,下面就是修改adapters.run_train_bpe如下:
-
adapters.py中导入我们的实现from cs336_basics import train_bpe as bpe
-
在
run_train_bpe函数中添加如下bpe_trainer = bpe.BPETrainer_Simple() return bpe_trainer.train(input_path, vocab_size, special_tokens) -
执行 按照文档中执行
uv run pytest tests/test_train_bpe.py
内存不够会OOM
我一开始训练开始一段时间后得到如下:========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.11.12, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/fdq/cources/cs336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 itemstests/test_train_bpe.py::test_train_bpe_speed Starting BPE training...
没有得到任何结果,这与测试脚本 test_train_bpe.py的逻辑是不符的,我又检查测试自己的代码,确认符合预期,又执行了一遍,还是如此,我突然想到可能是OOM了,再次执行时,我持续观察内存,果然。。。 我关闭了所有能关闭的服务,
最终发现是我在新加一个逻辑判断时,加错了位置,导致了OOM。修正后测试结果如下:
platform linux -- Python 3.11.12, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/fdq/cources/cs336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 items
tests/test_train_bpe.py::test_train_bpe_speed Starting BPE training...
FAILED
tests/test_train_bpe.py::test_train_bpe Starting BPE training...
PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens Starting BPE training...
PASSED
失败原因:
tests/test_train_bpe.py:24: AssertionError
============================================================================== short test summary info ===============================================================================
FAILED tests/test_train_bpe.py::test_train_bpe_speed - assert (1770088617.786312 - 1770088615.684015) < 1.5
============================================================================ 1 failed, 2 passed in 15.37s ============================================================================
现在说明我们的代码没有问题了,只是训练花费时间太长,下一步就是找出慢的点并逐步优化。