如何在Spark中使用gbdt模型分布式预测

这目录

  • [1 训练gbdt模型](#1 训练gbdt模型)
  • [2 第三方包python环境打包](#2 第三方包python环境打包)
  • [3 Spark中使用gbdt模型](#3 Spark中使用gbdt模型)
  • [4 spark任务提交](#4 spark任务提交)

1 训练gbdt模型

我们可以基于lightgbm快速的训练一个gbdt模型,训练相对比较简单,只要把训练样本处理好,几行代码可以快速训练好模型,如下是训练一个多分类模型训练核心代码如下:

python 复制代码
import lightgbm as lgb
import joblib
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
#假设处理好的训练样本为train.csv
df = pd.read_csv('./train.csv')
X = pd.drop(['label'],axis=1)
Y = df.label
# split data for val
x_train,x_val,y_train,y_val = train_test_split(X,Y,test_size=0.2,random_state=123)
# train model
cate_features=['sex','brand']
train_data = train_data = lgb.Dataset(x_train,label=y_train,categoryical_featrues=cate_features)
params = {
	'objective':'multiclass',
	'learning_rate':0.1,
	'n_estimators':100,
	'num_class':23
	}
model = lgb.train(params, train_data,100)
#predict val
y_pred = model.predict(x_val)
y_pred = y_pred.argmax(axis=1)

# acc
acc = accuracy_score(y_val, y_pred)
print(acc)

# feature importance
feature_name = model.feature_name()
feature_importance = model.feature_importance()
feature_score = dict(zip(feature_name, feature_importance))
feature_score_sort = sorted(feature_score.items(),key=lambda x:x[1], reverse=True)

# save model
joblib.dump(model, 'model.pkl')

上述就是基于lightgbm训练gbdt模型的代码,训练完后我们通过joblib保存了我们训练好的模型,这个模型接下来我们可以在spark进行分布式预测。

2 第三方包python环境打包

在使用spark的时候,我们可以自定义python环境,并且把我们需要的第三方包都可以安装该python环境里,这样在spark里我们就可以用python第三方包,比如等会我们需要的joblib, numpy等。具体如何配置python环境和第三方包,可以参考我上一篇博客:如何在spark中使用scikit-learn和tensorflow等第三方python包

3 Spark中使用gbdt模型

通过上述步骤,把需要的python环境和第三方包制作好了,包名为python39.zip,接下来我们介绍一下如何在spark中使用我们刚才训练好的gbdt模型进行分布式快速预测。

3.1 spark配置文件

提交spark任务的时候,配置文件这块也需要稍微修改一下,配置文件信息如下:

bash 复制代码
$SPARK_HOME/bin/spark-submit \
--master yarn \
--deploy-memory 12G
--executor-memory 20G \
--executor-cores 4 \
--queue root.your_queue_name \
--archives ./python39.zip#python39 \
--conf spark.yarn.appMasterEnv.PYSPARK_PYTHON=./python39/python39/bin/python3.9 \
--conf spark.yarn.appMasterEnv,HADOOP_USER_NAME=your_hduser_name \
--conf spark.shuffle.service.enabled=true \
--conf spark.dynamicAllocation.enbled=true \
--conf spark.dynamicAllocation.maxExecutors=50 \
--conf spark.dynamicAllocation.minExecutors=50 \
--conf spark.braodcast.compress=True \
--conf saprk.network.timeout=1000s \
--conf spark.sql.hive.mergeFiles=true \
--conf spark.speculation=false \
--conf spark.yarn.executor.memoryOverhead=4096 \
--files $HIVE_CONF_DIR/hive-site.xml \
--py-files ./model.pkl \
$@

上述是基本的提交spark任务的配置文件,其中
--archives ./python39.zip#python39 \

--archives参数用于在Spark应用程序运行期间将本地压缩档案文件解压到YARN集群节点上。#python39 是为档案文件定义的别名,这将在Spark应用程序中使用。
这个参数的目的是将名为python39.zip的压缩文件解压到YARN集群节点,并将其路径设置为python39,以供Spark应用程序使用。这通常用于指定特定版本的Python环境,以便在Spark任务中使用。
--conf spark.yarn.appMasterEnv.PYSPARK_PYTHON=./python39/python39/bin/python3.9

--conf参数用于设置Spark配置属性。

spark.yarn.appMasterEnv.PYSPARK_PYTHON 是一个Spark配置属性,它指定了YARN应用程序的主节点(ApplicationMaster)使用的Python解释器。

./python39/python39/bin/python3.9 是实际Python解释器的路径,它将在YARN应用程序的主节点上执行。

这个参数的目的是告诉Spark应用程序在YARN的主节点上使用特定的Python解释器即./python39/python39/bin/python3.9。这通常用于确保Spark应用程序使用正确的Python版本和环境来运行任务。
--py-files ./model.pkl \

--py-file是 Spark 提交任务时的一个参数,用于将指定的 .py 文件、.zip 文件或 .egg 文件分发到集群的所有 Worker 节点。Spark 会将这些文件自动添加到 Python 的模块路径中(即 sys.path),使得这些文件可以被任务中的代码引用。所以在这里我们将 model.pkl 模型文件分发到 Spark 集群的每个节点,确保每个节点在运行任务时都能访问并使用这个模型。

3.2 主函数main.py

接下来,我们来看下如何在在spark调用我们训练好的gbdt模型进行预测,核心代码主要如下:

1)import基础函数功能包等

python 复制代码
# -*-coding:utf8 -*-
import sys
from pyspark.sql.types import Row
from pyspark.sql.types import *
from pyspark import SparkConf, SparkContext, HiveContext
import datetime
import numpy as np
import joblib

save_table='your_target_table_name'
source_table = 'your_source_predict_table_name'
index_table = 'your_index_table_name'
# define saved table schema
schema = StructType([
	StructFiled('userid', StringType(), True),
	StructFiled('names', ArrayType(StringType()), True),
	StructFiled('scores', ArrayType(FloatType()), True)
	])
  1. main执行入口基础配置和执行流程
python 复制代码
if __name__=='__main__':
	conf = SparkConf()
	sc = SparkContext(conf=conf, appName='gbdt_spark_predict')
	sc.setLogLevel("WARN")
	hiveCtx = HiveContext(sc)
	#hive基础配置
	hiveCtx.setConf('spark.shuffle.consolidateFiles','true')
	hiveCtx.setConf('spark.shuffle.memoryFraction','0.4')
	hiveCtx.setConf('spark.sql.shuffle.partitions','1000')
	if len(sys.argv) == 1:
		dt = datetime.datetime.now() + datetime.timedelta(-1)
	else:
		dt = datetime.datetime.strptime(sys.argv[1],"%Y%m%d').date()
	dt_str = dt.strftime('%Y-%m-%d')
	hiveCtx.sql("use your_datebase_name")

	#注册函数,在sql时候可以使用
	hiveCtx.registerFunction('null_chage', null_chage, StringType())
	#创建目标表
	create_table(hiveCtx)
	#主函数
	get_predict(hiveCtx)
	

上面主函数给出了一个基本的流程步骤:1)spark, hive context等初始化 2)注册函数可以直接在sql中使用,方便数据处理 3)建立目标hive表 4)执行功能函数。

  1. 函数功能模块实现

在第2步骤里,我们主要有三个函数需要编写,一个是可以在sql中调用的基础函数,第二个就是创建表函数,第三个就是功能函数,我们接下来实现这三个的基本功能:

python 复制代码
#用在sql中的基本数据操作处理
def null_chage(x):
	return 'unknow' if x is None else x

#创建目标表
def create_table(hiveCtx):
	create_tbl = """
		CREATE EXTERNAL TABLE IF NOT EXISTS your_database_name.{table_name} (
		userid       string       COMMENT 'user id';
		names        array<string>   COMMENT 'predict label names')
		scores        array<float>   COMMENT 'predict socre')
	PARTITIONED BY(dt string, dp string)
	STORED AS ORC
	LOCATION 'hdfs://your_database_name.db/{table_name}'
	TBLPROPERTIES('orc.compress'='SNAPPY','comment'='gbdt predict user score')
	""".format(table_name=save_table)
# 功能函数
def get_predict(hiveCtx):
	# get label and idex data
	sql="""
		select index, value
		from {index_table}
		where dt='active;
		""".format(index_table=index_table)
	print(sql)
	vocab = hiveCtx.sql(sql).rdd.collect()
	vocab_dict = dict()
	for x in vocab:
		vocab_dict.setdefault(x[0],x[1])
	# broadcast
	br_vocab_dict = sc.broadcast(vocab_dict)
	# get predict data
	sql="""
		select null_chage(userid) as userid, features
		from {source_table}
		where dt='active'
		""".format(source_table=source_table)
	print(sql)
	hiveCtx.sql(sql).rdd.mapPartitions(lambda rows: main_func(rows, br_vocab_dict)) \
	.toDF(schema=schema) \
	.registerTempTable('final_tbl')
	
	# insert table
	insert_sql = """
		insert overwrite table {save_table} partition (dt='{dt}')
		select * from final_tbl
		""".format(save_table=save_table,dt='active')
	print(insert_sql)
	hiveCtx.sql(insert_sql)

接下来,我们来看下main_func函数的实现:

python 复制代码
def main_func(rows, br_vocab_dict):
	# load model
	model = joblib.load('./model.pkl')
	vocab_dict = br_vobab_dict.value
	for row in rows:
		userid, features = row
		features = np.array(features)
		predict = model.predict(features)
		predict_sort = np.argsort(-predict[0])
		names = [vocab_dict[idx] for idx in predict_sort]
		scores = [float(predict[0][idx]) for idx in predict_sort]
		yield userid, names, scores

整个代码的实现我们在这里就写完了,整体实现逻辑是比较清晰易懂的,按照这个流程来,我们可以很高效快速的基于spark分布式的跑一些数据处理和模型预测性的任务。

4 spark任务提交

接下来,就是提交我们的spark任务了,在工作环境目录如下文件信息:

  • 提前准备好的python环境包python39.zip
  • spark config文件 run_spark_arg.sh
  • 主函数代码 main.py
  • gbdt模型文件model.pkl

最后环节就是提交spark任务,我么可以在服务器提交命令如下:

powershell 复制代码
nohup sh run_spark_arg.sh main.py >log.txt 2>&1 &
相关推荐
MasterNeverDown3 小时前
如何将 DotNetFramework 项目打包成 NuGet 包并发布
大数据·hadoop·hdfs
中科岩创3 小时前
广西钦州刘永福故居钦江爆破振动自动化监测
大数据·物联网
大数据编程之光4 小时前
Flink-CDC 全面解析
大数据·flink
GZ_TOGOGO5 小时前
华为大数据考试模拟真题(附答案)题库领取
大数据·华为
王子良.7 小时前
大数据生态系统:Hadoop(HDFS)、Hive、Spark、Flink、Kafka、Redis、ECharts、Zookeeper之间的关系详解
大数据·hive·hadoop·经验分享·学习·hdfs·spark
大力财经7 小时前
激发本地生意,抖音生活服务连锁商家生意同比增长超80%
大数据·人工智能
weixin_437398217 小时前
Elasticsearch学习(1) : 简介、索引库操作、文档操作、RestAPI、RestClient操作
java·大数据·spring boot·后端·学习·elasticsearch·全文检索
安的列斯凯奇7 小时前
Elasticsearch—索引库操作(增删查改)
大数据·elasticsearch·搜索引擎
金州饿霸8 小时前
hadoop-yarn常用命令
大数据·前端·hadoop
SeaTunnel8 小时前
对话新晋 Apache SeaTunnel Committer:张圣航的开源之路与技术洞察
大数据