前几天看到一个CoreML的教程视频,感觉挺有意思的样子,于是去了解了一下,决定尝试将他用于预测涨跌,看下机器学习的能不能预测的准,就算不行,也无所谓,就相当于学习好了
模型类型
CoreML需要选定一个模型来进行训练 打开developer tool,然后点击New document就可以看到有多少类型 前面几个都是图片分类,图片识别,手势识别,文字识别啥的,我也没细看,今天我们就用表格的Tabular Regression来给他一些基础数据,让他来预测
模型选型
曾经也想过给复杂的数据给他去训练,有点复杂,还是先给一些简单的了,我的想法就是以最后一根k线的涨跌幅作为目标,往前数五根k线作为基础数据 生成模型需要准备一个csv文件,我准备像下面这样的格式,前面的k是基础数据,target是目标数据,这样的数据交给模型训练器进行训练
| k0 | k1 | k2 | k3 | k4 | target |
|-------|-------|-------|-------|-------|----------|---|
| 0.09 | -0.16 | 0.23 | -0.06 | 0.07 | -0.05 |
| -0.16 | 0.23 | -0.06 | 0.07 | -0.05 | -0.10 | |
| 0.23 | -0.06 | 0.07 | -0.05 | -0.10 | 0.05 |
| -0.06 | 0.07 | -0.05 | -0.10 | 0.05 | -0.06 |
单位是百分点
获取数据
要得到这样的数据,需要从交易平台请求过往的k线数据 我从币安的api接口查到接口和参数,然后查询从某个时刻开始的数据,一次请求1500条,1500条乘以5分钟
swift
open func requestCandles(startTime: Int? = nil, limit: Int? = nil) async throws -> (response: HTTPURLResponse, candles: [Candle]) {
let path = "GET /api/v1/klines"
var params = ["symbol": instId, "interval": intervalStr] as [String: Any]
if let startTime = startTime {
params["startTime"] = startTime
}
params["limit"] = limit ?? self.limit
let response = await RestAPI.send(path: path, params: params)
if response.succeed {
if let data = response.data as? [[Any]] {
var candles = [Candle]()
for arr in data {
let candle = Candle(array: arr)
candles.append(candle)
}
return (response.res.urlResponse!, candles)
} else {
throw CommonError(message: "data类型不对")
}
} else {
logInfo("请求k线失败:\(response.errMsg ?? "")")
throw CommonError(message: response.errMsg ?? "")
}
}
通过接口测试,我发现最早只能查到2020/01/01 00:00:00的数据,时间戳就是1577808000000, 于是我就起了一个定时器不断调用接口把数据拉下来
swift
/// 执行一次请求
func request() async throws {
let lastTime = lastTime ?? initStartTime
let next = lastTime + 1
do {
logInfo("开始请求\(next.dateDesc)的数据")
// 请求当前时间的k线
let (response, candles) = try await candleManager.requestCandles(startTime: next)
let requestCompletion = await self.saveCandles(candles)
if !requestCompletion {
// 请求没完成,继续请求下一页
try await self.continueWith(response: response)
return
}
logInfo("请求完成,开始生成Candle模型")
try await createCandleModels()
} catch {
logInfo("请求失败:\(error)")
}
}
Candle是一个基础的k线数据结构,CandleModel就是上面的模型了,所有Candle会组装到一个数组中,然后再来生成所有的CandleModel,生成完之后,就开始组装csv需要的格式了
生成csv文件
swift
/// 生成csv文件
func createCSVFile() async throws {
// 每行的内容
var contents = [String]()
// 文件表头
let header = columns.joined(separator: ",")
contents.append(header)
logInfo("组装Header:\(header)")
// 每天的
for candleModel in candleModels {
var lines = [String]()
// 前面的涨跌幅
lines += candleModel.previousCandles
.map({ $0.rate })
// 当天的涨跌幅
if let current = candleModel.current {
lines.append(current.rate)
}
// 当前的字符串
let line = lines.joined(separator: ",")
logInfo("组装字符串:\(line)")
contents.append(line)
}
// 组装成最后字符
let content = contents.joined(separator: "\n")
if let data = content.data(using: .utf8) {
try data.write(to: csvFileURL)
}
logInfo("生csv成功,准备生成CoreML模型")
try await createML()
}
生成了一个14M的文件,模型还是比较多的
生成CoreML模型
接下来就是生成CoreML模型了
swift
/// 生成模型
func createML() async throws {
guard FileManager.default.fileExists(atPath: csvFileURL.path()) else {
print("csv路径不存在")
throw CreateError.csvNotFound
}
// 生成MLTable
let dataFrame = try DataFrame(contentsOfCSVFile: csvFileURL, columns: columns)
// 划分数据,0.8用于训练,0.2用于验证
let (trainingData, testingData) = dataFrame.randomSplit(by: 0.2, seed: 5)
// 开始训练
logInfo("开始训练")
let regressor = try MLLinearRegressor(trainingData: DataFrame(trainingData), targetColumn: "rate")
/// 获取训练结果
let trainintError = regressor.trainingMetrics.error
let trainintValid = regressor.trainingMetrics.isValid
let worstTrainingError = regressor.trainingMetrics.maximumError
logInfo("训练结果->: error: \(String(describing: trainintError)),是否有效:\(trainintValid),识别率:\(worstTrainingError)")
let validationError = regressor.validationMetrics.error
let validationValid = regressor.validationMetrics.isValid
let worstValidationError = regressor.validationMetrics.maximumError
logInfo("验证结果->: error: \(String(describing: validationError)),是否有效:\(validationValid),识别率:\(worstValidationError)")
/// 评估
logInfo("开始评估")
let regressorEvalutation = regressor.evaluation(on: DataFrame(testingData))
/// 评估e的结果
let evalutationError = regressorEvalutation.error
let evalutationValid = regressorEvalutation.isValid
let worstEvaluationError = regressorEvalutation.maximumError
logInfo("评估结果->: error: \(String(describing: evalutationError)),是否有效:\(evalutationValid),识别率:\(worstEvaluationError)")
// 保存
let regressorMetaData = MLModelMetadata(author: "zhtg@me.com", shortDescription: "BTC5mk线涨跌预测模型", version: "1.0")
try regressor.write(to: modelFileURL, metadata: regressorMetaData)
// 测试
// let testcsvFile = documentsPath + "/test.csv"
// let testDataFrame = try DataFrame(contentsOfCSVFile: URL(filePath: testcsvFile), columns: ["first", "second"])
// let pResult = try regressor.predictions(from: testDataFrame)
// print("result: \(pResult)")
logInfo("完成生成")
// xcrun coremlcompiler compile CandleModelRegressor.mlmodel .
// xcrun coremlcompiler generate --language Swift CandleModelRegressor.mlmodel .
}
有了csv也可以用developer tool来生成,我这里为了方便,就干脆用代码一起生成好了,其中有效和结果字段没用上,这个应该要判断一下,如果不满足,则直接中断的
生成的模型是一个CandleModelRegressor.mlmodel名称的,mlmodel结尾,这个文件还不可以用于spm,用于xcode project倒是可以,
不可以用于spm可能是bug,他编译的时候说缺少一个语言,但是spm没有设置语言的地方,走不下去了,只能另外想办法,google到可以手动编译,编译完再放进去就可以
shell
xcrun coremlcompiler compile CandleModelRegressor.mlmodel .
xcrun coremlcompiler generate --language Swift CandleModelRegressor.mlmodel .
手动命令如上,他会生成以下两个文件,
集成
拖入spm项目根目录即可, 然后target还要添加一个resource的配置,需要使用.copy
,其他都不行
swift
resources: [
.copy("CandleModelRegressor.mlmodelc"),
]
使用
用的时候需要先生成一个regressor
swift
// 生成识别器
var regressor: CandleModelRegressor = {
let bundle = Bundle.module
let url = bundle.url(forResource: "CandleModelRegressor", withExtension:"mlmodelc")!
return try! CandleModelRegressor(contentsOf: url)
}()
然后识别的时候,传入当前k线往前的5个节点的涨跌幅,即可用模型预测出来一个结果了
swift
guard let c0 = candleModel.previousCandles[0].rate.double,
let c1 = candleModel.previousCandles[1].rate.double,
let c2 = candleModel.previousCandles[2].rate.double,
let c3 = candleModel.previousCandles[3].rate.double,
let c4 = candleModel.previousCandles[4].rate.double else {
throw CommonError(message: "涨跌幅转成double失败")
}
logInfo("开始识别:\(c0),\(c1),\(c2),\(c3),\(c4)")
let output = try regressor.prediction(_0: c0, _1: c1, _2: c2, _3: c3, _4: c4)
logInfo("识别结果:\(output.rate)")
接下来就是交易的代码了,交易的代码就不帖了,有兴趣大家可以去github自己看,回头我会部署到linux服务上,跑一段时间,再来看结果,结果也会更新到这里