计算一组新因子、并分析它们与已有因子间的相关性
- [1. 导入库和初始化环境](#1. 导入库和初始化环境)
- [2. 定义新因子计算函数](#2. 定义新因子计算函数)
- [3. 计算新因子并添加到数据框](#3. 计算新因子并添加到数据框)
- [4. 计算相关系数矩阵](#4. 计算相关系数矩阵)
- [5. 高相关性因子分析](#5. 高相关性因子分析)
- [6. 相关性热力图可视化](#6. 相关性热力图可视化)
- [7. 加载同质异质分离数据(可选)](#7. 加载同质异质分离数据(可选))
- [8. 程序完成提示](#8. 程序完成提示)
- 总结
- 完整代码
程序的目标是计算一组新因子,并分析它们与已有因子(旧因子)之间的相关性,以评估新因子的独特性。
1. 导入库和初始化环境
功能
这一部分导入了必要的 Python 库,设置了工作环境,并加载了数据服务和因子数据,为后续计算和分析做准备。
代码解析
python
import sys
sys.path.append('/public/src')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import joblib
import os
import warnings
warnings.filterwarnings('ignore')
from factor_evaluation_server import FactorEvaluation, DataService
sys.path.append('/public/src')
:将/public/src
目录添加到 Python 的模块搜索路径,确保可以导入自定义模块(如factor_evaluation_server
)。- 导入库 :
numpy
和pandas
:用于数值计算和数据处理。matplotlib.pyplot
和seaborn
:用于数据可视化(如热力图)。tqdm
:显示进度条,方便监控循环进度。joblib
:用于加载/保存序列化数据。os
:用于文件路径操作。warnings.filterwarnings('ignore')
:忽略运行过程中的警告信息,以保持输出简洁。
- 自定义模块 :从
factor_evaluation_server
导入FactorEvaluation
和DataService
,可能是自定义的数据处理和因子评估工具。
数据加载
python
# 初始化数据服务
ds = DataService()
df = ds['ETHUSDT_15m_2020_2025']['2021-10-01':]
# 读取已有因子数据
factor_path = "/public/data/factor_data/ETHUSDT_15m_2020_2025_factor_data.pkl"
factors = pd.read_pickle(factor_path)
-
DataService
:初始化数据服务,加载ETHUSDT_15m_2020_2025
数据集(可能是以太坊对美元的 15 分钟 K 线数据),并从 2021-10-01 开始切片。 -
factors
:从.pkl
文件加载已有因子数据,存储为pandas.DataFrame
。 -
时间索引处理 :
pythonif isinstance(factors.index, pd.DatetimeIndex) or factors.index.dtype == 'datetime64[ns]': factors = factors.reset_index(drop=True)
检查因子数据是否具有时间索引,若有则重置为整数索引,确保后续计算时索引对齐。
-
记录旧因子 :
pythonoriginal_columns = factors.columns.tolist() print(f"加载了 {len(original_columns)} 个旧因子")
保存旧因子的列名,并打印旧因子的数量。
2. 定义新因子计算函数
功能
这一部分定义了 20 个新因子计算函数,每个函数基于输入的 K 线数据(如 close
、high
、low
、volume
等)生成一个新的因子。这些因子涵盖了波动率、成交量、价格位置、买卖压力等多个维度,用于量化交易策略的研究。
代码解析
以下是几个代表性因子的定义和解释(完整代码包含 20 个因子,这里仅解析部分以展示逻辑):
因子 1:波动率过滤器(filter_001_1
)
python
def calculate_filter_001_1(df):
'''衡量当前波动率高低的过滤器'''
log_ratio = np.log(df['close'] / df['close'].shift(1))
hv = log_ratio.rolling(20).std()
return hv
- 功能:计算历史波动率(Historical Volatility, HV),基于收盘价的对数收益率在 20 个周期内的标准差。
- 意义:衡量价格波动的剧烈程度,高波动率可能预示市场活跃或不稳定。
因子 2:ATR 过滤器(filter_001_2
)
python
def calculate_filter_001_2(df):
'''ATR过滤器'''
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift(1))
low_close = abs(df['low'] - df['close'].shift(1))
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
atr = true_range.rolling(14).mean()
return atr
- 功能:计算平均真实波幅(Average True Range, ATR),基于 14 个周期的真实波幅均值。
- 意义:ATR 是衡量价格波动范围的常用指标,可用于设置止损或判断趋势强度。
因子 3:凯尔特纳通道(filter_001_3
)
python
def calculate_filter_001_3_keltner_channels(df, ema_period=20, atr_period=10, multiplier=2):
'''凯尔特纳通道:基于ATR的波动通道'''
ema = df['close'].ewm(span=ema_period, adjust=False).mean()
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift(1))
low_close = abs(df['low'] - df['close'].shift(1))
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
atr = true_range.ewm(span=atr_period, adjust=False).mean()
channel_width = multiplier * atr
return (df['close'] - (ema - channel_width)) / (2 * channel_width)
- 功能:计算凯尔特纳通道(Keltner Channels),基于指数移动平均线(EMA)和 ATR,衡量价格相对于通道的位置。
- 意义:标准化后的值(范围 0-1)表示价格在通道中的相对位置,可用于判断超买或超卖。
其他因子
filter_002_1
:成交量偏差,衡量当前成交量相对于 20 周期均值的偏离程度。filter_002_2_obv
:能量潮指标(On-Balance Volume, OBV),基于价格方向和成交量的累积指标。filter_010_1
:相对强弱指数(RSI),衡量价格超买或超卖状态。filter_011
:MACD,衡量短期和长期趋势的差异。- 更多因子:包括量价趋势、阿隆指标、Chaikin 资金流等,涵盖了技术分析的多个方面。
因子映射
python
factor_calculators = {
'filter_001_1': calculate_filter_001_1,
'filter_001_2': calculate_filter_001_2,
# ... 其他因子映射
}
- 将因子名称与对应的计算函数映射,方便后续批量计算。
3. 计算新因子并添加到数据框
功能
这一部分遍历所有新因子计算函数,计算因子值并将其添加到因子数据框中,同时处理可能的长度不匹配或计算错误。
代码解析
python
print("\n开始计算新因子...")
for factor_name, calculator in tqdm(factor_calculators.items(), desc="计算因子"):
try:
factor_values = calculator(df)
if isinstance(factor_values, pd.Series):
if len(factor_values) != len(factors):
factor_values = factor_values.iloc[:len(factors)].reset_index(drop=True)
factors[factor_name] = factor_values.values
else:
if len(factor_values) != len(factors):
factor_values = factor_values[:len(factors)]
factors[factor_name] = factor_values
except Exception as e:
print(f"计算因子 {factor_name} 时出错: {str(e)}")
factors[factor_name] = np.nan
-
遍历计算 :使用
tqdm
显示进度条,逐个调用因子计算函数。 -
长度对齐 :
- 如果计算结果是
pd.Series
,确保其长度与factors
一致,必要时截断并重置索引。 - 如果是
numpy
数组,同样截断以匹配长度。
- 如果计算结果是
-
错误处理 :捕获异常并打印错误信息,将出错的因子列填充为
NaN
,避免程序中断。 -
索引重置 :
pythonfactors = factors.reset_index(drop=True)
确保因子数据框使用整数索引,统一数据结构。
4. 计算相关系数矩阵
功能
计算所有因子(包括新旧因子)之间的相关系数矩阵,并提取新因子与旧因子之间的相关性,用于后续分析。
代码解析
python
print("\n计算相关系数矩阵...")
corr_matrix = factors.corr()
new_factors = list(factor_calculators.keys())
old_factors = original_columns
new_to_old_corr = corr_matrix.loc[new_factors, old_factors]
factors.corr()
:计算因子数据框的皮尔逊相关系数矩阵。new_to_old_corr
:提取新因子(行)与旧因子(列)的相关性子矩阵,聚焦于新因子与旧因子的关系。
5. 高相关性因子分析
功能
筛选出新因子与旧因子之间绝对相关系数超过阈值(0.7)的因子对,记录并输出高相关性信息。
代码解析
python
threshold = 0.7
high_corr_records = []
print(f"\n高相关性因子分析 (新因子与旧因子, |corr| > {threshold}):")
for new_factor in new_factors:
corr_series = new_to_old_corr.loc[new_factor]
high_corr = corr_series[abs(corr_series) > threshold]
if not high_corr.empty:
print(f"\n🔍 {new_factor} 与以下旧因子有高相关性:")
for factor, corr_value in high_corr.items():
strength = "强" if abs(corr_value) > 0.8 else "中等"
direction = "正" if corr_value > 0 else "负"
print(f" - {factor}: {corr_value:.4f} ({strength}{direction}相关)")
high_corr_records.append({
'新因子': new_factor,
'旧因子': factor,
'相关系数': corr_value,
'相关强度': strength,
'相关方向': direction
})
- 阈值设置:相关系数绝对值大于 0.7 被认为是高相关。
- 筛选高相关性:对于每个新因子,检查其与旧因子的相关系数是否超过阈值。
- 记录信息 :将高相关性因子对的信息(新因子、旧因子、相关系数、强度、方向)存储在
high_corr_records
列表中。 - 强度与方向 :
- 相关系数绝对值 > 0.8:强相关。
- 相关系数绝对值 0.7-0.8:中等相关。
- 正/负相关:根据相关系数的正负判断。
输出高相关性表格
python
if high_corr_records:
high_corr_df = pd.DataFrame(high_corr_records)
print("\n📊 高相关性因子汇总表 (新因子 vs 旧因子):")
display(high_corr_df[['新因子', '旧因子', '相关系数', '相关强度', '相关方向']]
.sort_values(by='相关系数', key=abs, ascending=False)
.style
.background_gradient(cmap='coolwarm', subset=['相关系数'])
.format({'相关系数': "{:.4f}"})
.set_caption(f"高相关性因子汇总 (|corr| > {threshold})"))
else:
print(f"\n✅ 未发现高相关性新因子与旧因子 (|corr| > {threshold})")
- 表格生成 :将高相关性记录转换为
pd.DataFrame
。 - 美化输出 :
- 按相关系数绝对值降序排序。
- 使用
coolwarm
颜色映射突出相关系数的大小。 - 格式化相关系数保留 4 位小数。
- 空结果处理:如果没有高相关性因子对,输出提示信息。
6. 相关性热力图可视化
功能
通过热力图可视化新因子与旧因子之间的相关系数,直观展示相关性强弱。
代码解析
python
plt.figure(figsize=(18, 12))
sns.heatmap(
new_to_old_corr,
annot=True,
fmt=".2f",
cmap='coolwarm',
center=0,
vmin=-1,
vmax=1,
linewidths=0.5
)
plt.title('新因子与旧因子的相关性热力图', fontsize=16, pad=20)
plt.xlabel('旧因子', fontsize=12)
plt.ylabel('新因子', fontsize=12)
plt.xticks(rotation=90, fontsize=8)
plt.yticks(rotation=0, fontsize=8)
cbar = plt.gcf().axes[-1]
cbar.set_ylabel('相关系数', rotation=270, labelpad=20)
plt.tight_layout()
plt.show()
sns.heatmap
:绘制相关系数热力图。annot=True
:在格子中显示相关系数数值(保留 2 位小数)。cmap='coolwarm'
:使用红蓝渐变色,正相关为红色,负相关为蓝色。center=0
:以 0 为中心,区分正负相关。vmin=-1, vmax=1
:设置相关系数范围为 [-1, 1]。
- 美化设置 :
- 图表标题、轴标签、字体大小等优化显示效果。
- X 轴标签旋转 90 度,Y 轴标签保持水平,增强可读性。
- 显示:在 Jupyter Notebook 中直接展示热力图。
7. 加载同质异质分离数据(可选)
功能
尝试加载同质/异质因子分离数据(可能是预处理的分组信息),为后续分析提供参考。
代码解析
python
homo_hetero_path = "/public/data/factor_data/ETHUSDT_15m_2020_2025_homo_heter_split.joblib"
if os.path.exists(homo_hetero_path):
try:
homo_hetero_data = joblib.load(homo_hetero_path)
print("\n加载同质异质分离数据成功!")
except Exception as e:
print(f"\n加载同质异质分离数据时出错: {str(e)}")
else:
print("\n未找到同质异质分离数据文件")
- 文件检查 :使用
os.path.exists
检查文件是否存在。 - 加载数据 :使用
joblib.load
加载.joblib
文件,存储同质/异质因子分组信息。 - 错误处理:捕获加载过程中的异常并打印错误信息。
- 备注 :代码中未进一步使用
homo_hetero_data
,可能是为后续分析预留的扩展点。
8. 程序完成提示
功能
输出程序完成的信息,提示分析任务结束。
代码解析
python
print("\n🎉 因子相关性分析完成!")
- 打印完成信息,带有 emoji 增加友好性。
总结
该程序是一个完整的因子分析工作流,主要功能包括:
- 数据准备:加载 K 线数据和旧因子数据,统一索引格式。
- 新因子计算:定义并计算 20 个新因子,涵盖波动率、成交量、技术指标等。
- 相关性分析:计算新因子与旧因子的相关系数,筛选高相关性因子对(阈值 0.7)。
- 结果展示 :
- 输出高相关性因子表格,包含相关系数、强度和方向。
- 绘制相关性热力图,直观展示新旧因子之间的关系。
- 扩展性:支持加载同质/异质因子数据,为进一步分析预留空间。
逻辑清晰性
- 模块化设计:因子计算函数独立定义,便于维护和扩展。
- 错误处理:对数据长度、计算错误等进行了健壮处理。
- 可视化:通过表格和热力图直观展示分析结果,方便用户理解。
完整代码
python
# %%
import sys
sys.path.append('/public/src')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import joblib
import os
import warnings
warnings.filterwarnings('ignore')
from factor_evaluation_server import FactorEvaluation, DataService
# 初始化数据服务
ds = DataService()
df = ds['ETHUSDT_15m_2020_2025']['2021-10-01':]
# 读取已有因子数据
factor_path = "/public/data/factor_data/ETHUSDT_15m_2020_2025_factor_data.pkl"
factors = pd.read_pickle(factor_path)
# 检查是否有时间索引,如果有则重置为整数索引
if isinstance(factors.index, pd.DatetimeIndex) or factors.index.dtype == 'datetime64[ns]':
factors = factors.reset_index(drop=True)
# 保存原始列名(旧因子)
original_columns = factors.columns.tolist()
print(f"加载了 {len(original_columns)} 个旧因子")
python
# 定义所有新因子计算函数
def calculate_filter_001_1(df):
'''衡量当前波动率高低的过滤器'''
log_ratio = np.log(df['close'] / df['close'].shift(1))
hv = log_ratio.rolling(20).std()
return hv
def calculate_filter_001_2(df):
'''ATR过滤器'''
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift(1))
low_close = abs(df['low'] - df['close'].shift(1))
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
atr = true_range.rolling(14).mean()
return atr
def calculate_filter_001_3_keltner_channels(df, ema_period=20, atr_period=10, multiplier=2):
'''凯尔特纳通道:基于ATR的波动通道'''
ema = df['close'].ewm(span=ema_period, adjust=False).mean()
# 计算ATR
high_low = df['high'] - df['low']
high_close = abs(df['high'] - df['close'].shift(1))
low_close = abs(df['low'] - df['close'].shift(1))
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
atr = true_range.ewm(span=atr_period, adjust=False).mean()
# 计算通道宽度
channel_width = multiplier * atr
return (df['close'] - (ema - channel_width)) / (2 * channel_width)
def calculate_filter_002_1(df):
'''衡量当前成交量高低的过滤器'''
volume_mean = df['volume'].rolling(20).mean()
volume_deviation = (df['volume'] - volume_mean) / volume_mean
return volume_deviation
def calculate_filter_002_2_obv(df):
'''能量潮指标:累积成交量平衡'''
obv = (np.sign(df['close'].diff()) * df['volume'])
obv = obv.cumsum()
return obv
def calculate_filter_002_3_vwap(df):
'''成交量加权平均价'''
typical_price = (df['high'] + df['low'] + df['close']) / 3
vwap = (typical_price * df['volume']).cumsum() / df['volume'].cumsum()
return vwap
def calculate_filter_003(df):
'''衡量当前相对位置高低的过滤器'''
up = df['high'].rolling(20).max()
down = df['low'].rolling(20).min()
price_position = (df['close'] - down) / (up - down)
return price_position
def calculate_filter_004(df):
'''衡量短期价格波动快慢的过滤器'''
std_5 = df['close'].rolling(5).std()
std_30 = df['close'].rolling(30).std()
price_fluctuation = std_5 / std_30
return price_fluctuation
def calculate_filter_005(df):
'''
衡量买卖压力的比例的过滤器
'''
imbalance = 2 * df['taker_buy_volume'] / df['volume'] - 1
return imbalance
def calculate_filter_006(df):
'''
衡量平均交易量的过滤器
'''
average_trade_size = df['volume'] / df['trade_count']
return average_trade_size
def calculate_filter_007(df):
'''
衡量 candle body 相对价格范围的大小的过滤器
'''
body = abs(df['close'] - df['open'])
range_ = df['high'] - df['low']
body_to_range_ratio = body / range_
return body_to_range_ratio
def calculate_filter_008(df):
'''衡量上影线相对价格范围的大小的过滤器'''
upper_wick = df['high'] - df[['open', 'close']].max(axis=1)
range_ = df['high'] - df['low']
upper_wick_ratio = upper_wick / range_
return upper_wick_ratio
def calculate_filter_009(df):
'''衡量下影线相对价格范围的大小的过滤器'''
lower_wick = df[['open', 'close']].min(axis=1) - df['low']
range_ = df['high'] - df['low']
lower_wick_ratio = lower_wick / range_
return lower_wick_ratio
def calculate_filter_010_1(df, period=14):
'''
衡量RSI的过滤器
'''
delta = df['close'].diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.ewm(alpha=1/period, adjust=False).mean()
avg_loss = loss.ewm(alpha=1/period, adjust=False).mean()
rs = avg_gain / avg_loss.replace([np.inf, -np.inf], np.nan).fillna(0)
rsi = 100 - (100 / (1 + rs))
return rsi
def calculate_filter_010_2_mfi(df, period=14):
'''资金流量指数:结合价格和成交量的RSI变体'''
typical_price = (df['high'] + df['low'] + df['close']) / 3
raw_money_flow = typical_price * df['volume']
money_flow_direction = np.where(typical_price > typical_price.shift(1), 1, -1)
positive_flow = raw_money_flow.where(money_flow_direction > 0, 0)
negative_flow = raw_money_flow.where(money_flow_direction < 0, 0)
money_ratio = positive_flow.rolling(period).sum() / negative_flow.rolling(period).sum()
money_ratio = money_ratio.replace([np.inf, -np.inf], np.nan).fillna(1)
mfi = 100 - (100 / (1 + money_ratio))
return mfi
def calculate_filter_011(df, short_period=12, long_period=26):
'''
衡量MACD的过滤器
'''
short_ema = df['close'].ewm(span=short_period, adjust=False).mean()
long_ema = df['close'].ewm(span=long_period, adjust=False).mean()
macd = short_ema - long_ema
return macd
def calculate_filter_012_aroon_up(df, period=14):
'''阿隆上升指标:衡量价格创新高的能力'''
high_idx = df['high'].rolling(period).apply(lambda x: x.argmax(), raw=True)
aroon_up = 100 * (period - high_idx) / period
return aroon_up
def calculate_filter_013_aroon_down(df, period=14):
'''阿隆下降指标:衡量价格创新低的能力'''
low_idx = df['low'].rolling(period).apply(lambda x: x.argmin(), raw=True)
aroon_down = 100 * (period - low_idx) / period
return aroon_down
def calculate_filter_014_aroon_oscillator(df, period=14):
'''阿隆震荡器:衡量趋势强度'''
aroon_up = calculate_filter_012_aroon_up(df, period)
aroon_down = calculate_filter_013_aroon_down(df, period)
return aroon_up - aroon_down
def calculate_filter_015_chaikin_money_flow(df, period=20):
'''
Chaikin资金流(CMF)
'''
money_flow_multiplier = ((df['close'] - df['low']) - (df['high'] - df['close'])) / (df['high'] - df['low'])
money_flow_multiplier = money_flow_multiplier.replace([np.inf, -np.inf], 0).fillna(0)
money_flow_volume = money_flow_multiplier * df['volume']
cmf = money_flow_volume.rolling(period).sum() / df['volume'].rolling(period).sum()
return cmf
def calculate_filter_020_volume_price_trend(df):
'''量价趋势指标:结合价格变动和成交量'''
price_change = df['close'].pct_change()
vpt = (price_change * df['volume']).cumsum()
return vpt
python
# 所有新因子计算函数的映射
factor_calculators = {
'filter_001_1': calculate_filter_001_1,
'filter_001_2': calculate_filter_001_2,
'filter_001_3': lambda df: calculate_filter_001_3_keltner_channels(df),
'filter_002_1': calculate_filter_002_1,
'filter_002_2': calculate_filter_002_2_obv,
'filter_002_3': calculate_filter_002_3_vwap,
'filter_003': calculate_filter_003,
'filter_004': calculate_filter_004,
'filter_005': calculate_filter_005,
'filter_006': calculate_filter_006,
'filter_007': calculate_filter_007,
'filter_008': calculate_filter_008,
'filter_009': calculate_filter_009,
'filter_010_1': lambda df: calculate_filter_010_1(df, 14),
'filter_010_2': lambda df: calculate_filter_010_2_mfi(df, 14),
'filter_011': lambda df: calculate_filter_011(df, 12, 26),
'filter_012': lambda df: calculate_filter_012_aroon_up(df, 14),
'filter_013': lambda df: calculate_filter_013_aroon_down(df, 14),
'filter_014': lambda df: calculate_filter_014_aroon_oscillator(df, 14),
'filter_015': lambda df: calculate_filter_015_chaikin_money_flow(df, 20),
'filter_020': calculate_filter_020_volume_price_trend
}
python
# 计算并添加所有新因子
print("\n开始计算新因子...")
for factor_name, calculator in tqdm(factor_calculators.items(), desc="计算因子"):
try:
# 计算因子值并转换为pandas Series
factor_values = calculator(df)
if isinstance(factor_values, pd.Series):
# 确保长度与factors一致
if len(factor_values) != len(factors):
# 截断或填充以匹配长度
factor_values = factor_values.iloc[:len(factors)].reset_index(drop=True)
factors[factor_name] = factor_values.values
else:
# 处理numpy数组
if len(factor_values) != len(factors):
factor_values = factor_values[:len(factors)]
factors[factor_name] = factor_values
except Exception as e:
print(f"计算因子 {factor_name} 时出错: {str(e)}")
# 添加空列以保持数据结构
factors[factor_name] = np.nan
# 确保索引对齐
factors = factors.reset_index(drop=True)
python
# 计算相关系数矩阵
print("\n计算相关系数矩阵...")
corr_matrix = factors.corr()
# 提取新因子与旧因子的相关系数
new_factors = list(factor_calculators.keys())
old_factors = original_columns
new_to_old_corr = corr_matrix.loc[new_factors, old_factors]
python
# 设置相关性阈值
threshold = 0.7
high_corr_records = [] # 用于存储高相关性记录
print(f"\n高相关性因子分析 (新因子与旧因子, |corr| > {threshold}):")
for new_factor in new_factors:
# 只考虑与旧因子的相关性
corr_series = new_to_old_corr.loc[new_factor]
# 找出相关性绝对值超过阈值的因子
high_corr = corr_series[abs(corr_series) > threshold]
if not high_corr.empty:
print(f"\n🔍 {new_factor} 与以下旧因子有高相关性:")
for factor, corr_value in high_corr.items():
# 判断相关性强弱和方向
strength = "强" if abs(corr_value) > 0.8 else "中等"
direction = "正" if corr_value > 0 else "负"
print(f" - {factor}: {corr_value:.4f} ({strength}{direction}相关)")
high_corr_records.append({
'新因子': new_factor,
'旧因子': factor,
'相关系数': corr_value,
'相关强度': strength,
'相关方向': direction
})
python
# 在Notebook中直接显示高相关性因子表格
if high_corr_records:
high_corr_df = pd.DataFrame(high_corr_records)
# 美化显示表格
print("\n📊 高相关性因子汇总表 (新因子 vs 旧因子):")
display(high_corr_df[['新因子', '旧因子', '相关系数', '相关强度', '相关方向']]
.sort_values(by='相关系数', key=abs, ascending=False)
.style
.background_gradient(cmap='coolwarm', subset=['相关系数'])
.format({'相关系数': "{:.4f}"})
.set_caption(f"高相关性因子汇总 (|corr| > {threshold})"))
else:
print(f"\n✅ 未发现高相关性新因子与旧因子 (|corr| > {threshold})")
python
# 可视化:新因子与旧因子的相关性热力图
plt.figure(figsize=(18, 12))
sns.heatmap(
new_to_old_corr,
annot=True,
fmt=".2f",
cmap='coolwarm',
center=0,
vmin=-1,
vmax=1,
linewidths=0.5
)
plt.title('新因子与旧因子的相关性热力图', fontsize=16, pad=20)
plt.xlabel('旧因子', fontsize=12)
plt.ylabel('新因子', fontsize=12)
plt.xticks(rotation=90, fontsize=8)
plt.yticks(rotation=0, fontsize=8)
# 添加颜色条说明
cbar = plt.gcf().axes[-1]
cbar.set_ylabel('相关系数', rotation=270, labelpad=20)
# 直接在Notebook中显示图表
plt.tight_layout()
plt.show()
python
# 加载同质异质分离数据(如果存在)
homo_hetero_path = "/public/data/factor_data/ETHUSDT_15m_2020_2025_homo_heter_split.joblib"
if os.path.exists(homo_hetero_path):
try:
homo_hetero_data = joblib.load(homo_hetero_path)
print("\n加载同质异质分离数据成功!")
# 这里可以添加对同质异质数据的分析代码
except Exception as e:
print(f"\n加载同质异质分离数据时出错: {str(e)}")
else:
print("\n未找到同质异质分离数据文件")
print("\n🎉 因子相关性分析完成!")