目录
-
- 完整代码:
- 代码运行逻辑解释
-
- [1. **导入依赖库**](#1. 导入依赖库)
- [2. **定义全局变量**](#2. 定义全局变量)
- [3. **定义编码函数**](#3. 定义编码函数)
- [4. **定义解码函数**](#4. 定义解码函数)
- [5. **定义分块函数**](#5. 定义分块函数)
- [6. **定义 MD5 哈希函数**](#6. 定义 MD5 哈希函数)
- [7. **读取文件并分块**](#7. 读取文件并分块)
- [8. **输出分块数量**](#8. 输出分块数量)
- 总结
完整代码:
python
复制代码
import tiktoken
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
python
复制代码
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
python
复制代码
def chunking_by_token_size(
content: str,
split_by_character=None,
split_by_character_only=False,
overlap_token_size=128,
max_token_size=1024,
tiktoken_model="gpt-4o",
**kwargs,
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
)
else:
new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
"tokens": _len,
"content": chunk.strip(),
"chunk_order_index": index,
}
)
else:
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
return results
python
复制代码
from dataclasses import field
split_by_character=None
split_by_character_only=False
chunk_overlap_token_size: int = 100
chunk_token_size: int = 1200
tiktoken_model_name: str = "gpt-4o-mini"
ENCODER = None
chunking_func = chunking_by_token_size
from hashlib import md5
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
with open("./book.txt", "r", encoding="utf-8") as f:
content = f.read()
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": 1,
}
for dp in chunking_func(
content,
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=chunk_overlap_token_size,
max_token_size=chunk_token_size,
tiktoken_model=tiktoken_model_name
)
}
python
复制代码
len(chunks)
42
代码运行逻辑解释
1. 导入依赖库
python
复制代码
import tiktoken
from dataclasses import field
from hashlib import md5
tiktoken
:用于将文本编码为模型所需的 token。
dataclasses
:用于定义数据类(未在代码中实际使用)。
hashlib
:用于生成 MD5 哈希值,为每个 chunk 生成唯一 ID。
2. 定义全局变量
python
复制代码
split_by_character = None
split_by_character_only = False
chunk_overlap_token_size: int = 100
chunk_token_size: int = 1200
tiktoken_model_name: str = "gpt-4o-mini"
ENCODER = None
chunking_func = chunking_by_token_size
split_by_character
:指定按某个字符分割文本(默认为 None
,表示不按字符分割)。
split_by_character_only
:是否仅按字符分割,而不进一步处理 token 大小。
chunk_overlap_token_size
:chunk 之间的重叠 token 数量。
chunk_token_size
:每个 chunk 的最大 token 数量。
tiktoken_model_name
:使用的 tiktoken 模型名称。
ENCODER
:全局变量,用于存储 tiktoken 编码器实例。
chunking_func
:指向 chunking_by_token_size
函数,用于分块处理文本。
3. 定义编码函数
python
复制代码
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
tokens = ENCODER.encode(content)
return tokens
- 功能:将输入的字符串编码为 token。
- 逻辑:
- 检查全局变量
ENCODER
是否为 None
,如果是,则初始化指定模型的编码器。
- 使用编码器将字符串编码为 token 列表并返回。
4. 定义解码函数
python
复制代码
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
ENCODER = tiktoken.encoding_for_model(model_name)
content = ENCODER.decode(tokens)
return content
- 功能:将 token 列表解码为字符串。
- 逻辑:
- 检查全局变量
ENCODER
是否为 None
,如果是,则初始化指定模型的编码器。
- 使用编码器将 token 列表解码为字符串并返回。
5. 定义分块函数
python
复制代码
def chunking_by_token_size(
content: str,
split_by_character=None,
split_by_character_only=False,
overlap_token_size=128,
max_token_size=1024,
tiktoken_model="gpt-4o",
**kwargs,
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
if split_by_character:
raw_chunks = content.split(split_by_character)
new_chunks = []
if split_by_character_only:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
new_chunks.append((len(_tokens), chunk))
else:
for chunk in raw_chunks:
_tokens = encode_string_by_tiktoken(chunk, model_name=tiktoken_model)
if len(_tokens) > max_token_size:
for start in range(
0, len(_tokens), max_token_size - overlap_token_size
):
chunk_content = decode_tokens_by_tiktoken(
_tokens[start : start + max_token_size],
model_name=tiktoken_model,
)
new_chunks.append(
(min(max_token_size, len(_tokens) - start), chunk_content)
)
else:
new_chunks.append((len(_tokens), chunk))
for index, (_len, chunk) in enumerate(new_chunks):
results.append(
{
"tokens": _len,
"content": chunk.strip(),
"chunk_order_index": index,
}
)
else:
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)
return results
- 功能:将输入的文本按 token 大小分块。
- 逻辑:
- 将文本编码为 token 列表。
- 如果指定了
split_by_character
:
- 按指定字符分割文本。
- 如果
split_by_character_only
为 True
,则直接记录每个 chunk 的 token 数量和内容。
- 否则,检查每个 chunk 的 token 数量是否超过
max_token_size
,如果超过则进一步分块。
- 如果未指定
split_by_character
,则直接按 max_token_size
和 overlap_token_size
分块。
- 返回分块结果,每个 chunk 包含 token 数量、内容和顺序索引。
6. 定义 MD5 哈希函数
python
复制代码
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
- 功能:为每个 chunk 生成唯一的 MD5 哈希 ID。
- 逻辑:
- 将内容编码为字节。
- 计算 MD5 哈希值并转换为十六进制字符串。
- 添加前缀(如
"chunk-"
)并返回。
7. 读取文件并分块
python
复制代码
with open("./book.txt", "r", encoding="utf-8") as f:
content = f.read()
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": 1,
}
for dp in chunking_func(
content,
split_by_character=split_by_character,
split_by_character_only=split_by_character_only,
overlap_token_size=chunk_overlap_token_size,
max_token_size=chunk_token_size,
tiktoken_model=tiktoken_model_name
)
}
- 功能:读取文件内容并分块,生成分块字典。
- 逻辑:
- 打开文件
book.txt
并读取内容。
- 调用
chunking_func
对内容进行分块。
- 为每个 chunk 生成唯一的 MD5 哈希 ID,并将分块信息存储到字典中。
- 每个 chunk 包含以下字段:
tokens
:token 数量。
content
:分块内容。
chunk_order_index
:分块顺序索引。
full_doc_id
:文档 ID(此处固定为 1
)。
8. 输出分块数量
python
复制代码
len(chunks)
- 功能:返回分块字典的长度,即分块数量。
- 结果:
42
,表示生成了 42 个 chunk。
总结
- 代码的主要功能是将文本文件按 token 大小分块,并为每个 chunk 生成唯一 ID。
- 支持按字符分割和按 token 大小分块两种模式。
- 最终生成的分块字典可用于进一步处理或存储。