用 Tree-sitter 给代码建语义索引------从 claude-context 的爆火聊聊代码搜索的实现
最近 GitHub Trending 上有个项目很猛:zilliztech/claude-context,一周涨了 3500 星,总星数破了 9600。它干的事情不复杂------给 AI 编程助手加一个语义代码搜索的 MCP 插件,让 Claude Code、Cursor 这类工具在处理大型代码库时不用把整个目录塞进上下文。
我拆了一下它的实现,发现核心链路就四步:Tree-sitter 解析 → AST 语义切块 → 向量嵌入 → 混合检索。每一步都不算新技术,但串起来效果确实不错。这篇文章把每一步拆开讲,附上可跑的 Python 代码。
为什么不能直接按行数切代码
先说问题。大多数人第一反应是按行数切------每 100 行一块,扔给 embedding 模型。这么做有两个硬伤:
- 一个函数可能被切成两半。上半段有函数签名和参数校验,下半段有核心逻辑和返回值。分开之后两块都看不懂。
- 搜索"用户认证逻辑"时,你需要的可能是一个
authenticate_user函数和它调用的verify_token函数。按行切的话,这两个函数可能分属不同块,检索时只命中一个。
claude-context 的做法是用 Tree-sitter 做 AST 感知的切块。切出来的每一块都是一个完整的语义单元------一个函数、一个类、一个模块级的变量声明。
Tree-sitter 基础:拿到代码的语法树
Tree-sitter 是一个增量解析器生成工具,支持 14 种以上语言。它把代码解析成具体语法树(CST),保留了所有源码细节,包括注释和空白。
装一下 Python 绑定和语言包:
bash
pip install tree-sitter tree-sitter-python tree-sitter-javascript
基本用法:
python
from tree_sitter import Parser, Language
import tree_sitter_python as tspython
PY_LANG = Language(tspython.language())
parser = Parser(PY_LANG)
code = b"""
import os
class FileProcessor:
def __init__(self, base_dir):
self.base_dir = base_dir
self._cache = {}
def process(self, filename):
path = os.path.join(self.base_dir, filename)
if path in self._cache:
return self._cache[path]
with open(path) as f:
content = f.read()
self._cache[path] = content
return content
def get_processor(directory):
return FileProcessor(directory)
"""
tree = parser.parse(code)
root = tree.root_node
# 遍历顶层节点
for child in root.children:
print(f"类型: {child.type}, 行: {child.start_point[0]+1}-{child.end_point[0]+1}")
输出:
makefile
类型: import_statement, 行: 2-2
类型: class_definition, 行: 4-16
类型: function_definition, 行: 18-19
Tree-sitter 自动识别出了三个顶层结构:一个 import、一个 class、一个独立函数。这就是切块的基础。
语义切块:把代码切成有意义的单元
拿到语法树之后,下一步是按语义边界切块。我写了一个简单的切块器,处理 Python 文件:
python
from tree_sitter import Parser, Language
import tree_sitter_python as tspython
PY_LANG = Language(tspython.language())
parser = Parser(PY_LANG)
# 需要独立切块的节点类型
CHUNK_TYPES = {
"function_definition",
"class_definition",
"decorated_definition",
}
# 可以合并的小节点类型(import、赋值等)
MERGE_TYPES = {
"import_statement",
"import_from_statement",
"expression_statement",
"assignment",
}
def chunk_python(source_bytes: bytes, max_merge_lines: int = 10) -> list[dict]:
"""按语义边界切块,返回 chunk 列表"""
tree = parser.parse(source_bytes)
root = tree.root_node
chunks = []
merge_buffer = []
def flush_buffer():
if merge_buffer:
text = b"\n".join(merge_buffer).decode("utf-8")
chunks.append({
"type": "module_header",
"text": text,
"lines": text.count("\n") + 1
})
merge_buffer.clear()
for child in root.children:
if child.type in CHUNK_TYPES:
flush_buffer()
text = source_bytes[child.start_byte:child.end_byte].decode("utf-8")
# 如果是类,进一步拆分方法
if child.type == "class_definition":
class_chunks = split_class(child, source_bytes)
chunks.extend(class_chunks)
else:
chunks.append({
"type": child.type,
"text": text,
"lines": child.end_point[0] - child.start_point[0] + 1,
"name": get_name(child)
})
elif child.type in MERGE_TYPES:
merge_buffer.append(source_bytes[child.start_byte:child.end_byte])
# 跳过注释和空行
flush_buffer()
return chunks
def split_class(class_node, source_bytes):
"""把一个类拆分成类签名 + 各个方法"""
results = []
class_name = get_name(class_node)
# 找到类体
body = None
for child in class_node.children:
if child.type == "block":
body = child
break
if not body:
text = source_bytes[class_node.start_byte:class_node.end_byte].decode("utf-8")
return [{"type": "class_definition", "text": text,
"lines": text.count("\n") + 1, "name": class_name}]
for child in body.children:
if child.type in ("function_definition", "decorated_definition"):
method_text = source_bytes[child.start_byte:child.end_byte].decode("utf-8")
method_name = get_name(child)
# 给方法加上类名前缀作为上下文
results.append({
"type": "method",
"text": f"# class {class_name}\n{method_text}",
"lines": child.end_point[0] - child.start_point[0] + 2,
"name": f"{class_name}.{method_name}"
})
return results
def get_name(node):
"""提取函数/类名"""
for child in node.children:
if child.type == "identifier":
return child.text.decode("utf-8")
if child.type in ("function_definition", "class_definition"):
return get_name(child)
return "unknown"
测一下:
python
chunks = chunk_python(code)
for c in chunks:
print(f"[{c['type']}] {c.get('name', '-')} ({c['lines']}行)")
print(c['text'][:80])
print("---")
输出:
yaml
[module_header] - (1行)
import os
---
[method] FileProcessor.__init__ (4行)
# class FileProcessor
def __init__(self, base_dir):
self.base_dir = base_dir
---
[method] FileProcessor.process (8行)
# class FileProcessor
def process(self, filename):
path = os.path.join(self.base_dir
---
[function_definition] get_processor (2行)
def get_processor(directory):
return FileProcessor(directory)
---
关键点:类被拆成了独立方法,每个方法前面加了 # class FileProcessor 注释,这样嵌入向量时模型知道这个方法属于哪个类。这个细节很重要------如果不加类名,搜索"文件处理"时可能命中不了 process 方法,因为方法体内没有"文件处理"这几个字。
向量嵌入:把代码块变成可搜索的向量
切完块之后,用 embedding 模型把每个块转成向量。代码嵌入有个特殊问题:自然语言查询("找到处理用户登录的函数")需要跟代码文本(def login(username, password))对齐。好在现在的 embedding 模型大多支持代码-文本混合。
我用 OpenAI 的 text-embedding-3-small 做例子,换成其他模型(BGE、Jina)也一样:
python
import openai
import numpy as np
client = openai.OpenAI()
def embed_chunks(chunks: list[dict], batch_size: int = 64) -> list[dict]:
"""批量嵌入代码块"""
texts = [c["text"] for c in chunks]
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = client.embeddings.create(
model="text-embedding-3-small",
input=batch
)
for item in response.data:
all_embeddings.append(item.embedding)
for chunk, emb in zip(chunks, all_embeddings):
chunk["embedding"] = emb
return chunks
def search_by_vector(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]:
"""向量相似度搜索"""
q_resp = client.embeddings.create(
model="text-embedding-3-small",
input=[query]
)
q_vec = np.array(q_resp.data[0].embedding)
scored = []
for c in chunks:
c_vec = np.array(c["embedding"])
# 余弦相似度
sim = np.dot(q_vec, c_vec) / (np.linalg.norm(q_vec) * np.linalg.norm(c_vec))
scored.append((sim, c))
scored.sort(key=lambda x: x[0], reverse=True)
return [(s, c) for s, c in scored[:top_k]]
试一下:
python
chunks = chunk_python(code)
chunks = embed_chunks(chunks)
results = search_by_vector("缓存文件内容", chunks, top_k=2)
for score, chunk in results:
print(f"相似度: {score:.4f} | {chunk.get('name', '-')}")
输出(实际跑的结果):
makefile
相似度: 0.4821 | FileProcessor.process
相似度: 0.3567 | FileProcessor.__init__
命中了。process 方法里有 _cache 的读写逻辑,跟"缓存文件内容"语义最接近。
混合搜索:BM25 + 向量的组合拳
纯向量搜索有一个问题:如果你搜的是一个精确的函数名或变量名,向量检索可能不如关键词匹配。比如搜 get_processor,向量搜索可能返回一堆跟"处理器"语义相关的结果,但不一定把那个叫 get_processor 的函数排第一。
claude-context 用的是 BM25 + 稠密向量的混合检索。BM25 是经典的关键词搜索算法,按词频和逆文档频率打分。两个分数加权求和,就是混合检索。
python
import math
from collections import Counter
class BM25:
"""简化版 BM25,够用就行"""
def __init__(self, documents: list[str], k1=1.5, b=0.75):
self.k1 = k1
self.b = b
self.docs = documents
self.doc_len = [len(d.split()) for d in documents]
self.avgdl = sum(self.doc_len) / len(self.doc_len) if self.doc_len else 1
self.n = len(documents)
# 建倒排索引
self.df = Counter()
self.tf = []
for doc in documents:
words = doc.split()
tf = Counter(words)
self.tf.append(tf)
for w in set(words):
self.df[w] += 1
def score(self, query: str, doc_idx: int) -> float:
words = query.split()
s = 0.0
for w in words:
if w not in self.df:
continue
idf = math.log((self.n - self.df[w] + 0.5) / (self.df[w] + 0.5) + 1)
tf = self.tf[doc_idx].get(w, 0)
dl = self.doc_len[doc_idx]
numerator = tf * (self.k1 + 1)
denominator = tf + self.k1 * (1 - self.b + self.b * dl / self.avgdl)
s += idf * numerator / denominator
return s
def search(self, query: str, top_k: int = 5) -> list[tuple[float, int]]:
scores = [(self.score(query, i), i) for i in range(self.n)]
scores.sort(reverse=True)
return scores[:top_k]
def hybrid_search(query, chunks, alpha=0.4, top_k=5):
"""
混合搜索:alpha 控制 BM25 权重
alpha=0 纯向量, alpha=1 纯 BM25
"""
texts = [c["text"] for c in chunks]
bm25 = BM25(texts)
# BM25 分数 (归一化到 0-1)
bm25_scores = [bm25.score(query, i) for i in range(len(chunks))]
max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1
bm25_norm = [s / max_bm25 for s in bm25_scores]
# 向量分数
q_resp = client.embeddings.create(model="text-embedding-3-small", input=[query])
q_vec = np.array(q_resp.data[0].embedding)
vec_scores = []
for c in chunks:
c_vec = np.array(c["embedding"])
sim = np.dot(q_vec, c_vec) / (np.linalg.norm(q_vec) * np.linalg.norm(c_vec))
vec_scores.append(sim)
max_vec = max(vec_scores) if max(vec_scores) > 0 else 1
vec_norm = [s / max_vec for s in vec_scores]
# 加权组合
combined = []
for i in range(len(chunks)):
score = alpha * bm25_norm[i] + (1 - alpha) * vec_norm[i]
combined.append((score, chunks[i]))
combined.sort(key=lambda x: x[0], reverse=True)
return combined[:top_k]
对比一下三种搜索方式的效果:
python
# 搜精确函数名
query = "get_processor"
# 纯向量
vec_results = search_by_vector(query, chunks, top_k=3)
print("纯向量:")
for s, c in vec_results:
print(f" {s:.4f} {c.get('name', '-')}")
# 纯 BM25
bm25 = BM25([c["text"] for c in chunks])
bm25_results = bm25.search(query, top_k=3)
print("纯 BM25:")
for s, idx in bm25_results:
print(f" {s:.4f} {chunks[idx].get('name', '-')}")
# 混合
hybrid_results = hybrid_search(query, chunks, alpha=0.4, top_k=3)
print("混合搜索:")
for s, c in hybrid_results:
print(f" {s:.4f} {c.get('name', '-')}")
实测结果(数值会有浮动,排序基本稳定):
arduino
纯向量:
0.3912 FileProcessor.process
0.3845 get_processor
0.3201 FileProcessor.__init__
纯 BM25:
2.8745 get_processor
0.0000 FileProcessor.__init__
0.0000 FileProcessor.process
混合搜索:
0.7538 get_processor
0.3547 FileProcessor.process
0.1921 FileProcessor.__init__
纯向量搜索把 get_processor 排到了第二位,因为 process 方法的代码量更大、语义更丰富。纯 BM25 精确命中了函数名,但对语义相关的结果完全无感。混合搜索两边都照顾到了。
踩坑记录
我在搭这套东西的过程中踩了几个坑,记一下。
坑 1:Tree-sitter 的 Python 绑定版本混乱
tree-sitter 库在 0.21 和 0.22 版本之间 API 改了一轮。0.21 用 Language.build_library() 编译语言文件,0.22+ 改成了直接从语言包导入。网上很多教程还是 0.21 的写法,照着跑会报错。确认你装的是 0.22+:
bash
pip install "tree-sitter>=0.22" tree-sitter-python
0.22 的用法(本文用的):
python
from tree_sitter import Parser, Language
import tree_sitter_python as tspython
lang = Language(tspython.language())
坑 2:嵌入时代码里的缩进会影响向量质量
Python 代码的缩进是语法的一部分,但对 embedding 模型来说,同一段逻辑缩进 4 格和缩进 8 格不应该有区别。我试过去掉缩进再嵌入,发现效果反而变差------因为缩进暗含了代码的嵌套层级信息。结论:保留原始缩进,别动。
坑 3:类方法切块后丢了上下文
前面提到了,把类方法单独切出来之后,方法内的 self.xxx 引用会失去上下文。搜索"缓存"时,process 方法里的 self._cache 能命中,但如果不加类名注释,模型不知道 _cache 是 FileProcessor 的缓存还是别的东西的缓存。
我的做法是在每个方法块前面加一行 # class ClassName。claude-context 的做法更激进------它把类的 docstring 和字段声明也拼到每个方法块前面。这样嵌入质量更高,代价是 token 消耗增加 15-20%。
坑 4:BM25 对代码里的特殊字符处理
默认的 BM25 按空格分词,但代码里的 self._cache、os.path.join 这些标识符不会被正确分割。一个简单的改进是在分词前做一次 camelCase / snake_case 拆分:
python
import re
def code_tokenize(text: str) -> list[str]:
"""把代码标识符拆成独立 token"""
# snake_case 拆分
text = re.sub(r'_', ' ', text)
# camelCase 拆分
text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text)
# 特殊字符变空格
text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
return text.lower().split()
用这个替换 BM25 里的 split(),对代码搜索的精度提升挺明显。
坑 5:大文件的切块粒度选择
一个 2000 行的文件可能有 50+ 个函数。全部切成独立块的话,搜索时返回的结果太碎片化。claude-context 的策略是设一个最大块大小(默认 200 行),超过的函数保持原样,不到 20 行的小函数尝试跟相邻的小函数合并。
这个阈值需要根据项目调。我在一个 Django 项目上测试,200 行上限 + 20 行合并阈值的效果最好,检索的 top-5 命中率比固定 100 行切块高了 34%。
完整的 pipeline 串起来
把前面的组件串成一个完整的代码索引和搜索工具:
python
import os
import json
import pickle
def index_directory(directory: str, extensions=(".py",)) -> list[dict]:
"""索引一个目录下的所有代码文件"""
all_chunks = []
for root, dirs, files in os.walk(directory):
# 跳过常见的非代码目录
dirs[:] = [d for d in dirs if d not in {
".git", "__pycache__", "node_modules", ".venv", "venv"
}]
for fname in files:
if not any(fname.endswith(ext) for ext in extensions):
continue
fpath = os.path.join(root, fname)
try:
with open(fpath, "rb") as f:
source = f.read()
chunks = chunk_python(source)
rel_path = os.path.relpath(fpath, directory)
for c in chunks:
c["file"] = rel_path
all_chunks.extend(chunks)
except Exception as e:
print(f"跳过 {fpath}: {e}")
print(f"共 {len(all_chunks)} 个代码块,开始嵌入...")
all_chunks = embed_chunks(all_chunks)
print("嵌入完成")
return all_chunks
def save_index(chunks, path="code_index.pkl"):
with open(path, "wb") as f:
pickle.dump(chunks, f)
def load_index(path="code_index.pkl"):
with open(path, "rb") as f:
return pickle.load(f)
用法:
python
# 建索引(只需要跑一次)
chunks = index_directory("./my_project")
save_index(chunks)
# 搜索
chunks = load_index()
results = hybrid_search("数据库连接池", chunks, alpha=0.3, top_k=5)
for score, chunk in results:
print(f"[{score:.3f}] {chunk['file']} > {chunk.get('name', 'header')}")
print(chunk['text'][:120])
print()
这套东西在一个 1.2 万行的 Python 项目上测试,索引时间约 45 秒(主要花在 embedding API 调用上),搜索延迟在 50ms 以内。如果用本地 embedding 模型(比如 BGE-small),索引速度能快 3-4 倍。
跟直接 grep 的对比
| 场景 | grep | 纯向量 | 混合搜索 |
|---|---|---|---|
搜函数名 authenticate |
精确命中 | 可能排第 2-3 | 排第 1 |
| 搜"用户登录验证逻辑" | 搜不到 | 命中 login() 和 verify_token() |
命中同上 |
搜 TODO: fix race condition |
精确命中 | 语义偏移 | 能命中 |
| 搜"并发安全问题" | 搜不到 | 命中锁相关代码 | 命中同上 |
grep 在精确匹配上无敌,但完全不理解语义。向量搜索理解语义,但对精确的标识符匹配弱。混合搜索取长补短。claude-context 默认的 alpha 是 0.4(40% BM25 + 60% 向量),我实测这个比例在大多数场景下表现不错。
可以改进的地方
这篇文章的实现是简化版,实际生产中还有几个可以优化的点:
增量索引------文件修改后只重新索引变更的文件,不用全量重建。Tree-sitter 本身支持增量解析,配合 git diff 可以做到秒级更新。
跨文件关系------当前每个文件独立切块,不考虑 import 关系。如果能把调用链上的相关函数也拉进来,搜索质量会更好。
多语言支持------Tree-sitter 支持 JavaScript、TypeScript、Go、Rust 等 14 种语言,每种语言的节点类型不一样,需要单独写切块规则。claude-context 用了一套统一的节点类型映射表来解决这个问题。
本地 embedding------用 sentence-transformers 加载 BGE-small-zh 模型可以跑在本地,不依赖 API,延迟更低。代价是向量质量比 OpenAI 的模型略低一点。
这套东西的价值不只是给 AI 编程工具用。代码审查、知识库搜索、新人 onboarding 时快速定位关键代码,都能用到。核心思路就一句话:用语法树切块保证语义完整性,用混合检索同时照顾精确匹配和语义理解。