十、训练入口
train.py
import os, json, joblib
import numpy as np
from pathlib import Path
from collections import Counter
from parse_log import parse_log_file
from features import GameFeatureExtractor
from predictors import MLPredictor
════════════════════════════════════════════════════
TrainingSampleBuilder
════════════════════════════════════════════════════
class TrainingSampleBuilder:
CATEGORY_INTENT = {
'Action': '交战',
'Fire': '交战',
'SkillStart': '技能交战',
'Grenade': '投弹交战',
'BeingResuce': '被救援/等待救援',
'Looting': '搜寻物资'
}
CATEGORY_ACTIONS = {
'Action': '靠近敌人', '开镜瞄准', '开火射击',
'Fire': '开火射击', '开镜瞄准', '换弹',
'SkillStart': '使用技能', '开镜瞄准', '开火射击',
'Grenade': '投掷手雷', '靠近敌人', '开火射击',
'BeingResuce': '原地不动/潜伏', '被救援中',
'Looting': '移动至物资点', '搜寻物资', '拾取物品'
}
def init(self, train_root: str, extractor: GameFeatureExtractor):
self.train_root = Path(train_root)
self.extractor = extractor
def build_all(self) -> list:
samples = \[\]
for cat, intent in self.CATEGORY_INTENT.items():
folder = self.train_root / cat
if not folder.exists():
print(f" ⚠ 文件夹不存在: {folder}")
continue
files = list(folder.glob('*.txt'))
print(f" {cat:15s} → {len(files):4d} 个文件")
for f in files:
try:
s = self._build_one(f, cat)
if s: samples.append(s)
except Exception as e:
print(f" 处理失败 {f.name}: {e}")
return samples
def _build_one(self, filepath: Path, category: str) -> dict:
parsed = parse_log_file(str(filepath))
if not parsed'player_states':
return None
player_team_map = {
i'player_id': i'team_id'
for i in parsed'game_info'
}
确定主玩家
if parsed'actions':
counts = Counter(a.player_id for a in parsed'actions')
main_id = counts.most_common(1)00
elif parsed'game_info':
main_id = parsed'game_info'0'player_id'
else:
return None
提取特征
features = {}
features.update(self.extractor.extract_player_features(
parsed'player_states', main_id, {}
))
features.update(self.extractor.extract_spatial_features(
parsed'player_states', main_id, player_team_map
))
features.update(self.extractor.extract_action_features(
parsed'actions', main_id
))
return {
'file': str(filepath),
'category': category,
'main_player': main_id,
'features': features,
'label': {
'intent': self.CATEGORY_INTENTcategory,
'actions': self.CATEGORY_ACTIONScategory
}
}
════════════════════════════════════════════════════
训练主函数
════════════════════════════════════════════════════
def main():
print("=" * 60)
print(" 三角洲 GameGPT 训练流程")
print("=" * 60)
os.makedirs('models', exist_ok=True)
Step1: 特征提取器
print("\n1/3 初始化特征提取器")
extractor = GameFeatureExtractor()
Step2: 构造训练样本
print("\n2/3 构造训练样本")
builder = TrainingSampleBuilder(
train_root='data/train',
extractor=extractor
)
samples = builder.build_all()
print(f"\n 共构造样本: {len(samples)} 个")
类别分布
cat_dist = Counter(s'category' for s in samples)
print(" 类别分布:")
for cat, cnt in sorted(cat_dist.items()):
print(f" {cat:15s}: {cnt}")
保存样本元数据
meta = [{
'file': s'file',
'category': s'category',
'intent': s'label''intent'
} for s in samples]
with open('models/samples_meta.json', 'w', encoding='utf-8') as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
Step3: 训练ML分类器
print("\n3/3 训练ML分类器")
ml = MLPredictor()
ml.train(samples)
print("\n ✓ 训练完成!")
print(" 模型文件:")
print(" models/intent_clf.pkl")
print(" models/intent_encoder.pkl")
print(" models/samples_meta.json")
if name == 'main':
main()
十一、主入口
main.py ── 统一入口
import argparse
import sys
def parse_args():
parser = argparse.ArgumentParser(
description='三角洲 GameGPT 初赛方案'
)
subparsers = parser.add_subparsers(dest='mode')
── train 子命令 ───────────────────────────────
train_p = subparsers.add_parser('train', help='训练ML分类器')
train_p.add_argument(
'--train_dir', default='data/train',
help='训练数据根目录'
)
── infer 子命令 ───────────────────────────────
infer_p = subparsers.add_parser('infer', help='批量推理')
infer_p.add_argument(
'--test_dir', default='data/test',
help='测试txt文件目录'
)
infer_p.add_argument(
'--output', default='output/submit.xlsx',
help='输出xlsx路径'
)
infer_p.add_argument(
'--backend', default='auto',
choices='auto', 'ollama', 'vllm', 'transformers',
help='LLM后端选择'
)
infer_p.add_argument(
'--model_name', default='qwen2.5:7b',
help='Ollama/vLLM模型名称'
)
infer_p.add_argument(
'--model_path', default=None,
help='本地模型路径(transformers后端使用)'
)
infer_p.add_argument(
'--ollama_url', default='http://localhost:11434',
help='Ollama服务地址'
)
infer_p.add_argument(
'--vllm_url', default='http://localhost:8000/v1',
help='vLLM服务地址'
)
infer_p.add_argument(
'--no_ml', action='store_true',
help='不使用ML分类器'
)
infer_p.add_argument(
'--no_debug', action='store_true',
help='关闭详细日志'
)
── eval 子命令(用训练集交叉验证)────────────
eval_p = subparsers.add_parser('eval', help='在训练集上评估')
eval_p.add_argument(
'--train_dir', default='data/train'
)
eval_p.add_argument(
'--backend', default='auto'
)
eval_p.add_argument(
'--model_name', default='qwen2.5:7b'
)
eval_p.add_argument(
'--model_path', default=None
)
return parser.parse_args()
def run_train(args):
"""训练流程"""
import train as train_module
可自定义训练目录
import os
if hasattr(args, 'train_dir') and args.train_dir != 'data/train':
os.environ'TRAIN_DIR' = args.train_dir
train_module.main()
def run_infer(args):
"""推理流程"""
from inference import run_inference
run_inference(
test_dir = args.test_dir,
output_path= args.output,
model_name = args.model_name,
model_path = args.model_path,
backend = args.backend,
use_ml = not args.no_ml,
debug_log = not args.no_debug
)
def run_eval(args):
"""
在训练集上进行评估
随机抽取20%作为验证集,计算意图准确率
"""
import random
import json
from pathlib import Path
from collections import defaultdict
from parse_log import parse_log_file
from features import GameFeatureExtractor
from local_llm import LocalLLMPredictor
from predictors import RuleBasedPredictor, MLPredictor, EnsemblePredictor
from inference import (
determine_main_player,
build_compact_text,
process_single_file
)
import joblib
print("=" * 60)
print(" 训练集评估流程")
print("=" * 60)
CATEGORY_INTENT = {
'Action': '交战',
'Fire': '交战',
'SkillStart': '技能交战',
'Grenade': '投弹交战',
'BeingResuce': '被救援/等待救援',
'Looting': '搜寻物资'
}
train_root = Path(args.train_dir)
extractor = GameFeatureExtractor()
收集验证样本(每类取20%)
val_samples = \[\]
for cat, intent in CATEGORY_INTENT.items():
folder = train_root / cat
if not folder.exists():
continue
files = list(folder.glob('*.txt'))
random.shuffle(files)
val_files = files:max(1, int(len(files) \* 0.2))
for f in val_files:
val_samples.append({'file': f, 'true_intent': intent})
print(f"验证样本数: {len(val_samples)}")
初始化预测器
rule_p = RuleBasedPredictor()
ml_p = None
if Path('models/intent_clf.pkl').exists():
ml_p = MLPredictor()
ml_p.intent_clf = joblib.load('models/intent_clf.pkl')
ml_p.intent_encoder = joblib.load('models/intent_encoder.pkl')
llm_p = LocalLLMPredictor(
backend = args.backend,
model_name = args.model_name,
model_path = args.model_path
)
predictor = EnsemblePredictor(
rule_predictor = rule_p,
ml_predictor = ml_p,
llm_predictor = llm_p,
weights = {
'rule': 0.15,
'ml': 0.25 if ml_p else 0.0,
'llm': 0.60 if ml_p else 0.85
}
)
逐条评估
correct, total = 0, 0
per_class = defaultdict(lambda: {'correct': 0, 'total': 0})
errors = \[\]
for item in val_samples:
try:
parsed = parse_log_file(str(item'file'))
main_id = determine_main_player(parsed)
ptmap = {
i'player_id': i'team_id'
for i in parsed'game_info'
}
feats = {}
feats.update(extractor.extract_player_features(
parsed'player_states', main_id, {}
))
feats.update(extractor.extract_spatial_features(
parsed'player_states', main_id, ptmap
))
feats.update(extractor.extract_action_features(
parsed'actions', main_id
))
text = build_compact_text(
parsed, main_id, ptmap, feats
)
pred = predictor.predict(feats, text, parsed'actions')
true_intent = item'true_intent'
pred_intent = pred'intent'
is_correct = (true_intent == pred_intent)
if is_correct:
correct += 1
total += 1
per_classtrue_intent'total' += 1
per_classtrue_intent'correct' += int(is_correct)
if not is_correct:
errors.append({
'file': str(item'file'),
'true': true_intent,
'pred': pred_intent
})
except Exception as e:
print(f" 评估失败 {item'file'}: {e}")
total += 1
打印评估结果
print("\n" + "=" * 50)
print(f" 总体准确率: {correct}/{total} = {correct/max(total,1)*100:.2f}%")
print("=" * 50)
print(" 各类准确率:")
for intent, stat in sorted(per_class.items()):
acc = stat'correct' / max(stat'total', 1) * 100
bar = '█' * int(acc / 5)
print(f" {intent:20s} {acc:5.1f}% {bar}")
print(f"\n 错误样本数: {len(errors)}")
if errors:3:
print(" 错误示例:")
for e in errors:3:
print(f" 文件: {Path(e'file').name}")
print(f" 真实: {e'true'} 预测: {e'pred'}")
保存评估报告
import os
os.makedirs('output', exist_ok=True)
with open('output/eval_report.json', 'w', encoding='utf-8') as f:
json.dump({
'accuracy': correct / max(total, 1),
'per_class': dict(per_class),
'errors': errors
}, f, ensure_ascii=False, indent=2)
print("\n 评估报告已保存至 output/eval_report.json")
if name == 'main':
args = parse_args()
if args.mode == 'train':
run_train(args)
elif args.mode == 'infer':
run_infer(args)
elif args.mode == 'eval':
run_eval(args)
else:
print("请指定子命令: train / infer / eval")
print("示例:")
print(" python main.py train")
print(" python main.py infer --backend ollama --model_name qwen2.5:7b")
print(" python main.py eval --backend ollama --model_name qwen2.5:7b")
sys.exit(1)