功能说明
本代码实现了一个基于LSTM神经网络的多任务量化交易策略框架,通过协同优化遗忘门参数实现趋势跟踪与均值回归双目标的动态平衡。系统包含数据预处理模块、双任务损失函数设计、遗忘门协同优化机制和实盘交易接口,支持在保持模型泛化能力的同时抑制过拟合风险。核心创新在于将传统单任务LSTM扩展为双输出结构,分别捕捉价格序列的趋势延续性和均值回复特性。
技术架构设计
1. 多任务网络拓扑
采用共享编码层+任务特化解译层的混合架构,底层LSTM单元通过参数绑定实现特征共享,上层分设趋势解码器和均值解码器。关键创新点在于引入可学习的遗忘门协调矩阵,动态调整两个子任务对隐藏状态更新的影响权重。
python
class DualTaskLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.shared_lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.trend_head = nn.Linear(hidden_size, 1)
self.mean_head = nn.Linear(hidden_size, 1)
self.forget_gate = nn.Sigmoid()
self.coord_matrix = nn.Parameter(torch.randn(2, 2)) # 遗忘门协调矩阵
def forward(self, x):
output, (hn, cn) = self.shared_lstm(x)
# 协调性遗忘操作
f_trend, f_mean = self.forget_gate(self.coord_matrix[0] * hn[-1]), \
self.forget_gate(self.coord_matrix[1] * hn[-1])
# 任务特定解码
trend_pred = self.trend_head(f_trend * output[:, -1, :])
mean_pred = self.mean_head(f_mean * output[:, -1, :])
return trend_pred, mean_pred
2. 双目标损失函数
构建包含趋势跟随惩罚项和均值回归正则化的复合损失函数,通过帕累托最优前沿确定权重分配。趋势分量采用方向一致性损失,均值分量使用波动率调整的均方误差。
python
def dual_loss(trend_pred, mean_pred, y_trend, y_mean, alpha=0.5):
# 趋势分量:方向敏感损失
trend_loss = F.mse_loss(trend_pred, y_trend) * (1 + torch.sign(y_trend).float())
# 均值分量:波动率加权损失
vol_weight = torch.abs(y_mean) / (torch.std(y_mean) + 1e-8)
mean_loss = F.mse_loss(mean_pred, y_mean) * vol_weight
# 动态权重平衡
total_loss = alpha * trend_loss + (1 - alpha) * mean_loss
return total_loss, trend_loss, mean_loss
遗忘门协同优化机制
1. 梯度追踪协调算法
在反向传播过程中,实时监测两个子任务对隐藏状态梯度的贡献度,通过竞争性学习自动调整协调矩阵。当某任务梯度范数超过阈值时,相应增强其在协调矩阵中的主导权。
python
class ForgetCoordinator:
def __init__(self, threshold=0.7, momentum=0.9):
self.threshold = threshold
self.grad_tracker = {}
self.momentum = momentum
def update_coord_matrix(self, model, step):
# 获取各任务对隐藏状态的梯度贡献
grad_trend = torch.norm(model.trend_head.weight.grad)
grad_mean = torch.norm(model.mean_head.weight.grad)
total_grad = grad_trend + grad_mean + 1e-8
# 计算相对影响力
influence_trend = grad_trend / total_grad
influence_mean = grad_mean / total_grad
# 动态调整协调矩阵
with torch.no_grad():
new_row0 = model.coord_matrix[0] * self.momentum + \
torch.tensor([influence_trend, influence_mean]) * (1 - self.momentum)
new_row1 = model.coord_matrix[1] * self.momentum + \
torch.tensor([1-influence_trend, 1-influence_mean]) * (1 - self.momentum)
model.coord_matrix.data = torch.stack([new_row0, new_row1])
# 记录极端情况用于异常处理
self.grad_tracker[step] = {
'trend': influence_trend.item(),
'mean': influence_mean.item()
}
2. 在线校准协议
部署阶段实施滚动窗口验证,每完成N个交易周期后,根据近期市场状态重新评估双目标权重。当检测到趋势持续性增强或减弱时,自动迁移协调矩阵至新的均衡点。
python
def online_calibration(model, recent_data, window=60):
# 使用最近window个样本进行在线校准
model.eval()
with torch.no_grad():
outputs = []
for i in range(window):
seq = recent_data[i:i+30].view(1, 30, -1)
t_out, m_out = model(seq)
outputs.append((t_out.item(), m_out.item()))
# 分析当前市场状态
trend_score = np.mean([x[0] for x in outputs])
mean_score = np.mean([x[1] for x in outputs])
# 动态调整alpha值
if abs(trend_score) > 0.3 and abs(mean_score) < 0.1:
new_alpha = min(0.8, max(0.2, trend_score/2))
else:
new_alpha = 0.5 # 默认均衡点
return new_alpha
实证研究设计
1. 实验数据集构建
选取沪深300指数成分股分钟级行情数据,构造包含价格序列、成交量变化率、买卖盘口差等12维特征集。按时间顺序划分为训练集(70%)、验证集(15%)和测试集(15%),确保严格的前向观测。
python
class DataPreprocessor:
def __init__(self, lookback=30, forecast_horizon=5):
self.lookback = lookback
self.fh = forecast_horizon
def create_sequences(self, data):
X, y_trend, y_mean = [], [], []
for i in range(len(data)-self.lookback-self.fh+1):
# 输入序列: [t-30, t-29, ..., t-1]
seq = data[i:i+self.lookback]
# 目标变量: 未来5期收益率
future = data[i+self.lookback:i+self.lookback+self.fh]
# 趋势标签: 连续上涨/下跌标记
trend_label = 1 if all(future[j]>0 for j in range(self.fh)) else \
-1 if all(future[j]<0 for j in range(self.fh)) else 0
# 均值标签: 偏离移动平均程度
ma = np.mean(future)
mean_label = sum(abs(x-ma) for x in future)/self.fh
X.append(seq)
y_trend.append(trend_label)
y_mean.append(mean_label)
return np.array(X), np.array(y_trend), np.array(y_mean)
2. 对比基准设置
建立四个对照实验组:①传统单任务LSTM;②简单拼接的双任务LSTM;③固定权重的多任务LSTM;④本文提出的协同优化模型。所有模型保持相同的超参数配置,仅改变任务组织方式。
| 模型类型 | 特点描述 | 预期表现 |
|---|---|---|
| Single-Task LSTM | 单一输出层,专注趋势预测 | 高夏普比率,易过度拟合 |
| Naive Multi-Task | 独立输出分支,无参数共享 | 中等性能,存在冗余计算 |
| Fixed-Weight Multi-Task | 预定义权重分配,静态平衡 | 稳定性好,适应性不足 |
| Coordinated Multi-Task | 动态遗忘门协调,自适应平衡 | 综合性能最优 |
风险控制体系
1. 仓位管理规则
实施三级风险预算机制,根据模型信心度动态调整头寸规模。当任一子任务出现连续亏损时,触发阶梯式降仓直至平仓。特别设置最大回撤硬约束,防止极端行情导致的灾难性后果。
python
class PositionSizing:
def __init__(self, max_dd=0.15, risk_free=0.03):
self.max_drawdown = max_dd
self.rf = risk_free
self.equity_curve = []
def calculate_position(self, confidence, current_price, entry_price):
# 根据凯利公式改进版计算仓位
edge = abs(current_price - entry_price) / entry_price
prob_win = confidence if confidence > 0.5 else 1 - confidence
kelly_fraction = (prob_win - (1 - prob_win)) / (edge + 1e-8)
# 应用最大回撤限制
max_possible = min(kelly_fraction, self.max_drawdown/(1 - self.max_drawdown))
# 考虑机会成本进行调整
adjusted_pos = max(0, max_possible - self.rf/2)
return adjusted_pos
2. 熔断保护机制
监控三个关键指标:①趋势预测置信度低于阈值;②均值回归残差突破历史极值;③协调矩阵元素发生剧烈震荡。任一条件触发即启动应急模式,暂停交易并等待人工干预。
python
class CircuitBreaker:
def __init__(self, conf_thresh=0.6, res_thresh=0.05, matrix_thresh=0.3):
self.conf_threshold = conf_thresh
self.residual_threshold = res_thresh
self.matrix_threshold = matrix_thresh
self.last_confidence = None
self.last_residual = None
self.last_matrix_change = None
def check_trigger(self, current_conf, current_res, matrix_diff):
triggers = []
if current_conf < self.conf_threshold:
triggers.append("low_confidence")
if abs(current_res) > self.residual_threshold:
triggers.append("large_residual")
if matrix_diff > self.matrix_threshold:
triggers.append("matrix_instability")
return triggers
完整实现示例
以下是整合上述组件的完整Python实现,包含数据管道、模型训练和交易执行流程:
python
import torch
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
# 数据准备阶段
class TradingDataset(torch.utils.data.Dataset):
def __init__(self, features, targets_trend, targets_mean):
self.X = torch.FloatTensor(features)
self.y_trend = torch.LongTensor(targets_trend)
self.y_mean = torch.FloatTensor(targets_mean)
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y_trend[idx], self.y_mean[idx]
# 主训练循环
def train_model(train_loader, val_loader, model, optimizer, coordinator, num_epochs=100):
best_val_loss = float('inf')
patience_counter = 0
for epoch in range(num_epochs):
# 训练阶段
model.train()
train_loss, train_tl, train_ml = 0, 0, 0
for batch_idx, (data, t_labels, m_labels) in enumerate(train_loader):
optimizer.zero_grad()
t_pred, m_pred = model(data)
loss, tl, ml = dual_loss(t_pred, m_pred, t_labels, m_labels)
loss.backward()
optimizer.step()
# 更新协调矩阵
coordinator.update_coord_matrix(model, epoch*len(train_loader)+batch_idx)
# 累计指标
train_loss += loss.item()
train_tl += tl.item()
train_ml += ml.item()
# 验证阶段
model.eval()
val_loss, val_tl, val_ml = 0, 0, 0
with torch.no_grad():
for data, t_labels, m_labels in val_loader:
t_pred, m_pred = model(data)
loss, tl, ml = dual_loss(t_pred, m_pred, t_labels, m_labels)
val_loss += loss.item()
val_tl += tl.item()
val_ml += ml.item()
# 早停判断
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# 保存最佳模型
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
if patience_counter >= 10:
print(f"Early stopping at epoch {epoch}")
break
# 打印日志
avg_train_loss = train_loss/len(train_loader)
avg_val_loss = val_loss/len(val_loader)
print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
print(f"Trend/Mean Losses - Train: ({train_tl/len(train_loader):.4f}, {train_ml/len(train_loader):.4f}), "
f"Val: ({val_tl/len(val_loader):.4f}, {val_ml/len(val_loader):.4f})")
# 交易执行引擎
class TradingEngine:
def __init__(self, model_path, capital_base=1000000):
self.model = DualTaskLSTM(input_size=12, hidden_size=64, num_layers=2)
self.model.load_state_dict(torch.load(model_path))
self.capital = capital_base
self.position = 0
self.entry_price = 0
self.stop_loss = 0.05 # 5%止损线
self.take_profit = 0.1 # 10%止盈线
def execute_trade(self, current_data):
# 生成预测信号
self.model.eval()
with torch.no_grad():
signal, _ = self.model(current_data.unsqueeze(0))
# 执行交易逻辑
if signal > 0.5: # 做多信号
if self.position == 0:
self.entry_price = current_data['close'].iloc[-1]
self.position = self.capital // current_data['close'].iloc[-1]
self.capital -= self.position * self.entry_price
elif signal < -0.5: # 做空信号
if self.position == 0:
self.entry_price = current_data['close'].iloc[-1]
self.position = -(self.capital // current_data['close'].iloc[-1])
self.capital += abs(self.position) * self.entry_price
# 风险管理检查
current_price = current_data['close'].iloc[-1]
if self.position != 0:
profit_loss = (current_price - self.entry_price) * self.position
if profit_loss / (self.entry_price * abs(self.position)) >= self.stop_loss or \
profit_loss / (self.entry_price * abs(self.position)) <= -self.stop_loss:
# 触发止损/止盈
self.capital += profit_loss
self.position = 0
return {'position': self.position, 'cash': self.capital}