金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例
逻辑回归: 逻辑回归常被用于二分类问题, 比如涨跌预测。你可以将涨跌标记为类别, 然后使用逻辑回归进行训练。
决策树和随机森林: 决策树和随机森林是用于分类问题的强大模型。它们能够处理非线性关系, 并且对于特征的重要性有较好的解释。
实例数据
本实例截取了 "湖北宜化(000422)" 2015年08月06日 - 2015年12月31日的数据。
HBYH_000422_20150806_20151231.csv
csv
Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80
探索思路
这里只是简单示例, 目的在于熟悉 Scikit-Learn 中的决策树分类器使用方法, 无任何投资引导。
目标:
通过当日数值情况, 预测当日收盘涨跌, 如果 "涨跌幅(Change) >= 0", 则用 1 表示, 如果 "涨跌幅(Change) < 0", 则用 0 表示 (二分类标签)。
变量:
-
当日最高价
-
当日最低价
-
当日换手率
-
当日成交量
-
当日星期几 (星期对价格的影响)
-
当日 "短期均线(MA5)" 与 "长期均线(MA10)" 的关系, 如果 "MA5 > MA10", 则用 1 表示, 如果 "MA5 = MA10", 则用 0 表示, 如果 "MA5 < MA10", 则用 -1 表示。
-
节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
导入 Pandas 相关模块
Pandas 是基于 NumPy 的一种工具, 该工具是为解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型, 提供了高效地操作大型数据集所需的工具。
Pandas 提供了大量能使我们快速便捷地处理数据的函数和方法。你很快就会发现, 它是使 Python 成为强大而高效的数据分析环境的重要因素之一。
python
import pandas as pd
导入 Scikit-Learn 相关模块
Scikit-Learn (以前称为 scikits.learn, 也称为 sklearn) 是针对 Python 编程语言的免费软件机器学习库。
它具有各种分类, 回归和聚类算法, 包括支持向量机, 随机森林, 梯度提升, K均值 和 DBSCAN, 并且旨在与 Python 数值科学库 NumPy 和 SciPy 联合使用。
python
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler
使用 Pandas 读取 CSV 数据
调用 Pandas 的 .read_csv 方法读取 CSV 数据:
其中 header 参数指定 CSV 文件的表头行, 这里的 header=0 表示表头行在 1 行, 如果 header=None 则表示数据没有列索引, Pandas 则会自动加上索引。
其中 sep 参数指定 CSV 文件的分隔符, 默认情况下都是以 "," 作为分隔符, 这里的 sep="," 表示指定 CSV 文件的分隔符为 ","。
还有 dtype 参数指定 CSV 某些特定列以特定的数据类型进行读取, 例如 dtype={"Close":float, "Volume":int} 表示 "Close" 列以 浮点(float) 类型读取, "Volume" 列以 整数(integer) 类型读取。
python
PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")
输出 DataFrame 数据框:
python
print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
print(PDF)
输出:
txt
[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv
Date Code Open High Low Close Pre_Close Change Turnover_Rate Volume MA5 MA10
0 2015-12-31 '000422 7.93 7.95 7.76 7.77 7.93 -0.020177 0.015498 13915200 7.86 7.85
1 2015-12-30 '000422 7.86 7.93 7.75 7.93 7.84 0.011480 0.018662 16755900 7.90 7.85
2 2015-12-29 '000422 7.72 7.85 7.69 7.84 7.71 0.016861 0.015886 14263800 7.90 7.81
3 2015-12-28 '000422 8.03 8.08 7.70 7.71 8.03 -0.039851 0.030821 27672800 7.91 7.78
4 2015-12-25 '000422 8.03 8.05 7.93 8.03 7.99 0.005006 0.021132 18974000 7.93 7.78
.. ... ... ... ... ... ... ... ... ... ... ... ...
94 2015-08-12 '000422 8.38 8.48 8.21 8.24 8.48 -0.028302 0.035421 30960700 8.30 8.08
95 2015-08-11 '000422 8.41 8.68 8.32 8.48 8.49 -0.001178 0.048444 42343900 8.26 8.03
96 2015-08-10 '000422 8.28 8.58 8.18 8.49 8.21 0.034105 0.041268 36071600 8.20 7.92
97 2015-08-07 '000422 8.15 8.28 8.08 8.21 8.07 0.017348 0.025855 22599800 8.05 7.81
98 2015-08-06 '000422 7.88 8.21 7.80 8.07 8.03 0.004981 0.020074 17546700 7.95 7.80
[99 rows x 12 columns]
转换 Pandas 中 DateFrame 各列数据类型
通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。
python
# 转换 Pandas 中 DateFrame 数据类型。
PDF["Date"] = PDF["Date"].astype("datetime64[ns]")
PDF["Open"] = PDF["Open"].astype("float64")
PDF["High"] = PDF["High"].astype("float64")
PDF["Low"] = PDF["Low"].astype("float64")
PDF["Close"] = PDF["Close"].astype("float64")
PDF["Pre_Close"] = PDF["Pre_Close"].astype("float64")
PDF["Change"] = PDF["Change"].astype("float64")
PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
PDF["Volume"] = PDF["Volume"].astype("int64")
PDF["MA5"] = PDF["MA5"].astype("float64")
PDF["MA10"] = PDF["MA10"].astype("float64")
# 输出 Pandas 中 DataFrame 字段和数据类型。
print("[Message] Changed Pandas DataFrame Data Type:")
print(PDF.dtypes)
输出:
txt
[Message] Changed Pandas DataFrame Data Type:
Date datetime64[ns]
Code object
Open float64
High float64
Low float64
Close float64
Pre_Close float64
Change float64
Turnover_Rate float64
Volume int64
MA5 float64
MA10 float64
dtype: object
在 Pandas 的 DataFrame 中计算数据
编写 "判断股票短期均线和长期均线关系" 函数:
python
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:
if (Short_MA >= Long_MA): return 1
if (Short_MA == Long_MA): return 0
if (Short_MA <= Long_MA): return -1
# ==============================================
# End of Function.
在 Pandas 的 DataFrame 中直接计算或调用自定义函数:
python
# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
PDF["Weekday(Idx)"] = PDF["Date"].apply(lambda X: X.weekday())
# ..................................................
# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
PDF["Festival"] = None
for Idx in PDF.index:
if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。
if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas" # -> 圣诞节。
# ..................................................
# 计算数据: 判断股票涨跌。
PDF["Rise_Fall"] = PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
# ..................................................
# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)
# 输出计算好的 DataFrame 数据框。
print("[Message] Calculated DataFrame:")
print(PDF)
输出:
txt
[Message] Calculated DataFrame:
Date Code Open High Low Close Pre_Close Change Turnover_Rate Volume MA5 MA10 Weekday(Idx) Festival Rise_Fall MA_Relationship
0 2015-12-31 '000422 7.93 7.95 7.76 7.77 7.93 -0.020177 0.015498 13915200 7.86 7.85 3 None 0 1
1 2015-12-30 '000422 7.86 7.93 7.75 7.93 7.84 0.011480 0.018662 16755900 7.90 7.85 2 None 1 1
2 2015-12-29 '000422 7.72 7.85 7.69 7.84 7.71 0.016861 0.015886 14263800 7.90 7.81 1 None 1 1
3 2015-12-28 '000422 8.03 8.08 7.70 7.71 8.03 -0.039851 0.030821 27672800 7.91 7.78 0 None 0 1
4 2015-12-25 '000422 8.03 8.05 7.93 8.03 7.99 0.005006 0.021132 18974000 7.93 7.78 4 Christmas 1 1
.. ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
94 2015-08-12 '000422 8.38 8.48 8.21 8.24 8.48 -0.028302 0.035421 30960700 8.30 8.08 2 None 0 1
95 2015-08-11 '000422 8.41 8.68 8.32 8.48 8.49 -0.001178 0.048444 42343900 8.26 8.03 1 None 0 1
96 2015-08-10 '000422 8.28 8.58 8.18 8.49 8.21 0.034105 0.041268 36071600 8.20 7.92 0 None 1 1
97 2015-08-07 '000422 8.15 8.28 8.08 8.21 8.07 0.017348 0.025855 22599800 8.05 7.81 4 None 1 1
98 2015-08-06 '000422 7.88 8.21 7.80 8.07 8.03 0.004981 0.020074 17546700 7.95 7.80 3 None 1 1
[99 rows x 16 columns]
在 Pandas 的 DataFrame 中将字符串类型的特征列转换为数值 (One-Hot Encoding)
pd.get_dummies() 是 Pandas 库中用于独热编码 (One-Hot Encoding) 的函数。它的作用是将分类 (离散) 变量的每个不同取值都拓展为一个新的二进制特征 (0 或 1), 从而方便机器学习模型处理。
python
# 函数签名:
pd.get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, columns=None, sparse=False, drop_first=False, dtype=None)
# 参数说明:
# - data: 要进行独热编码的 DataFrame 或 Series。
# - prefix: 生成的独热编码列的前缀。
# - prefix_sep: 生成的独热编码列的前缀和原始列名之间的分隔符。
# - dummy_na: 是否为原始数据中的缺失值生成独热编码列。
# - columns: 要进行独热编码的列的名称, 如果指定, 则只对这些列进行操作。
# - drop_first: 是否删除第一个独热编码列, 以避免共线性问题。
转换 Festival 特征列为数值:
python
# 将字符串类型的特征列转换为数值 (独热编码)。
PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)
# 输出转换后的 DataFrame 数据框。
print("[Message] DataFrame After One-Hot Encoding:")
print(PDF)
输出:
txt
[Message] DataFrame After One-Hot Encoding:
Date Code Open High Low Close Pre_Close Change Turnover_Rate Volume MA5 MA10 Weekday(Idx) Rise_Fall MA_Relationship Festival_Christmas Festival_Christmas_Eve
0 2015-12-31 '000422 7.93 7.95 7.76 7.77 7.93 -0.020177 0.015498 13915200 7.86 7.85 3 0 1 0 0
1 2015-12-30 '000422 7.86 7.93 7.75 7.93 7.84 0.011480 0.018662 16755900 7.90 7.85 2 1 1 0 0
2 2015-12-29 '000422 7.72 7.85 7.69 7.84 7.71 0.016861 0.015886 14263800 7.90 7.81 1 1 1 0 0
3 2015-12-28 '000422 8.03 8.08 7.70 7.71 8.03 -0.039851 0.030821 27672800 7.91 7.78 0 0 1 0 0
4 2015-12-25 '000422 8.03 8.05 7.93 8.03 7.99 0.005006 0.021132 18974000 7.93 7.78 4 1 1 1 0
.. ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
94 2015-08-12 '000422 8.38 8.48 8.21 8.24 8.48 -0.028302 0.035421 30960700 8.30 8.08 2 0 1 0 0
95 2015-08-11 '000422 8.41 8.68 8.32 8.48 8.49 -0.001178 0.048444 42343900 8.26 8.03 1 0 1 0 0
96 2015-08-10 '000422 8.28 8.58 8.18 8.49 8.21 0.034105 0.041268 36071600 8.20 7.92 0 1 1 0 0
97 2015-08-07 '000422 8.15 8.28 8.08 8.21 8.07 0.017348 0.025855 22599800 8.05 7.81 4 1 1 0 0
98 2015-08-06 '000422 7.88 8.21 7.80 8.07 8.03 0.004981 0.020074 17546700 7.95 7.80 3 1 1 0 0
[99 rows x 17 columns]
提取 标签(Label)列 和 特征(Feature)列
提取 标签(Label) 列:
python
# 提取 标签(Label) 列。
Y = PDF["Rise_Fall"]
提取 特征(Feature) 列:
python
# 提取 特征(Feature) 列。
X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)
划分训练集和测试集(train_test_split) 以及 特征标准化(StandardScaler)
划分训练集和测试集(train_test_split):
python
# 数据集划分训练集和测试集(train_test_split)。
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)
特征标准化(StandardScaler):
在机器学习中, fit_transform 和 transform 是用于数据预处理的常见方法, 它们的作用略有不同:
fit_transform: 该方法将同时拟合和转换数据。
-
它会根据输入的数据计算所需的转换参数 (例如均值、标准差等), 然后将数据应用这些参数进行转换。
-
在训练阶段, 通常使用 fit_transform 来对训练集进行拟合和转换。
-
拟合过程会根据训练集数据计算并保存所需的转换参数, 然后将训练集数据应用这些参数进行转换。
-
这样做的目的是确保在后续对测试集或新数据进行转换时使用相同的转换参数。
transform: 该方法仅对数据进行转换, 不进行拟合过程。
-
它根据之前使用 fit_transform 得到的转换参数, 将这些参数应用于新的数据, 使其按照相同的转换方式进行处理。
-
在测试阶段或对新数据应用模型时, 通常使用 transform 方法对测试集或新数据进行转换。
简而言之, fit_transform 方法用于拟合转换器并将数据进行转换, 而 transform 方法仅用于将数据按照已经拟合的转换器进行转换。
在代码中的具体应用上, 通常将 fit_transform 用于训练集的拟合和转换, 将 transform 用于测试集或新数据的转换, 以保证数据的一致性和正确的预处理操作。
python
# 特征标准化(StandardScaler)。
Obj_Scaler = StandardScaler()
X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
X_Test_Scaled = Obj_Scaler.transform(X_Test)
训练 决策树分类器(DecisionTreeClassifier) 模型
创建 决策树分类器(DecisionTreeClassifier):
python
# 创建 决策树分类器(DecisionTreeClassifier)。
DTC = DecisionTreeClassifier(random_state=42)
训练 决策树分类器(DecisionTreeClassifier) 模型:
python
# 训练 决策树分类器(DecisionTreeClassifier) 模型。
DTC.fit(X_Train_Scaled, Y_Train)
# Value of Return:
# +----------------------------------------+
# |▼ DecisionTreeClassifier |
# +----------------------------------------+
# | DecisionTreeClassifier(random_state=42)|
# +----------------------------------------+
使用 决策树分类器(DecisionTreeClassifier) 模型预测数据
python
# 在测试集上进行预测。
Y_Pred = DTC.predict(X_Test_Scaled)
# 合并预测结果。
Result = X_Test.copy()
Result["Actually"] = Y_Test
Result["Prediction"] = Y_Pred
print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
print(Result)
输出:
txt
[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:
High Low Turnover_Rate Volume Weekday(Idx) MA_Relationship Festival_Christmas Festival_Christmas_Eve Actually Prediction
62 6.32 6.18 0.008991 8072900 1 -1 0 0 0 1
40 7.54 7.32 0.040463 36330200 3 1 0 0 1 1
95 8.68 8.32 0.048444 42343900 1 1 0 0 0 1
18 8.39 7.94 0.064590 57993100 0 1 0 0 1 1
97 8.28 8.08 0.025855 22599800 4 1 0 0 1 1
84 6.77 6.09 0.029963 26190200 2 -1 0 0 1 0
64 6.56 6.25 0.012584 11298800 4 1 0 0 0 1
42 7.13 7.02 0.014938 13412400 1 1 0 0 0 0
10 7.75 7.57 0.028054 25188400 3 -1 0 0 1 0
0 7.95 7.76 0.015498 13915200 3 1 0 0 0 1
31 7.54 7.38 0.014173 12725000 2 -1 0 0 0 0
76 7.09 6.86 0.028974 25325600 2 1 0 0 1 1
47 7.48 7.08 0.057658 51769300 1 1 0 0 1 1
26 7.64 7.51 0.020919 18782900 2 1 0 0 1 1
44 7.38 7.10 0.022821 20490200 4 1 0 0 1 0
4 8.05 7.93 0.021132 18974000 4 1 1 0 1 1
22 7.39 7.23 0.012308 11050700 1 -1 0 0 1 1
12 7.66 7.52 0.025902 23256600 1 -1 0 0 0 1
88 8.56 8.14 0.031764 27764200 3 1 0 0 0 1
73 6.95 6.18 0.021233 18559600 0 1 0 0 0 1
使用 accuracy_score 评估模型性能
python
# 评估模型性能。
Accuracy = accuracy_score(Y_Test, Y_Pred)
print("Accuracy:", Accuracy)
print("\n")
# 输出分类报告。
print("Classification Report:")
print(classification_report(Y_Test, Y_Pred))
输出:
txt
Accuracy: 0.5
Classification Report:
precision recall f1-score support
0 0.40 0.22 0.29 9
1 0.53 0.73 0.62 11
accuracy 0.50 20
macro avg 0.47 0.47 0.45 20
weighted avg 0.47 0.50 0.47 20
完整代码
python
#!/usr/bin/python3
# Create By GF 2024-01-04
# 在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。
# 为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。
# 然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。
# 最后, 我们训练模型、进行预测, 并评估模型性能。
# 请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。
import datetime
# --------------------------------------------------
import pandas as pd
# --------------------------------------------------
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler
# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:
if (Short_MA >= Long_MA): return 1
if (Short_MA == Long_MA): return 0
if (Short_MA <= Long_MA): return -1
# ==============================================
# End of Function.
if __name__ == "__main__":
PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")
print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
print(PDF)
# 转换 Pandas 中 DateFrame 数据类型。
PDF["Date"] = PDF["Date"].astype("datetime64[ns]")
PDF["Open"] = PDF["Open"].astype("float64")
PDF["High"] = PDF["High"].astype("float64")
PDF["Low"] = PDF["Low"].astype("float64")
PDF["Close"] = PDF["Close"].astype("float64")
PDF["Pre_Close"] = PDF["Pre_Close"].astype("float64")
PDF["Change"] = PDF["Change"].astype("float64")
PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
PDF["Volume"] = PDF["Volume"].astype("int64")
PDF["MA5"] = PDF["MA5"].astype("float64")
PDF["MA10"] = PDF["MA10"].astype("float64")
# 输出 Pandas 中 DataFrame 字段和数据类型。
print("[Message] Changed Pandas DataFrame Data Type:")
print(PDF.dtypes)
# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
PDF["Weekday(Idx)"] = PDF["Date"].apply(lambda X: X.weekday())
# ..................................................
# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
PDF["Festival"] = None
for Idx in PDF.index:
if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。
if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas" # -> 圣诞节。
# ..................................................
# 计算数据: 判断股票涨跌。
PDF["Rise_Fall"] = PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
# ..................................................
# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)
# 输出计算好的 DataFrame 数据框。
print("[Message] Calculated DataFrame:")
print(PDF)
# 将字符串类型的特征列转换为数值 (独热编码)。
PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)
# 输出转换后的 DataFrame 数据框。
print("[Message] DataFrame After One-Hot Encoding:")
print(PDF)
# 提取 标签(Label) 列。
Y = PDF["Rise_Fall"]
# 提取 特征(Feature) 列。
X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)
# 数据集划分训练集和测试集(train_test_split)。
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)
# 特征标准化(StandardScaler)。
Obj_Scaler = StandardScaler()
X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
X_Test_Scaled = Obj_Scaler.transform(X_Test)
# 创建 决策树分类器(DecisionTreeClassifier)。
DTC = DecisionTreeClassifier(random_state=42)
# 训练 决策树分类器(DecisionTreeClassifier) 模型。
DTC.fit(X_Train_Scaled, Y_Train)
# Value of Return:
# +----------------------------------------+
# |▼ DecisionTreeClassifier |
# +----------------------------------------+
# | DecisionTreeClassifier(random_state=42)|
# +----------------------------------------+
# 在测试集上进行预测。
Y_Pred = DTC.predict(X_Test_Scaled)
# 合并预测结果。
Result = X_Test.copy()
Result["Actually"] = Y_Test
Result["Prediction"] = Y_Pred
print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
print(Result)
# 评估模型性能。
Accuracy = accuracy_score(Y_Test, Y_Pred)
print("Accuracy:", Accuracy)
print("\n")
# 输出分类报告。
print("Classification Report:")
print(classification_report(Y_Test, Y_Pred))
其它
在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。
为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。
然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。
最后, 我们训练模型、进行预测, 并评估模型性能。
请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。
总结
以上就是关于 金融数据 Scikit-Learn决策树(DecisionTreeClassifier)实例 的全部内容。
更多内容可以访问我的代码仓库: