十、训练入口
train.py
import os, json, joblib
import numpy as np
from pathlib import Path
from collections import Counter
from parse_log import parse_log_file
from features import GameFeatureExtractor
from predictors import MLPredictor
════════════════════════════════════════════════════
TrainingSampleBuilder
════════════════════════════════════════════════════
class TrainingSampleBuilder:
CATEGORY_INTENT = {
'Action': '交战',
'Fire': '交战',
'SkillStart': '技能交战',
'Grenade': '投弹交战',
'BeingResuce': '被救援/等待救援',
'Looting': '搜寻物资'
}
CATEGORY_ACTIONS = {
'Action': ['靠近敌人', '开镜瞄准', '开火射击'],
'Fire': ['开火射击', '开镜瞄准', '换弹'],
'SkillStart': ['使用技能', '开镜瞄准', '开火射击'],
'Grenade': ['投掷手雷', '靠近敌人', '开火射击'],
'BeingResuce': ['原地不动/潜伏', '被救援中'],
'Looting': ['移动至物资点', '搜寻物资', '拾取物品']
}
def init(self, train_root: str, extractor: GameFeatureExtractor):
self.train_root = Path(train_root)
self.extractor = extractor
def build_all(self) -> list:
samples = []
for cat, intent in self.CATEGORY_INTENT.items():
folder = self.train_root / cat
if not folder.exists():
print(f" ⚠ 文件夹不存在: {folder}")
continue
files = list(folder.glob('*.txt'))
print(f" {cat:15s} → {len(files):4d} 个文件")
for f in files:
try:
s = self._build_one(f, cat)
if s: samples.append(s)
except Exception as e:
print(f" 处理失败 {f.name}: {e}")
return samples
def _build_one(self, filepath: Path, category: str) -> dict:
parsed = parse_log_file(str(filepath))
if not parsed['player_states']:
return None
player_team_map = {
i['player_id']: i['team_id']
for i in parsed['game_info']
}
确定主玩家
if parsed['actions']:
counts = Counter(a.player_id for a in parsed['actions'])
main_id = counts.most_common(1)[0][0]
elif parsed['game_info']:
main_id = parsed['game_info'][0]['player_id']
else:
return None
提取特征
features = {}
features.update(self.extractor.extract_player_features(
parsed['player_states'], main_id, {}
))
features.update(self.extractor.extract_spatial_features(
parsed['player_states'], main_id, player_team_map
))
features.update(self.extractor.extract_action_features(
parsed['actions'], main_id
))
return {
'file': str(filepath),
'category': category,
'main_player': main_id,
'features': features,
'label': {
'intent': self.CATEGORY_INTENT[category],
'actions': self.CATEGORY_ACTIONS[category]
}
}
════════════════════════════════════════════════════
训练主函数
════════════════════════════════════════════════════
def main():
print("=" * 60)
print(" 三角洲 GameGPT 训练流程")
print("=" * 60)
os.makedirs('models', exist_ok=True)
Step1: 特征提取器
print("\n[1/3] 初始化特征提取器")
extractor = GameFeatureExtractor()
Step2: 构造训练样本
print("\n[2/3] 构造训练样本")
builder = TrainingSampleBuilder(
train_root='data/train',
extractor=extractor
)
samples = builder.build_all()
print(f"\n 共构造样本: {len(samples)} 个")
类别分布
cat_dist = Counter(s['category'] for s in samples)
print(" 类别分布:")
for cat, cnt in sorted(cat_dist.items()):
print(f" {cat:15s}: {cnt}")
保存样本元数据
meta = [{
'file': s['file'],
'category': s['category'],
'intent': s['label']['intent']
} for s in samples]
with open('models/samples_meta.json', 'w', encoding='utf-8') as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
Step3: 训练ML分类器
print("\n[3/3] 训练ML分类器")
ml = MLPredictor()
ml.train(samples)
print("\n ✓ 训练完成!")
print(" 模型文件:")
print(" models/intent_clf.pkl")
print(" models/intent_encoder.pkl")
print(" models/samples_meta.json")
if name == 'main':
main()
十一、主入口
main.py ── 统一入口
import argparse
import sys
def parse_args():
parser = argparse.ArgumentParser(
description='三角洲 GameGPT 初赛方案'
)
subparsers = parser.add_subparsers(dest='mode')
── train 子命令 ───────────────────────────────
train_p = subparsers.add_parser('train', help='训练ML分类器')
train_p.add_argument(
'--train_dir', default='data/train',
help='训练数据根目录'
)
── infer 子命令 ───────────────────────────────
infer_p = subparsers.add_parser('infer', help='批量推理')
infer_p.add_argument(
'--test_dir', default='data/test',
help='测试txt文件目录'
)
infer_p.add_argument(
'--output', default='output/submit.xlsx',
help='输出xlsx路径'
)
infer_p.add_argument(
'--backend', default='auto',
choices=['auto', 'ollama', 'vllm', 'transformers'],
help='LLM后端选择'
)
infer_p.add_argument(
'--model_name', default='qwen2.5:7b',
help='Ollama/vLLM模型名称'
)
infer_p.add_argument(
'--model_path', default=None,
help='本地模型路径(transformers后端使用)'
)
infer_p.add_argument(
'--ollama_url', default='http://localhost:11434',
help='Ollama服务地址'
)
infer_p.add_argument(
'--vllm_url', default='http://localhost:8000/v1',
help='vLLM服务地址'
)
infer_p.add_argument(
'--no_ml', action='store_true',
help='不使用ML分类器'
)
infer_p.add_argument(
'--no_debug', action='store_true',
help='关闭详细日志'
)
── eval 子命令(用训练集交叉验证)────────────
eval_p = subparsers.add_parser('eval', help='在训练集上评估')
eval_p.add_argument(
'--train_dir', default='data/train'
)
eval_p.add_argument(
'--backend', default='auto'
)
eval_p.add_argument(
'--model_name', default='qwen2.5:7b'
)
eval_p.add_argument(
'--model_path', default=None
)
return parser.parse_args()
def run_train(args):
"""训练流程"""
import train as train_module
可自定义训练目录
import os
if hasattr(args, 'train_dir') and args.train_dir != 'data/train':
os.environ['TRAIN_DIR'] = args.train_dir
train_module.main()
def run_infer(args):
"""推理流程"""
from inference import run_inference
run_inference(
test_dir = args.test_dir,
output_path= args.output,
model_name = args.model_name,
model_path = args.model_path,
backend = args.backend,
use_ml = not args.no_ml,
debug_log = not args.no_debug
)
def run_eval(args):
"""
在训练集上进行评估
随机抽取20%作为验证集,计算意图准确率
"""
import random
import json
from pathlib import Path
from collections import defaultdict
from parse_log import parse_log_file
from features import GameFeatureExtractor
from local_llm import LocalLLMPredictor
from predictors import RuleBasedPredictor, MLPredictor, EnsemblePredictor
from inference import (
determine_main_player,
build_compact_text,
process_single_file
)
import joblib
print("=" * 60)
print(" 训练集评估流程")
print("=" * 60)
CATEGORY_INTENT = {
'Action': '交战',
'Fire': '交战',
'SkillStart': '技能交战',
'Grenade': '投弹交战',
'BeingResuce': '被救援/等待救援',
'Looting': '搜寻物资'
}
train_root = Path(args.train_dir)
extractor = GameFeatureExtractor()
收集验证样本(每类取20%)
val_samples = []
for cat, intent in CATEGORY_INTENT.items():
folder = train_root / cat
if not folder.exists():
continue
files = list(folder.glob('*.txt'))
random.shuffle(files)
val_files = files[:max(1, int(len(files) * 0.2))]
for f in val_files:
val_samples.append({'file': f, 'true_intent': intent})
print(f"验证样本数: {len(val_samples)}")
初始化预测器
rule_p = RuleBasedPredictor()
ml_p = None
if Path('models/intent_clf.pkl').exists():
ml_p = MLPredictor()
ml_p.intent_clf = joblib.load('models/intent_clf.pkl')
ml_p.intent_encoder = joblib.load('models/intent_encoder.pkl')
llm_p = LocalLLMPredictor(
backend = args.backend,
model_name = args.model_name,
model_path = args.model_path
)
predictor = EnsemblePredictor(
rule_predictor = rule_p,
ml_predictor = ml_p,
llm_predictor = llm_p,
weights = {
'rule': 0.15,
'ml': 0.25 if ml_p else 0.0,
'llm': 0.60 if ml_p else 0.85
}
)
逐条评估
correct, total = 0, 0
per_class = defaultdict(lambda: {'correct': 0, 'total': 0})
errors = []
for item in val_samples:
try:
parsed = parse_log_file(str(item['file']))
main_id = determine_main_player(parsed)
ptmap = {
i['player_id']: i['team_id']
for i in parsed['game_info']
}
feats = {}
feats.update(extractor.extract_player_features(
parsed['player_states'], main_id, {}
))
feats.update(extractor.extract_spatial_features(
parsed['player_states'], main_id, ptmap
))
feats.update(extractor.extract_action_features(
parsed['actions'], main_id
))
text = build_compact_text(
parsed, main_id, ptmap, feats
)
pred = predictor.predict(feats, text, parsed['actions'])
true_intent = item['true_intent']
pred_intent = pred['intent']
is_correct = (true_intent == pred_intent)
if is_correct:
correct += 1
total += 1
per_class[true_intent]['total'] += 1
per_class[true_intent]['correct'] += int(is_correct)
if not is_correct:
errors.append({
'file': str(item['file']),
'true': true_intent,
'pred': pred_intent
})
except Exception as e:
print(f" 评估失败 {item['file']}: {e}")
total += 1
打印评估结果
print("\n" + "=" * 50)
print(f" 总体准确率: {correct}/{total} = {correct/max(total,1)*100:.2f}%")
print("=" * 50)
print(" 各类准确率:")
for intent, stat in sorted(per_class.items()):
acc = stat['correct'] / max(stat['total'], 1) * 100
bar = '█' * int(acc / 5)
print(f" {intent:20s} {acc:5.1f}% [{bar}]")
print(f"\n 错误样本数: {len(errors)}")
if errors[:3]:
print(" 错误示例:")
for e in errors[:3]:
print(f" 文件: {Path(e['file']).name}")
print(f" 真实: {e['true']} 预测: {e['pred']}")
保存评估报告
import os
os.makedirs('output', exist_ok=True)
with open('output/eval_report.json', 'w', encoding='utf-8') as f:
json.dump({
'accuracy': correct / max(total, 1),
'per_class': dict(per_class),
'errors': errors
}, f, ensure_ascii=False, indent=2)
print("\n 评估报告已保存至 output/eval_report.json")
if name == 'main':
args = parse_args()
if args.mode == 'train':
run_train(args)
elif args.mode == 'infer':
run_infer(args)
elif args.mode == 'eval':
run_eval(args)
else:
print("请指定子命令: train / infer / eval")
print("示例:")
print(" python main.py train")
print(" python main.py infer --backend ollama --model_name qwen2.5:7b")
print(" python main.py eval --backend ollama --model_name qwen2.5:7b")
sys.exit(1)