金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例

金融数据_Scikit-Learn决策树(DecisionTreeClassifier)实例

逻辑回归: 逻辑回归常被用于二分类问题, 比如涨跌预测。你可以将涨跌标记为类别, 然后使用逻辑回归进行训练。

决策树和随机森林: 决策树和随机森林是用于分类问题的强大模型。它们能够处理非线性关系, 并且对于特征的重要性有较好的解释。

实例数据

本实例截取了 "湖北宜化(000422)" 2015年08月06日 - 2015年12月31日的数据。

HBYH_000422_20150806_20151231.csv

csv 复制代码
Date,Code,Open,High,Low,Close,Pre_Close,Change,Turnover_Rate,Volume,MA5,MA10
2015-12-31,'000422,7.93,7.95,7.76,7.77,7.93,-0.020177,0.015498,13915200,7.86,7.85
2015-12-30,'000422,7.86,7.93,7.75,7.93,7.84,0.011480,0.018662,16755900,7.90,7.85
2015-12-29,'000422,7.72,7.85,7.69,7.84,7.71,0.016861,0.015886,14263800,7.90,7.81
2015-12-28,'000422,8.03,8.08,7.70,7.71,8.03,-0.039851,0.030821,27672800,7.91,7.78
2015-12-25,'000422,8.03,8.05,7.93,8.03,7.99,0.005006,0.021132,18974000,7.93,7.78
2015-12-24,'000422,7.93,8.16,7.87,7.99,7.92,0.008838,0.026487,23781900,7.85,7.72
2015-12-23,'000422,7.97,8.11,7.88,7.92,7.89,0.003802,0.042360,38033600,7.80,7.69
2015-12-22,'000422,7.86,7.93,7.76,7.89,7.83,0.007663,0.026929,24178700,7.73,7.68
2015-12-21,'000422,7.59,7.89,7.56,7.83,7.63,0.026212,0.030777,27633600,7.66,7.67
2015-12-18,'000422,7.71,7.74,7.57,7.63,7.74,-0.014212,0.024764,22234900,7.62,7.71
2015-12-17,'000422,7.58,7.75,7.57,7.74,7.55,0.025166,0.028054,25188400,7.59,7.77
2015-12-16,'000422,7.57,7.62,7.53,7.55,7.55,0.000000,0.020718,18601600,7.58,7.79
2015-12-15,'000422,7.63,7.66,7.52,7.55,7.62,-0.009186,0.025902,23256600,7.64,7.78
2015-12-14,'000422,7.40,7.64,7.36,7.62,7.51,0.014647,0.021005,18860100,7.68,7.76
2015-12-11,'000422,7.65,7.70,7.41,7.51,7.67,-0.020860,0.020477,18385900,7.80,7.73
2015-12-10,'000422,7.78,7.87,7.65,7.67,7.83,-0.020434,0.019972,17931900,7.95,7.69
2015-12-09,'000422,7.76,8.00,7.75,7.83,7.77,0.007722,0.025137,22569700,8.00,7.68
2015-12-08,'000422,8.08,8.18,7.76,7.77,8.24,-0.057039,0.036696,32948200,7.92,7.66
2015-12-07,'000422,8.12,8.39,7.94,8.24,8.23,0.001215,0.064590,57993100,7.84,7.64
2015-12-04,'000422,7.85,8.48,7.80,8.23,7.92,0.039141,0.100106,89881900,7.65,7.58
2015-12-03,'000422,7.42,8.09,7.38,7.92,7.43,0.065949,0.045416,40777500,7.43,7.52
2015-12-02,'000422,7.35,7.48,7.20,7.43,7.36,0.009511,0.015968,14337600,7.37,7.49
2015-12-01,'000422,7.28,7.39,7.23,7.36,7.33,0.004093,0.012308,11050700,7.41,7.48
2015-11-30,'000422,7.18,7.36,6.95,7.33,7.11,0.030942,0.020323,18247500,7.45,7.50
2015-11-27,'000422,7.59,7.59,6.95,7.11,7.60,-0.064474,0.027673,24846700,7.51,7.52
2015-11-26,'000422,7.63,7.73,7.58,7.60,7.63,-0.003932,0.024836,22299800,7.61,7.54
2015-11-25,'000422,7.56,7.64,7.51,7.63,7.59,0.005270,0.020919,18782900,7.61,7.54
2015-11-24,'000422,7.60,7.63,7.48,7.59,7.62,-0.003937,0.014867,13348200,7.56,7.53
2015-11-23,'000422,7.59,7.72,7.55,7.62,7.61,0.001314,0.028406,25505000,7.54,7.53
2015-11-20,'000422,7.59,7.71,7.53,7.61,7.59,0.002635,0.028277,25389100,7.52,7.53
2015-11-19,'000422,7.45,7.62,7.41,7.59,7.39,0.027064,0.038638,34691700,7.47,7.52
2015-11-18,'000422,7.53,7.54,7.38,7.39,7.51,-0.015979,0.014173,12725000,7.46,7.50
2015-11-17,'000422,7.53,7.63,7.44,7.51,7.50,0.001333,0.028640,25714500,7.51,7.50
2015-11-16,'000422,7.27,7.52,7.24,7.50,7.38,0.016260,0.016230,14572000,7.52,7.46
2015-11-13,'000422,7.49,7.55,7.36,7.38,7.54,-0.021220,0.029196,26214400,7.53,7.41
2015-11-12,'000422,7.65,7.68,7.49,7.54,7.61,-0.009198,0.026501,23794800,7.56,7.40
2015-11-11,'000422,7.57,7.64,7.52,7.61,7.57,0.005284,0.026113,23445900,7.54,7.37
2015-11-10,'000422,7.51,7.61,7.45,7.57,7.55,0.002649,0.024979,22427700,7.49,7.32
2015-11-09,'000422,7.51,7.62,7.45,7.55,7.53,0.002656,0.033367,29959500,7.39,7.31
2015-11-06,'000422,7.47,7.53,7.37,7.53,7.45,0.010738,0.037058,33273100,7.29,7.27
2015-11-05,'000422,7.34,7.54,7.32,7.45,7.37,0.010855,0.040463,36330200,7.24,7.24
2015-11-04,'000422,7.10,7.38,7.07,7.37,7.05,0.045390,0.034817,31260800,7.20,7.17
2015-11-03,'000422,7.08,7.13,7.02,7.05,7.06,-0.001416,0.014938,13412400,7.15,7.10
2015-11-02,'000422,7.11,7.26,7.05,7.06,7.26,-0.027548,0.016865,15142100,7.23,7.10
2015-10-30,'000422,7.22,7.38,7.10,7.26,7.24,0.002762,0.022821,20490200,7.25,7.10
2015-10-29,'000422,7.27,7.33,7.16,7.24,7.16,0.011173,0.025726,23098500,7.23,7.08
2015-10-28,'000422,7.32,7.40,7.09,7.16,7.42,-0.035040,0.035572,31938500,7.15,7.05
2015-10-27,'000422,7.21,7.48,7.08,7.42,7.18,0.033426,0.057658,51769300,7.04,7.01
2015-10-26,'000422,7.20,7.25,7.01,7.18,7.17,0.001395,0.036840,33077800,6.98,6.96
2015-10-23,'000422,6.84,7.22,6.81,7.17,6.80,0.054412,0.047169,42351500,6.95,6.93
2015-10-22,'000422,6.68,6.81,6.64,6.80,6.65,0.022556,0.020609,18503800,6.93,6.87
2015-10-21,'000422,7.08,7.11,6.61,6.65,7.09,-0.062059,0.039388,35365300,6.96,6.85
2015-10-20,'000422,7.00,7.09,6.94,7.09,7.03,0.008535,0.024472,21972900,6.98,6.81
2015-10-19,'000422,7.09,7.13,6.92,7.03,7.08,-0.007062,0.031262,28068800,6.94,6.72
2015-10-16,'000422,6.97,7.08,6.91,7.08,6.93,0.021645,0.039632,35584700,6.91,6.66
2015-10-15,'000422,6.77,6.94,6.75,6.93,6.77,0.023634,0.031645,28412700,6.82,6.59
2015-10-14,'000422,6.87,6.94,6.74,6.77,6.89,-0.017417,0.027226,24445500,6.74,6.55
2015-10-13,'000422,6.86,6.96,6.80,6.89,6.88,0.001453,0.028704,25771900,6.64,6.51
2015-10-12,'000422,6.62,6.91,6.58,6.88,6.61,0.040847,0.037037,33254300,6.50,6.49
2015-10-09,'000422,6.54,6.65,6.45,6.61,6.54,0.010703,0.018528,16635900,6.41,6.46
2015-10-08,'000422,6.45,6.70,6.37,6.54,6.26,0.044728,0.018857,16931000,6.35,6.44
2015-09-30,'000422,6.25,6.30,6.22,6.26,6.23,0.004815,0.007327,6579090,6.35,6.43
2015-09-29,'000422,6.30,6.32,6.18,6.23,6.40,-0.026562,0.008991,8072900,6.39,6.48
2015-09-28,'000422,6.35,6.42,6.25,6.40,6.34,0.009464,0.008824,7922890,6.48,6.47
2015-09-25,'000422,6.51,6.56,6.25,6.34,6.53,-0.029096,0.012584,11298800,6.51,6.45
2015-09-24,'000422,6.48,6.56,6.45,6.53,6.45,0.012403,0.011339,10180900,6.53,6.51
2015-09-23,'000422,6.51,6.60,6.41,6.45,6.67,-0.032984,0.015920,14294100,6.52,6.54
2015-09-22,'000422,6.58,6.73,6.54,6.67,6.58,0.013678,0.023356,20970200,6.56,6.60
2015-09-21,'000422,6.34,6.61,6.29,6.58,6.44,0.021739,0.017036,15295900,6.46,6.62
2015-09-18,'000422,6.52,6.58,6.30,6.44,6.44,0.000000,0.016622,14924700,6.39,6.62
2015-09-17,'000422,6.59,6.76,6.43,6.44,6.68,-0.035928,0.019517,17523900,6.48,6.62
2015-09-16,'000422,6.21,6.76,6.17,6.68,6.15,0.086179,0.019671,17662300,6.56,6.65
2015-09-15,'000422,6.24,6.38,6.05,6.15,6.26,-0.017572,0.015338,13771200,6.64,6.66
2015-09-14,'000422,6.89,6.95,6.18,6.26,6.87,-0.088792,0.021233,18559600,6.78,6.75
2015-09-11,'000422,6.87,6.96,6.77,6.87,6.84,0.004386,0.010853,9486290,6.85,6.79
2015-09-10,'000422,6.95,7.01,6.76,6.84,7.06,-0.031161,0.017423,15229100,6.76,6.74
2015-09-09,'000422,6.90,7.09,6.86,7.06,6.88,0.026163,0.028974,25325600,6.74,6.68
2015-09-08,'000422,6.65,6.91,6.55,6.88,6.62,0.039275,0.017858,15609100,6.69,6.67
2015-09-07,'000422,6.50,6.81,6.50,6.62,6.38,0.037618,0.017850,15602600,6.72,6.75
2015-09-02,'000422,6.45,6.88,6.30,6.38,6.74,-0.053412,0.022286,19480100,6.73,6.91
2015-09-01,'000422,6.88,6.99,6.67,6.74,6.81,-0.010279,0.025829,22576700,6.72,7.12
2015-08-31,'000422,6.90,6.97,6.71,6.81,7.07,-0.036775,0.018385,16069600,6.62,7.24
2015-08-28,'000422,6.75,7.08,6.71,7.07,6.67,0.059970,0.026692,23330800,6.65,7.44
2015-08-27,'000422,6.53,6.67,6.34,6.67,6.32,0.055380,0.022455,19627900,6.78,7.59
2015-08-26,'000422,6.31,6.77,6.09,6.32,6.25,0.011200,0.029963,26190200,7.08,7.76
2015-08-25,'000422,6.40,6.77,6.25,6.25,6.94,-0.099424,0.029492,25778600,7.52,7.96
2015-08-24,'000422,7.49,7.49,6.94,6.94,7.71,-0.099870,0.036552,31949900,7.86,8.18
2015-08-21,'000422,8.00,8.11,7.60,7.71,8.17,-0.056304,0.032199,28144800,8.23,8.33
2015-08-20,'000422,8.38,8.56,8.14,8.17,8.53,-0.042204,0.031764,27764200,8.40,8.38
2015-08-19,'000422,7.73,8.57,7.72,8.53,7.96,0.071608,0.052192,45619900,8.45,8.37
2015-08-18,'000422,8.81,8.86,7.92,7.96,8.80,-0.095455,0.056179,49105500,8.39,8.32
2015-08-17,'000422,8.49,8.83,8.42,8.80,8.52,0.032864,0.048161,42096900,8.50,8.35
2015-08-14,'000422,8.48,8.65,8.43,8.52,8.44,0.009479,0.041169,35985000,8.43,8.24
2015-08-13,'000422,8.20,8.45,8.15,8.44,8.24,0.024272,0.029768,26019600,8.37,8.16
2015-08-12,'000422,8.38,8.48,8.21,8.24,8.48,-0.028302,0.035421,30960700,8.30,8.08
2015-08-11,'000422,8.41,8.68,8.32,8.48,8.49,-0.001178,0.048444,42343900,8.26,8.03
2015-08-10,'000422,8.28,8.58,8.18,8.49,8.21,0.034105,0.041268,36071600,8.20,7.92
2015-08-07,'000422,8.15,8.28,8.08,8.21,8.07,0.017348,0.025855,22599800,8.05,7.81
2015-08-06,'000422,7.88,8.21,7.80,8.07,8.03,0.004981,0.020074,17546700,7.95,7.80

探索思路

这里只是简单示例, 目的在于熟悉 Scikit-Learn 中的决策树分类器使用方法, 无任何投资引导。

目标:

通过当日数值情况, 预测当日收盘涨跌, 如果 "涨跌幅(Change) >= 0", 则用 1 表示, 如果 "涨跌幅(Change) < 0", 则用 0 表示 (二分类标签)。

变量:

  1. 当日最高价

  2. 当日最低价

  3. 当日换手率

  4. 当日成交量

  5. 当日星期几 (星期对价格的影响)

  6. 当日 "短期均线(MA5)" 与 "长期均线(MA10)" 的关系, 如果 "MA5 > MA10", 则用 1 表示, 如果 "MA5 = MA10", 则用 0 表示, 如果 "MA5 < MA10", 则用 -1 表示。

  7. 节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。

导入 Pandas 相关模块

Pandas 是基于 NumPy 的一种工具, 该工具是为解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型, 提供了高效地操作大型数据集所需的工具。

Pandas 提供了大量能使我们快速便捷地处理数据的函数和方法。你很快就会发现, 它是使 Python 成为强大而高效的数据分析环境的重要因素之一。

python 复制代码
import pandas as pd

导入 Scikit-Learn 相关模块

Scikit-Learn (以前称为 scikits.learn, 也称为 sklearn) 是针对 Python 编程语言的免费软件机器学习库。

它具有各种分类, 回归和聚类算法, 包括支持向量机, 随机森林, 梯度提升, K均值 和 DBSCAN, 并且旨在与 Python 数值科学库 NumPy 和 SciPy 联合使用。

python 复制代码
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler

使用 Pandas 读取 CSV 数据

调用 Pandas 的 .read_csv 方法读取 CSV 数据:

其中 header 参数指定 CSV 文件的表头行, 这里的 header=0 表示表头行在 1 行, 如果 header=None 则表示数据没有列索引, Pandas 则会自动加上索引。

其中 sep 参数指定 CSV 文件的分隔符, 默认情况下都是以 "," 作为分隔符, 这里的 sep="," 表示指定 CSV 文件的分隔符为 ","。

还有 dtype 参数指定 CSV 某些特定列以特定的数据类型进行读取, 例如 dtype={"Close":float, "Volume":int} 表示 "Close" 列以 浮点(float) 类型读取, "Volume" 列以 整数(integer) 类型读取。

python 复制代码
PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")

输出 DataFrame 数据框:

python 复制代码
print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
print(PDF)

输出:

txt 复制代码
[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv
          Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10
0   2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85
1   2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85
2   2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81
3   2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78
4   2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78
..         ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...
94  2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08
95  2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03
96  2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92
97  2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81
98  2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80

[99 rows x 12 columns]

转换 Pandas 中 DateFrame 各列数据类型

通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。

python 复制代码
# 转换 Pandas 中 DateFrame 数据类型。
PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")
PDF["Open"] =          PDF["Open"].astype("float64")
PDF["High"] =          PDF["High"].astype("float64")
PDF["Low"] =           PDF["Low"].astype("float64")
PDF["Close"] =         PDF["Close"].astype("float64")
PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")
PDF["Change"] =        PDF["Change"].astype("float64")
PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
PDF["Volume"] =        PDF["Volume"].astype("int64")
PDF["MA5"] =           PDF["MA5"].astype("float64")
PDF["MA10"] =          PDF["MA10"].astype("float64")

# 输出 Pandas 中 DataFrame 字段和数据类型。
print("[Message] Changed Pandas DataFrame Data Type:")
print(PDF.dtypes)

输出:

txt 复制代码
[Message] Changed Pandas DataFrame Data Type:
Date             datetime64[ns]
Code                     object
Open                    float64
High                    float64
Low                     float64
Close                   float64
Pre_Close               float64
Change                  float64
Turnover_Rate           float64
Volume                    int64
MA5                     float64
MA10                    float64
dtype: object

在 Pandas 的 DataFrame 中计算数据

编写 "判断股票短期均线和长期均线关系" 函数:

python 复制代码
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:

    if (Short_MA >= Long_MA): return  1
    if (Short_MA == Long_MA): return  0
    if (Short_MA <= Long_MA): return -1

    # ==============================================
    # End of Function.

在 Pandas 的 DataFrame 中直接计算或调用自定义函数:

python 复制代码
# 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())
# ..................................................
# 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
PDF["Festival"] = None
for Idx in PDF.index:
    if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。
    if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。
# ..................................................
# 计算数据: 判断股票涨跌。
PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
# ..................................................
# 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)

# 输出计算好的 DataFrame 数据框。
print("[Message] Calculated DataFrame:")
print(PDF)

输出:

txt 复制代码
[Message] Calculated DataFrame:
         Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)   Festival  Rise_Fall  MA_Relationship
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3       None          0                1
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2       None          1                1
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1       None          1                1
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0       None          0                1
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4  Christmas          1                1
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...        ...              ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2       None          0                1
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1       None          0                1
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0       None          1                1
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4       None          1                1
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3       None          1                1

[99 rows x 16 columns]

在 Pandas 的 DataFrame 中将字符串类型的特征列转换为数值 (One-Hot Encoding)

pd.get_dummies() 是 Pandas 库中用于独热编码 (One-Hot Encoding) 的函数。它的作用是将分类 (离散) 变量的每个不同取值都拓展为一个新的二进制特征 (0 或 1), 从而方便机器学习模型处理。

python 复制代码
# 函数签名:
pd.get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False, columns=None, sparse=False, drop_first=False, dtype=None)

# 参数说明:
# - data: 要进行独热编码的 DataFrame 或 Series。
# - prefix: 生成的独热编码列的前缀。
# - prefix_sep: 生成的独热编码列的前缀和原始列名之间的分隔符。
# - dummy_na: 是否为原始数据中的缺失值生成独热编码列。
# - columns: 要进行独热编码的列的名称, 如果指定, 则只对这些列进行操作。
# - drop_first: 是否删除第一个独热编码列, 以避免共线性问题。

转换 Festival 特征列为数值:

python 复制代码
# 将字符串类型的特征列转换为数值 (独热编码)。
PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)

# 输出转换后的 DataFrame 数据框。
print("[Message] DataFrame After One-Hot Encoding:")
print(PDF)

输出:

txt 复制代码
[Message] DataFrame After One-Hot Encoding:
         Date     Code  Open  High   Low  Close  Pre_Close    Change  Turnover_Rate    Volume   MA5  MA10  Weekday(Idx)  Rise_Fall  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve
0  2015-12-31  '000422  7.93  7.95  7.76   7.77       7.93 -0.020177       0.015498  13915200  7.86  7.85             3          0                1                   0                       0
1  2015-12-30  '000422  7.86  7.93  7.75   7.93       7.84  0.011480       0.018662  16755900  7.90  7.85             2          1                1                   0                       0
2  2015-12-29  '000422  7.72  7.85  7.69   7.84       7.71  0.016861       0.015886  14263800  7.90  7.81             1          1                1                   0                       0
3  2015-12-28  '000422  8.03  8.08  7.70   7.71       8.03 -0.039851       0.030821  27672800  7.91  7.78             0          0                1                   0                       0
4  2015-12-25  '000422  8.03  8.05  7.93   8.03       7.99  0.005006       0.021132  18974000  7.93  7.78             4          1                1                   1                       0
..        ...      ...   ...   ...   ...    ...        ...       ...            ...       ...   ...   ...           ...        ...              ...                 ...                     ...
94 2015-08-12  '000422  8.38  8.48  8.21   8.24       8.48 -0.028302       0.035421  30960700  8.30  8.08             2          0                1                   0                       0
95 2015-08-11  '000422  8.41  8.68  8.32   8.48       8.49 -0.001178       0.048444  42343900  8.26  8.03             1          0                1                   0                       0
96 2015-08-10  '000422  8.28  8.58  8.18   8.49       8.21  0.034105       0.041268  36071600  8.20  7.92             0          1                1                   0                       0
97 2015-08-07  '000422  8.15  8.28  8.08   8.21       8.07  0.017348       0.025855  22599800  8.05  7.81             4          1                1                   0                       0
98 2015-08-06  '000422  7.88  8.21  7.80   8.07       8.03  0.004981       0.020074  17546700  7.95  7.80             3          1                1                   0                       0

[99 rows x 17 columns]

提取 标签(Label)列 和 特征(Feature)列

提取 标签(Label) 列:

python 复制代码
# 提取 标签(Label) 列。
Y = PDF["Rise_Fall"]

提取 特征(Feature) 列:

python 复制代码
# 提取 特征(Feature) 列。
X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)

划分训练集和测试集(train_test_split) 以及 特征标准化(StandardScaler)

划分训练集和测试集(train_test_split):

python 复制代码
# 数据集划分训练集和测试集(train_test_split)。
X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)

特征标准化(StandardScaler):

在机器学习中, fit_transform 和 transform 是用于数据预处理的常见方法, 它们的作用略有不同:

fit_transform: 该方法将同时拟合和转换数据。

  • 它会根据输入的数据计算所需的转换参数 (例如均值、标准差等), 然后将数据应用这些参数进行转换。

  • 在训练阶段, 通常使用 fit_transform 来对训练集进行拟合和转换。

  • 拟合过程会根据训练集数据计算并保存所需的转换参数, 然后将训练集数据应用这些参数进行转换。

  • 这样做的目的是确保在后续对测试集或新数据进行转换时使用相同的转换参数。

transform: 该方法仅对数据进行转换, 不进行拟合过程。

  • 它根据之前使用 fit_transform 得到的转换参数, 将这些参数应用于新的数据, 使其按照相同的转换方式进行处理。

  • 在测试阶段或对新数据应用模型时, 通常使用 transform 方法对测试集或新数据进行转换。

简而言之, fit_transform 方法用于拟合转换器并将数据进行转换, 而 transform 方法仅用于将数据按照已经拟合的转换器进行转换。

在代码中的具体应用上, 通常将 fit_transform 用于训练集的拟合和转换, 将 transform 用于测试集或新数据的转换, 以保证数据的一致性和正确的预处理操作。

python 复制代码
# 特征标准化(StandardScaler)。
Obj_Scaler = StandardScaler()
X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
X_Test_Scaled = Obj_Scaler.transform(X_Test)

训练 决策树分类器(DecisionTreeClassifier) 模型

创建 决策树分类器(DecisionTreeClassifier):

python 复制代码
# 创建 决策树分类器(DecisionTreeClassifier)。
DTC = DecisionTreeClassifier(random_state=42)

训练 决策树分类器(DecisionTreeClassifier) 模型:

python 复制代码
# 训练 决策树分类器(DecisionTreeClassifier) 模型。
DTC.fit(X_Train_Scaled, Y_Train)

# Value of Return:
# +----------------------------------------+
# |▼        DecisionTreeClassifier         |
# +----------------------------------------+
# | DecisionTreeClassifier(random_state=42)|
# +----------------------------------------+

使用 决策树分类器(DecisionTreeClassifier) 模型预测数据

python 复制代码
# 在测试集上进行预测。
Y_Pred = DTC.predict(X_Test_Scaled)

# 合并预测结果。
Result = X_Test.copy()
Result["Actually"] = Y_Test
Result["Prediction"] = Y_Pred

print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
print(Result)

输出:

txt 复制代码
[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:
    High   Low  Turnover_Rate    Volume  Weekday(Idx)  MA_Relationship  Festival_Christmas  Festival_Christmas_Eve  Actually  Prediction
62  6.32  6.18       0.008991   8072900             1               -1                   0                       0         0           1
40  7.54  7.32       0.040463  36330200             3                1                   0                       0         1           1
95  8.68  8.32       0.048444  42343900             1                1                   0                       0         0           1
18  8.39  7.94       0.064590  57993100             0                1                   0                       0         1           1
97  8.28  8.08       0.025855  22599800             4                1                   0                       0         1           1
84  6.77  6.09       0.029963  26190200             2               -1                   0                       0         1           0
64  6.56  6.25       0.012584  11298800             4                1                   0                       0         0           1
42  7.13  7.02       0.014938  13412400             1                1                   0                       0         0           0
10  7.75  7.57       0.028054  25188400             3               -1                   0                       0         1           0
0   7.95  7.76       0.015498  13915200             3                1                   0                       0         0           1
31  7.54  7.38       0.014173  12725000             2               -1                   0                       0         0           0
76  7.09  6.86       0.028974  25325600             2                1                   0                       0         1           1
47  7.48  7.08       0.057658  51769300             1                1                   0                       0         1           1
26  7.64  7.51       0.020919  18782900             2                1                   0                       0         1           1
44  7.38  7.10       0.022821  20490200             4                1                   0                       0         1           0
4   8.05  7.93       0.021132  18974000             4                1                   1                       0         1           1
22  7.39  7.23       0.012308  11050700             1               -1                   0                       0         1           1
12  7.66  7.52       0.025902  23256600             1               -1                   0                       0         0           1
88  8.56  8.14       0.031764  27764200             3                1                   0                       0         0           1
73  6.95  6.18       0.021233  18559600             0                1                   0                       0         0           1

使用 accuracy_score 评估模型性能

python 复制代码
# 评估模型性能。
Accuracy = accuracy_score(Y_Test, Y_Pred)
print("Accuracy:", Accuracy)
print("\n")

# 输出分类报告。
print("Classification Report:")
print(classification_report(Y_Test, Y_Pred))

输出:

txt 复制代码
Accuracy: 0.5

Classification Report:
              precision    recall  f1-score   support

           0       0.40      0.22      0.29         9
           1       0.53      0.73      0.62        11

    accuracy                           0.50        20
   macro avg       0.47      0.47      0.45        20
weighted avg       0.47      0.50      0.47        20

完整代码

python 复制代码
#!/usr/bin/python3
# Create By GF 2024-01-04

# 在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。
# 为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。
# 然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。
# 最后, 我们训练模型、进行预测, 并评估模型性能。
# 请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。

import datetime
# --------------------------------------------------
import pandas as pd
# --------------------------------------------------
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler

# 编写 "判断股票短期均线和长期均线关系" 函数。
def MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA:float, Long_MA:float) -> int:

    if (Short_MA >= Long_MA): return  1
    if (Short_MA == Long_MA): return  0
    if (Short_MA <= Long_MA): return -1

    # ==============================================
    # End of Function.

if __name__ == "__main__":

    PDF = pd.read_csv("D:\\HBYH_000422_20150806_20151231.csv", header=0, sep=",")
    
    print("[Message] Readed CSV File: D:\\HBYH_000422_20150806_20151231.csv")
    print(PDF)
    
    # 转换 Pandas 中 DateFrame 数据类型。
    PDF["Date"] =          PDF["Date"].astype("datetime64[ns]")
    PDF["Open"] =          PDF["Open"].astype("float64")
    PDF["High"] =          PDF["High"].astype("float64")
    PDF["Low"] =           PDF["Low"].astype("float64")
    PDF["Close"] =         PDF["Close"].astype("float64")
    PDF["Pre_Close"] =     PDF["Pre_Close"].astype("float64")
    PDF["Change"] =        PDF["Change"].astype("float64")
    PDF["Turnover_Rate"] = PDF["Turnover_Rate"].astype("float64")
    PDF["Volume"] =        PDF["Volume"].astype("int64")
    PDF["MA5"] =           PDF["MA5"].astype("float64")
    PDF["MA10"] =          PDF["MA10"].astype("float64")
    
    # 输出 Pandas 中 DataFrame 字段和数据类型。
    print("[Message] Changed Pandas DataFrame Data Type:")
    print(PDF.dtypes)
    
    # 计算数据: 提取星期的索引, 从 0 到 6 (0 代表周一, 6 代表周日)。
    PDF["Weekday(Idx)"] =    PDF["Date"].apply(lambda X: X.weekday())
    # ..................................................
    # 计算数据: 计算节日 (节日对 A 股的影响, 中国节日 A 股休市, 所以只能探索国外节日对 A 股的影响, 这里仅用 "圣诞节(Christmas)" 和 "平安夜(Christmas Eve)" 做示例)。
    PDF["Festival"] = None
    for Idx in PDF.index:
        if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,24): PDF.loc[Idx, "Festival"] = "Christmas_Eve" # -> 平安夜。
        if PDF.loc[Idx, "Date"] == datetime.datetime(2015,12,25): PDF.loc[Idx, "Festival"] = "Christmas"     # -> 圣诞节。
    # ..................................................
    # 计算数据: 判断股票涨跌。
    PDF["Rise_Fall"] =       PDF["Change"].apply(lambda X: int(1) if X >= 0 else int(0))
    # ..................................................
    # 计算数据: 调用函数, 判断股票短期均线和长期均线关系。
    PDF["MA_Relationship"] = PDF.apply(lambda X: MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"]), axis=1)
    
    # 输出计算好的 DataFrame 数据框。
    print("[Message] Calculated DataFrame:")
    print(PDF)
    
    # 将字符串类型的特征列转换为数值 (独热编码)。
    PDF = pd.get_dummies(PDF, columns=["Festival"], drop_first=False)
    
    # 输出转换后的 DataFrame 数据框。
    print("[Message] DataFrame After One-Hot Encoding:")
    print(PDF)
    
    # 提取 标签(Label) 列。
    Y = PDF["Rise_Fall"]
    
    # 提取 特征(Feature) 列。
    X = PDF.drop(["Date", "Code", "Open", "Close", "Pre_Close", "Change", "MA5", "MA10", "Rise_Fall"], axis=1)
    
    # 数据集划分训练集和测试集(train_test_split)。
    X_Train, X_Test, Y_Train, Y_Test = train_test_split(X, Y, test_size=0.2, random_state=42)
    
    # 特征标准化(StandardScaler)。
    Obj_Scaler = StandardScaler()
    X_Train_Scaled = Obj_Scaler.fit_transform(X_Train)
    X_Test_Scaled = Obj_Scaler.transform(X_Test)
    
    # 创建 决策树分类器(DecisionTreeClassifier)。
    DTC = DecisionTreeClassifier(random_state=42)
    
    # 训练 决策树分类器(DecisionTreeClassifier) 模型。
    DTC.fit(X_Train_Scaled, Y_Train)
    
    # Value of Return:
    # +----------------------------------------+
    # |▼        DecisionTreeClassifier         |
    # +----------------------------------------+
    # | DecisionTreeClassifier(random_state=42)|
    # +----------------------------------------+
    
    # 在测试集上进行预测。
    Y_Pred = DTC.predict(X_Test_Scaled)
    
    # 合并预测结果。
    Result = X_Test.copy()
    Result["Actually"] = Y_Test
    Result["Prediction"] = Y_Pred
    
    print("[Message] Prediction Results on The Test Data Set for DecisionTreeClassifier:")
    print(Result)
    
    # 评估模型性能。
    Accuracy = accuracy_score(Y_Test, Y_Pred)
    print("Accuracy:", Accuracy)
    print("\n")
    
    # 输出分类报告。
    print("Classification Report:")
    print(classification_report(Y_Test, Y_Pred))

其它

在这个示例中, 我们使用 DecisionTreeClassifier 构建决策树模型。

为了处理字符串类型的特征列, 我们使用了 pd.get_dummies 进行独热编码。

然后, 我们对特征进行标准化, 并使用 train_test_split 将数据集划分为训练集和测试集。

最后, 我们训练模型、进行预测, 并评估模型性能。

请注意, 这只是一个基本的示例, 实际应用中你可能需要更多的特征工程、调参和模型评估。

总结

以上就是关于 金融数据 Scikit-Learn决策树(DecisionTreeClassifier)实例 的全部内容。

更多内容可以访问我的代码仓库:

https://gitee.com/goufeng928/public

https://github.com/goufeng928/public

相关推荐
蹉跎x44 分钟前
力扣1358. 包含所有三种字符的子字符串数目
数据结构·算法·leetcode·职场和发展
巫师不要去魔法部乱说2 小时前
PyCharm专项训练4 最小生成树算法
算法·pycharm
IT猿手2 小时前
最新高性能多目标优化算法:多目标麋鹿优化算法(MOEHO)求解GLSMOP1-GLSMOP9及工程应用---盘式制动器设计,提供完整MATLAB代码
开发语言·算法·机器学习·matlab·强化学习
阿七想学习2 小时前
数据结构《排序》
java·数据结构·学习·算法·排序算法
王老师青少年编程2 小时前
gesp(二级)(12)洛谷:B3955:[GESP202403 二级] 小杨的日字矩阵
c++·算法·矩阵·gesp·csp·信奥赛
Kenneth風车3 小时前
【机器学习(九)】分类和回归任务-多层感知机(Multilayer Perceptron,MLP)算法-Sentosa_DSML社区版 (1)111
算法·机器学习·分类
eternal__day3 小时前
数据结构(哈希表(中)纯概念版)
java·数据结构·算法·哈希算法·推荐算法
APP 肖提莫3 小时前
MyBatis-Plus分页拦截器,源码的重构(重构total总数的计算逻辑)
java·前端·算法
OTWOL3 小时前
两道数组有关的OJ练习题
c语言·开发语言·数据结构·c++·算法
qq_433554544 小时前
C++ 面向对象编程:递增重载
开发语言·c++·算法