使用苹果的CoreML来生成一个机器学习的模型,用于测试虚拟币的5分k线的涨跌

前几天看到一个CoreML的教程视频,感觉挺有意思的样子,于是去了解了一下,决定尝试将他用于预测涨跌,看下机器学习的能不能预测的准,就算不行,也无所谓,就相当于学习好了

demo传递门

模型类型

CoreML需要选定一个模型来进行训练 打开developer tool,然后点击New document就可以看到有多少类型 前面几个都是图片分类,图片识别,手势识别,文字识别啥的,我也没细看,今天我们就用表格的Tabular Regression来给他一些基础数据,让他来预测

模型选型

曾经也想过给复杂的数据给他去训练,有点复杂,还是先给一些简单的了,我的想法就是以最后一根k线的涨跌幅作为目标,往前数五根k线作为基础数据 生成模型需要准备一个csv文件,我准备像下面这样的格式,前面的k是基础数据,target是目标数据,这样的数据交给模型训练器进行训练

| k0 | k1 | k2 | k3 | k4 | target |
|-------|-------|-------|-------|-------|----------|---|
| 0.09 | -0.16 | 0.23 | -0.06 | 0.07 | -0.05 |
| -0.16 | 0.23 | -0.06 | 0.07 | -0.05 | -0.10 | |
| 0.23 | -0.06 | 0.07 | -0.05 | -0.10 | 0.05 |
| -0.06 | 0.07 | -0.05 | -0.10 | 0.05 | -0.06 |

单位是百分点

获取数据

要得到这样的数据,需要从交易平台请求过往的k线数据 我从币安的api接口查到接口和参数,然后查询从某个时刻开始的数据,一次请求1500条,1500条乘以5分钟

swift 复制代码
open func requestCandles(startTime: Int? = nil, limit: Int? = nil) async throws -> (response: HTTPURLResponse, candles: [Candle]) {
        let path = "GET /api/v1/klines"
        var params = ["symbol": instId, "interval": intervalStr] as [String: Any]
        if let startTime = startTime {
            params["startTime"] = startTime
        }
        params["limit"] = limit ?? self.limit
        let response = await RestAPI.send(path: path, params: params)
        if response.succeed {
            if let data = response.data as? [[Any]] {
                var candles = [Candle]()
                for arr in data {
                    let candle = Candle(array: arr)
                    candles.append(candle)
                }
                return (response.res.urlResponse!, candles)
            } else {
                throw CommonError(message: "data类型不对")
            }
        } else {
            logInfo("请求k线失败:\(response.errMsg ?? "")")
            throw CommonError(message: response.errMsg ?? "")
        }
    }

通过接口测试,我发现最早只能查到2020/01/01 00:00:00的数据,时间戳就是1577808000000, 于是我就起了一个定时器不断调用接口把数据拉下来

swift 复制代码
/// 执行一次请求
    func request() async throws {
        let lastTime = lastTime ?? initStartTime
        let next = lastTime + 1
        do {
            logInfo("开始请求\(next.dateDesc)的数据")
            // 请求当前时间的k线
            let (response, candles) = try await candleManager.requestCandles(startTime: next)
            let requestCompletion = await self.saveCandles(candles)
            if !requestCompletion {
                // 请求没完成,继续请求下一页
                try await self.continueWith(response: response)
                return
            }
            logInfo("请求完成,开始生成Candle模型")
            try await createCandleModels()
        } catch {
            logInfo("请求失败:\(error)")
        }
    }

Candle是一个基础的k线数据结构,CandleModel就是上面的模型了,所有Candle会组装到一个数组中,然后再来生成所有的CandleModel,生成完之后,就开始组装csv需要的格式了

生成csv文件

swift 复制代码
    /// 生成csv文件
    func createCSVFile() async throws {
        // 每行的内容
        var contents = [String]()
       
        // 文件表头
        let header = columns.joined(separator: ",")
        contents.append(header)
        logInfo("组装Header:\(header)")
       
        // 每天的
        for candleModel in candleModels {
            var lines = [String]()
            // 前面的涨跌幅
            lines += candleModel.previousCandles
                .map({ $0.rate })
            
            // 当天的涨跌幅
            if let current = candleModel.current {
                lines.append(current.rate)
            }
           
            // 当前的字符串
            let line = lines.joined(separator: ",")
            logInfo("组装字符串:\(line)")
            contents.append(line)
        }

        // 组装成最后字符
        let content = contents.joined(separator: "\n")
        if let data = content.data(using: .utf8) {
            try data.write(to: csvFileURL)
        }
        
        logInfo("生csv成功,准备生成CoreML模型")
        try await createML()
    }

生成了一个14M的文件,模型还是比较多的

生成CoreML模型

接下来就是生成CoreML模型了

swift 复制代码
/// 生成模型
    func createML() async throws {
        guard FileManager.default.fileExists(atPath: csvFileURL.path()) else {
            print("csv路径不存在")
            throw CreateError.csvNotFound
        }
        
        // 生成MLTable
        let dataFrame = try DataFrame(contentsOfCSVFile: csvFileURL, columns: columns)
        
        // 划分数据,0.8用于训练,0.2用于验证
        let (trainingData, testingData) = dataFrame.randomSplit(by: 0.2, seed: 5)
        
        // 开始训练
        logInfo("开始训练")
        let regressor = try MLLinearRegressor(trainingData: DataFrame(trainingData), targetColumn: "rate")
        
        /// 获取训练结果
        let trainintError = regressor.trainingMetrics.error
        let trainintValid = regressor.trainingMetrics.isValid
        let worstTrainingError = regressor.trainingMetrics.maximumError
        logInfo("训练结果->: error: \(String(describing: trainintError)),是否有效:\(trainintValid),识别率:\(worstTrainingError)")
        
        let validationError = regressor.validationMetrics.error
        let validationValid = regressor.validationMetrics.isValid
        let worstValidationError = regressor.validationMetrics.maximumError
        logInfo("验证结果->: error: \(String(describing: validationError)),是否有效:\(validationValid),识别率:\(worstValidationError)")
        
        /// 评估
        logInfo("开始评估")
        let regressorEvalutation = regressor.evaluation(on: DataFrame(testingData))
        
        
        /// 评估e的结果
        let evalutationError = regressorEvalutation.error
        let evalutationValid = regressorEvalutation.isValid
        let worstEvaluationError = regressorEvalutation.maximumError
        logInfo("评估结果->: error: \(String(describing: evalutationError)),是否有效:\(evalutationValid),识别率:\(worstEvaluationError)")
        
        // 保存
        let regressorMetaData = MLModelMetadata(author: "zhtg@me.com", shortDescription: "BTC5mk线涨跌预测模型", version: "1.0")
        try regressor.write(to: modelFileURL, metadata: regressorMetaData)
        
        // 测试
//        let testcsvFile = documentsPath + "/test.csv"
//        let testDataFrame = try DataFrame(contentsOfCSVFile: URL(filePath: testcsvFile), columns: ["first", "second"])
//        let pResult = try regressor.predictions(from: testDataFrame)
//        print("result: \(pResult)")
        
        logInfo("完成生成")
        
//        xcrun coremlcompiler compile CandleModelRegressor.mlmodel .
//        xcrun coremlcompiler generate --language Swift CandleModelRegressor.mlmodel .
    }

有了csv也可以用developer tool来生成,我这里为了方便,就干脆用代码一起生成好了,其中有效和结果字段没用上,这个应该要判断一下,如果不满足,则直接中断的

生成的模型是一个CandleModelRegressor.mlmodel名称的,mlmodel结尾,这个文件还不可以用于spm,用于xcode project倒是可以,

不可以用于spm可能是bug,他编译的时候说缺少一个语言,但是spm没有设置语言的地方,走不下去了,只能另外想办法,google到可以手动编译,编译完再放进去就可以

shell 复制代码
xcrun coremlcompiler compile CandleModelRegressor.mlmodel .
xcrun coremlcompiler generate --language Swift CandleModelRegressor.mlmodel .

手动命令如上,他会生成以下两个文件,

集成

拖入spm项目根目录即可, 然后target还要添加一个resource的配置,需要使用.copy,其他都不行

swift 复制代码
resources: [
    .copy("CandleModelRegressor.mlmodelc"),
      ]

使用

用的时候需要先生成一个regressor

swift 复制代码
    // 生成识别器
    var regressor: CandleModelRegressor = {
        let bundle = Bundle.module
        let url = bundle.url(forResource: "CandleModelRegressor", withExtension:"mlmodelc")!
        return try! CandleModelRegressor(contentsOf: url)
    }()

然后识别的时候,传入当前k线往前的5个节点的涨跌幅,即可用模型预测出来一个结果了

swift 复制代码
        guard let c0 = candleModel.previousCandles[0].rate.double,
              let c1 = candleModel.previousCandles[1].rate.double,
              let c2 = candleModel.previousCandles[2].rate.double,
              let c3 = candleModel.previousCandles[3].rate.double,
              let c4 = candleModel.previousCandles[4].rate.double else {
            throw CommonError(message: "涨跌幅转成double失败")
        }
        logInfo("开始识别:\(c0),\(c1),\(c2),\(c3),\(c4)")
        let output = try regressor.prediction(_0: c0, _1: c1, _2: c2, _3: c3, _4: c4)
        logInfo("识别结果:\(output.rate)")

接下来就是交易的代码了,交易的代码就不帖了,有兴趣大家可以去github自己看,回头我会部署到linux服务上,跑一段时间,再来看结果,结果也会更新到这里

相关推荐
IT古董44 分钟前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
睡觉狂魔er1 小时前
自动驾驶控制与规划——Project 3: LQR车辆横向控制
人工智能·机器学习·自动驾驶
scan7241 小时前
LILAC采样算法
人工智能·算法·机器学习
菌菌的快乐生活2 小时前
理解支持向量机
算法·机器学习·支持向量机
爱喝热水的呀哈喽2 小时前
《机器学习》支持向量机
人工智能·决策树·机器学习
大山同学2 小时前
第三章线性判别函数(二)
线性代数·算法·机器学习
苏言の狗2 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
bastgia3 小时前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
paixiaoxin5 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
恋猫de小郭5 小时前
什么?Flutter 可能会被 SwiftUI/ArkUI 化?全新的 Flutter Roadmap
flutter·ios·swiftui