第6章 Spark SQL 详细课堂讲义
第一部分:Spark SQL 概述(6.1)
1.1 从 Shark 说起
知识点讲解:
Shark 是 Hive on Spark 的实现,它复用了 Hive 的 HQL 解析、逻辑计划优化等模块,只把物理执行计划从 MapReduce 替换成了 Spark RDD 操作。虽然 Shark 的性能比 Hive 提高了 10-100 倍,但存在两个严重问题:
- 优化依赖 Hive:无法方便地添加新的优化策略
- 线程安全问题:Spark 是线程级并行,而 MapReduce 是进程级并行,导致 Spark 在兼容 Hive 时存在线程安全问题
2014年,Shark 停止开发,转向 Spark SQL 和 Hive on Spark 两个分支。
1.2 Spark SQL 架构
核心组件:
- Catalyst 优化器:负责执行计划的生成和优化
- DataFrame:带 Schema 信息的分布式数据集
- 外部数据源支持:Hive、JSON、Parquet、JDBC 等
1.3 为什么推出 Spark SQL
三大原因:
- 用户需要从多种数据源(结构化、半结构化、非结构化)执行 ETL 操作
- 用户需要执行高级分析(机器学习、图处理)
- 实际大数据应用经常需要融合关系查询和复杂算法
Spark SQL 的解决方案:
- 提供 DataFrame API,支持关系型操作
- 引入 Catalyst 优化器,支持扩展
- 融合 SQL 查询和复杂算法
1.4 Spark SQL 的特点
- 容易整合:SQL 查询和 Spark 程序无缝集成
- 统一数据访问:以相同方式访问 Hive、Avro、Parquet、ORC、JSON、JDBC 等
- 兼容 Hive:支持 HiveQL、SerDes、UDF
- 标准数据库连接:支持 JDBC/ODBC
1.5 SparkSession 入门
知识点讲解:
Spark 2.0 以上版本使用统一的 SparkSession 替代了 SQLContext 和 HiveContext。SparkSession 封装了 SparkContext、SparkConf、StreamingContext 等,是 Spark 应用程序的统一入口。
完整代码示例:
python
# coding:utf8
from pyspark.sql import SparkSession
if __name__ == '__main__':
# 创建 SparkSession 对象
spark = SparkSession \
.builder \
.master("local[*]") \
.appName("Simple Application") \
.getOrCreate()
# 读取文件
logFile = "file:///usr/local/spark/README.md"
logDF = spark.read.text(logFile)
# 统计包含字母 a 和 b 的行数
numAs = logDF.filter(logDF["value"].contains("a")).count()
numBs = logDF.filter(logDF["value"].contains("b")).count()
print('Lines with a: %s, Lines with b: %s' % (numAs, numBs))
# 停止 SparkSession
spark.stop()
说明: 在 spark-shell 中,系统已经默认提供了一个 SparkSession 对象,名称为 spark,可以直接使用。
第二部分:DataFrame 详解(6.2)
2.1 什么是 DataFrame
知识点讲解:
DataFrame 是一种以 RDD 为基础的分布式数据集,提供了详细的结构信息(Schema)。与 RDD 相比:
| 特性 | RDD | DataFrame |
|---|---|---|
| 数据结构 | 无结构,只是 Java 对象集合 | 有明确的列结构(Schema) |
| 优化能力 | 无法优化 | Spark 可以理解和优化 |
| 表达能力 | 需要告诉 Spark "如何做" | 只需要告诉 Spark "做什么" |
| 代码简洁度 | 复杂 | 简洁 |
2.2 DataFrame 的优点(重点案例)
案例:计算每种图书的平均销量
数据集:键值对 ("spark",2)、("hadoop",6)、("hadoop",4)、("spark",6)
RDD 实现方式(复杂):
python
# coding:utf8
from pyspark import SparkConf, SparkContext
if __name__ == '__main__':
conf = SparkConf().setAppName("Simple App")
sc = SparkContext(conf=conf)
bookRDD = sc.parallelize([("spark", 2), ("hadoop", 6), ("hadoop", 4), ("spark", 6)])
saleRDD = bookRDD.map(lambda x: (x[0], (x[1], 1))) \
.reduceByKey(lambda x, y: (x[0] + y[0], x[1] + y[1])) \
.map(lambda x: (x[0], x[1][0] / x[1][1]))
saleRDD.foreach(print)
sc.stop()
DataFrame 实现方式(简洁):
python
# coding:utf8
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg
if __name__ == '__main__':
spark = SparkSession \
.builder \
.master("local[*]") \
.appName("Simple Application") \
.getOrCreate()
# 创建 DataFrame
bookDF = spark \
.createDataFrame([("spark", 2), ("hadoop", 6), ("hadoop", 4), ("spark", 6)]) \
.toDF("book", "amount")
# 分组计算平均值
avgDF = bookDF.groupBy("book").agg(avg("amount"))
avgDF.show()
spark.stop()
执行结果:
+------+-----------+
| book| avg(amount)|
+------+-----------+
|spark| 4.0|
|hadoop| 5.0|
+------+-----------+
第三部分:DataFrame 的创建与保存(6.3)
3.1 Parquet 文件格式
知识点: Parquet 是 Spark 的默认数据源,是一种开源的列式存储文件格式,支持压缩(如 snappy),提供高效的 I/O 优化。
读取 Parquet 文件:
python
# 方式一:使用 format() 指定格式
filePath = "file:///usr/local/spark/examples/src/main/resources/users.parquet"
df = spark.read.format("parquet").load(filePath)
df.show()
# 方式二:直接使用 parquet() 方法
df = spark.read.parquet(filePath)
保存 DataFrame 为 Parquet 文件:
python
# 保存并指定压缩方式
df.write.format("parquet") \
.mode("overwrite") \
.option("compression", "snappy") \
.save("file:///home/hadoop/otherusers")
# 简化写法
df.write.parquet("file:///home/hadoop/otherusers")
3.2 JSON 文件
知识点: JSON 支持单行模式和多行模式两种格式。
读取 JSON 文件:
python
filePath = "file:///usr/local/spark/examples/src/main/resources/people.json"
# 方式一
df = spark.read.format("json").load(filePath)
# 方式二
df = spark.read.json(filePath)
df.show()
保存 DataFrame 为 JSON 文件:
python
df.write.format("json").mode("overwrite").save("file:///home/hadoop/otherpeople")
# 简化写法
df.write.json("file:///home/hadoop/otherpeople")
3.3 CSV 文件
知识点: CSV 是最常见的文本数据格式,字段用逗号分隔,每行一条记录。
读取 CSV 文件:
python
filePath = "file:///usr/local/spark/examples/src/main/resources/people.csv"
# 定义 Schema(表结构)
schema = "name STRING, age INT, job STRING"
# 方式一:使用 format()
df = spark.read.format("csv") \
.schema(schema) \
.option("header", "true") \
.option("sep", ";") \
.load(filePath)
# 方式二:直接使用 csv() 方法
df = spark.read \
.schema(schema) \
.option("header", "true") \
.option("sep", ";") \
.csv(filePath)
df.show()
保存 DataFrame 为 CSV 文件:
python
df.write.format("csv").mode("overwrite").save("file:///home/hadoop/anotherpeople")
# 简化写法
df.write.csv("file:///home/hadoop/anotherpeople")
3.4 文本文件
读取文本文件:
python
filePath = "file:///usr/local/spark/examples/src/main/resources/people.txt"
df = spark.read.format("text").load(filePath)
# 或者
df = spark.read.text(filePath)
df.show()
保存为文本文件:
python
df.write.text("file:///home/hadoop/newpeople")
# 或者
df.write.format("text").save("file:///home/hadoop/newpeople")
注意: 保存为文本文件时,DataFrame 只能包含一列数据。
3.5 从序列集合创建 DataFrame
python
# 基本创建方式
df = spark.createDataFrame([
("Xiaomei", "Female", "21"),
("Xiaoming", "Male", "22"),
("Xiaoxue", "Female", "23")
]).toDF("name", "sex", "age")
df.show()
# 指定数据类型的方式
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
schema = StructType([
StructField("name", StringType(), True),
StructField("sex", StringType(), True),
StructField("age", IntegerType(), True)
])
data = [("Xiaomei", "Female", 21), ("Xiaoming", "Male", 22), ("Xiaoxue", "Female", 23)]
df = spark.createDataFrame(data, schema)
第四部分:DataFrame 的基本操作(6.4)
4.1 DSL 语法风格
准备数据:
python
filePath = "file:///usr/local/spark/examples/src/main/resources/people.json"
df = spark.read.json(filePath)
1. printSchema() - 打印结构信息
python
df.printSchema()
输出:
root
|-- age: long (nullable = true)
|-- name: string (nullable = true)
2. show() - 显示数据
python
df.show()
输出:
+----+-------+
| age| name|
+----+-------+
|null|Michael|
| 30| Andy|
| 19| Justin|
+----+-------+
3. select() - 选择列
python
# 选择单列
df.select("name").show()
# 选择多列
df.select("name", "age").show()
# 在 select 中进行计算
df.select(df["name"], df["age"] + 1).show()
# 重命名列
df.select(df["name"].alias("username"), df["age"]).show()
4. filter() / where() - 条件筛选
python
# 使用 filter
df.filter(df["age"] > 20).show()
# 使用 where(多种写法)
df.where(df["age"] > 20).show()
df.where("age > 20").show()
5. groupBy() - 分组聚合
python
# 分组并统计数量
df.groupBy("age").count().show()
# 使用聚合函数
from pyspark.sql.functions import avg, sum, max, min
df.groupBy("age").agg(avg("age")).show()
6. sort() / orderBy() - 排序
python
# 降序排序
df.sort(df["age"].desc()).show()
df.sort(df["age"].desc(), df["name"].asc()).show()
# 使用 orderBy
df.orderBy("age", ascending=False).show()
df.orderBy(df["age"], ascending=False).show()
7. withColumn() - 添加新列
python
from pyspark.sql.functions import expr
# 根据条件添加新列
df2 = df.withColumn(
"IfWithAge",
expr("CASE WHEN age IS NOT NULL THEN 'YES' ELSE 'NO' END")
)
df2.show()
8. withColumnRenamed() - 重命名列
python
df3 = df.withColumnRenamed("name", "username")
df3.show()
9. drop() - 删除列
python
df4 = df3.drop("IfWithAge")
df4.show()
10. 统计函数
python
# 描述性统计
df.describe().show()
# 单独统计
from pyspark.sql.functions import min, max, avg, sum
df.select(min("age"), max("age"), avg("age"), sum("age")).show()
4.2 SQL 语法风格(重点)
知识点: 使用 SQL 语句操作 DataFrame 之前,必须先创建临时视图。
创建临时视图
python
# 创建临时视图(如果已存在会报错)
df.createTempView("people")
# 创建或替换临时视图(推荐)
df.createOrReplaceTempView("people")
# 创建全局临时视图(跨会话可用)
df.createGlobalTempView("global_people")
基础 SQL 查询
python
# 1. 简单查询
result = spark.sql("SELECT * FROM people")
result.show()
# 2. 条件查询
result = spark.sql("SELECT name, age FROM people WHERE age > 20")
result.show()
# 3. 分组聚合
result = spark.sql("""
SELECT age, COUNT(*) as count
FROM people
GROUP BY age
ORDER BY age DESC
""")
result.show()
SQL 函数使用(完整案例)
案例:使用系统函数和用户自定义函数
python
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import from_unixtime
from pyspark.sql.types import IntegerType, StringType, LongType, StructField, StructType
# 创建 SparkSession
spark = SparkSession.builder.master("local[*]").appName("SQL Functions").getOrCreate()
# 定义 Schema
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("create_time", LongType(), True)
])
# 准备数据
data = [
Row("Xiaomei", 21, 1580432800),
Row("Xiaoming", 22, 1580436400),
Row("Xiaoxue", 23, 1580438800)
]
df = spark.createDataFrame(data, schema)
df.show()
# 创建临时视图
df.createOrReplaceTempView("users")
# 1. 使用系统函数 from_unixtime 格式化时间戳
result1 = spark.sql("""
SELECT name,
age,
from_unixtime(create_time) as create_time_str
FROM users
""")
result1.show()
# 2. 定义用户自定义函数(UDF)
def to_upper(name):
return name.upper()
# 注册 UDF
spark.udf.register("upper_name", to_upper, StringType())
# 使用 UDF
result2 = spark.sql("""
SELECT upper_name(name) as upper_name, age
FROM users
""")
result2.show()
# 3. 组合使用系统函数和 UDF
result3 = spark.sql("""
SELECT upper_name(name) as upper_name,
age,
from_unixtime(create_time) as create_time_str
FROM users
WHERE age > 21
""")
result3.show()
spark.stop()
常用的 Spark SQL 系统函数分类
| 分类 | 函数示例 |
|---|---|
| 数学函数 | round(), abs(), sqrt(), pow() |
| 字符串函数 | length(), substring(), concat(), upper(), lower() |
| 日期函数 | year(), month(), day(), date_add(), datediff() |
| 聚合函数 | sum(), avg(), count(), max(), min() |
| 窗口函数 | row_number(), rank(), lag(), lead() |
| 条件函数 | when(), if(), coalesce() |
窗口函数使用示例
python
# 创建示例数据
data = [
("Alice", "Sales", 5000),
("Bob", "Sales", 6000),
("Charlie", "Sales", 5500),
("David", "IT", 7000),
("Eve", "IT", 6500)
]
df = spark.createDataFrame(data, ["name", "dept", "salary"])
df.createOrReplaceTempView("employees")
# 使用窗口函数:每个部门内按工资排名
result = spark.sql("""
SELECT name, dept, salary,
ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rank
FROM employees
""")
result.show()
第五部分:从 RDD 转换得到 DataFrame(6.5)
5.1 利用反射机制推断 RDD 模式
知识点: 当数据结构和字段名称已知时,可以使用反射机制自动推断 Schema。
完整案例:
python
from pyspark.sql import Row
# 读取文本文件
people = spark.sparkContext \
.textFile("file:///usr/local/spark/examples/src/main/resources/people.txt") \
.map(lambda line: line.split(",")) \
.map(lambda p: Row(name=p[0], age=int(p[1])))
# 创建 DataFrame(自动推断 Schema)
schemaPeople = spark.createDataFrame(people)
# 注册为临时表
schemaPeople.createOrReplaceTempView("people")
# 执行 SQL 查询
personsDF = spark.sql("SELECT name, age FROM people WHERE age > 20")
# 将结果转换为 RDD 并处理
personsRDD = personsDF.rdd.map(lambda p: "Name: " + p.name + ", Age: " + str(p.age))
personsRDD.foreach(print)
# 输出:
# Name: Michael, Age: 29
# Name: Andy, Age: 30
5.2 使用编程方式定义 RDD 模式
知识点: 当数据结构无法提前获知时,需要通过编程方式动态定义 Schema。
完整案例:
python
from pyspark.sql.types import *
from pyspark.sql import Row
# 第一步:定义 Schema(表头)
schemaString = "name age"
fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split(" ")]
schema = StructType(fields)
# 第二步:读取数据并转换为 Row 对象(表中的记录)
lines = spark.sparkContext.textFile("file:///usr/local/spark/examples/src/main/resources/people.txt")
parts = lines.map(lambda x: x.split(","))
people = parts.map(lambda p: Row(p[0], p[1].strip()))
# 第三步:组合 Schema 和数据
schemaPeople = spark.createDataFrame(people, schema)
# 第四步:注册临时表并执行查询
schemaPeople.createOrReplaceTempView("people")
results = spark.sql("SELECT name, age FROM people")
results.show()
输出:
+-------+---+
| name|age|
+-------+---+
|Michael| 29|
| Andy| 30|
| Justin| 19|
+-------+---+
第六部分:使用 Spark SQL 读写数据库(6.6)
6.1 准备工作
- 安装 MySQL 数据库(参考:http://dblab.xmu.edu.cn/blog/install-mysql/)
- 启动 MySQL 并创建数据库和表
sql
-- 登录 MySQL
mysql -u root -p
-- 创建数据库
CREATE DATABASE spark;
-- 使用数据库
USE spark;
-- 创建 student 表
CREATE TABLE student (
id INT(4),
name CHAR(20),
gender CHAR(4),
age INT(4)
);
-- 插入测试数据
INSERT INTO student VALUES(1, 'Xueqian', 'F', 23);
INSERT INTO student VALUES(2, 'Weiliang', 'M', 24);
-- 查看数据
SELECT * FROM student;
-
下载 MySQL JDBC 驱动
- 下载
mysql-connector-java-5.1.40.tar.gz - 解压后将
.jar文件拷贝到/usr/local/spark/jars/目录
- 下载
-
启动 PySpark
bash
cd /usr/local/spark
./bin/pyspark
6.2 读取 MySQL 数据库中的数据
在 PySpark Shell 中执行:
python
jdbcDF = spark.read \
.format("jdbc") \
.option("driver", "com.mysql.jdbc.Driver") \
.option("url", "jdbc:mysql://localhost:3306/spark") \
.option("dbtable", "student") \
.option("user", "root") \
.option("password", "123456") \
.load()
jdbcDF.show()
在 PyCharm 独立应用程序中执行:
python
from pyspark.sql import SparkSession
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName('SparkReadMySQL') \
.master('local[*]') \
.getOrCreate()
jdbcDF = spark.read \
.format("jdbc") \
.option("driver", "com.mysql.jdbc.Driver") \
.option("url", "jdbc:mysql://localhost:3306/spark?useSSL=false") \
.option("dbtable", "student") \
.option("user", "root") \
.option("password", "123456") \
.load()
jdbcDF.show()
spark.stop()
6.3 向 MySQL 数据库写入数据
完整写入程序:
python
#!/usr/bin/env python3
from pyspark.sql import Row
from pyspark.sql.types import *
from pyspark.sql import SparkSession
# 创建 SparkSession
spark = SparkSession.builder \
.master("local[*]") \
.appName("Write to MySQL") \
.getOrCreate()
# 定义 Schema
schema = StructType([
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("gender", StringType(), True),
StructField("age", IntegerType(), True)
])
# 准备数据
studentRDD = spark.sparkContext \
.parallelize(["3 Rongcheng M 26", "4 Guanhua M 27"]) \
.map(lambda x: x.split(" "))
# 创建 Row 对象
rowRDD = studentRDD.map(lambda p: Row(
int(p[0].strip()),
p[1].strip(),
p[2].strip(),
int(p[3].strip())
))
# 创建 DataFrame
studentDF = spark.createDataFrame(rowRDD, schema)
# 设置数据库连接参数
prop = {
'user': 'root',
'password': '123456',
'driver': "com.mysql.jdbc.Driver"
}
# 写入数据库
studentDF.write.jdbc(
"jdbc:mysql://localhost:3306/spark",
"student",
'append',
prop
)
# 验证:读取并显示
verifyDF = spark.read \
.format("jdbc") \
.option("driver", "com.mysql.jdbc.Driver") \
.option("url", "jdbc:mysql://localhost:3306/spark") \
.option("dbtable", "student") \
.option("user", "root") \
.option("password", "123456") \
.load()
verifyDF.show()
spark.stop()
写入后 MySQL 中的数据:
+------+-----------+--------+------+
| id | name | gender | age |
+------+-----------+--------+------+
| 1 | Xueqian | F | 23 |
| 2 | Weiliang | M | 24 |
| 3 | Rongcheng | M | 26 |
| 4 | Guanhua | M | 27 |
+------+-----------+--------+------+
第七部分:PySpark 与 pandas 整合(6.7)
7.1 安装 pandas
bash
conda activate pyspark
pip install pandas
pip install pyarrow # 用于高性能数据转换
7.2 pandas 数据结构简介
Series(一维数组)
python
import pandas as pd
# 创建 Series(自动索引)
obj = pd.Series([3, 5, 6, 8, 9, 2])
print(obj)
# 创建 Series(指定索引)
obj2 = pd.Series([3, 5, 6, 8, 9, 2], index=['a', 'b', 'c', 'd', 'e', 'f'])
print(obj2)
# 通过索引访问
print(obj2['a']) # 单个值
print(obj2[['b', 'd', 'f']]) # 多个值
DataFrame(二维表格)
python
import pandas as pd
# 从字典创建 DataFrame
data = {
'sno': ['95001', '95002', '95003', '95004'],
'name': ['Xiaoming', 'Zhangsan', 'Lisi', 'Wangwu'],
'sex': ['M', 'F', 'F', 'M'],
'age': [22, 25, 24, 23]
}
frame = pd.DataFrame(data)
print(frame)
# 指定列顺序
frame = pd.DataFrame(data, columns=['name', 'sno', 'sex', 'age'])
print(frame)
# 指定行索引
frame = pd.DataFrame(data,
columns=['sno', 'name', 'sex', 'age', 'grade'],
index=['a', 'b', 'c', 'd'])
print(frame)
# 访问列
print(frame['sno'])
print(frame.name)
# 访问行
print(frame.loc['b'])
7.3 实例1:两种 DataFrame 相互转换
python
# coding:utf8
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
import pandas as pd
if __name__ == '__main__':
spark = SparkSession \
.builder \
.master("local[*]") \
.appName("Pandas to Spark") \
.getOrCreate()
# 1. 创建 pandas DataFrame
pd_df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['zhangsan', 'lisi', 'wangwu', 'maliu', 'liuqi'],
'age': [24, 26, 22, 28, 24]
})
print("=== pandas DataFrame ===")
print(pd_df)
# 2. 定义 PySpark DataFrame 的 Schema
schema = StructType([
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("age", IntegerType(), True)
])
# 3. pandas DataFrame -> PySpark DataFrame
spark_df = spark.createDataFrame(pd_df, schema)
print("=== PySpark DataFrame ===")
spark_df.show()
# 4. PySpark DataFrame -> pandas DataFrame
pd_df2 = spark_df.toPandas()
print("=== 转换回 pandas DataFrame ===")
print(pd_df2)
spark.stop()
7.4 实例2:使用自定义聚合函数(pandas_udf)
数据文件 people_data.txt 内容:
1,zhangsan,22
2,lisi,26
3,wangwu,28
4,maliu,24
完整代码:
python
# coding:utf8
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, StringType, StructField, StructType, Row
import pandas as pd
if __name__ == '__main__':
spark = SparkSession \
.builder \
.master("local[*]") \
.appName("Pandas UDF Example") \
.getOrCreate()
# 读取数据
peopleRDD = spark.sparkContext \
.textFile("file:///home/hadoop/people_data.txt") \
.map(lambda line: line.split(",")) \
.map(lambda p: Row(id=int(p[0]), name=p[1], age=int(p[2])))
# 定义 Schema
schema = StructType([
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("age", IntegerType(), True)
])
# 创建 DataFrame
peopleDF = peopleRDD.toDF(schema)
peopleDF.show()
# 步骤1:使用装饰器定义 pandas UDF
@F.pandas_udf(IntegerType())
def my_sum(a: pd.Series) -> int:
"""计算年龄总和"""
return a.sum()
# 步骤2:使用 register 方法注册 UDF
spark.udf.register("sum_func", my_sum)
# 方式1:在 DSL 风格中使用
print("=== DSL 风格使用 UDF ===")
sumDF = peopleDF.select(my_sum("age").alias("total_age"))
sumDF.show()
# 方式2:在 SQL 风格中使用
print("=== SQL 风格使用 UDF ===")
peopleDF.createOrReplaceTempView("people")
sumDF2 = spark.sql("SELECT sum_func(age) as total_age FROM people")
sumDF2.show()
spark.stop()
第八部分:综合实例 - 电影评分数据分析(6.8)
数据集说明
电影评分数据集 ratings.dat,每行包含4个字段,用 :: 分隔:
- 用户ID(User ID)
- 电影ID(Movie ID)
- 评分(Rating,1-5分)
- 时间戳(Timestamp)
数据示例:
1::1193::5::978300760
1::661::3::978302109
1::914::3::978301968
2::1193::4::978300760
2::607::3::978302109
完整程序 MovieRating.py
python
# coding:utf8
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
if __name__ == '__main__':
spark = SparkSession \
.builder \
.master("local[*]") \
.appName("Movie Rating Analysis") \
.getOrCreate()
# ========== 1. 读取数据集 ==========
print("=" * 50)
print("1. 数据集如下:")
print("=" * 50)
filePath = "file:///home/hadoop/ratings.dat"
# 定义 Schema
schema = StructType([
StructField("user_id", StringType(), True),
StructField("movie_id", StringType(), True),
StructField("rating", IntegerType(), True),
StructField("ts", StringType(), True)
])
# 读取 CSV 文件,分隔符为 "::"
ratingsDF = spark.read.format("csv") \
.option("sep", "::") \
.option("header", False) \
.option("encoding", "utf-8") \
.schema(schema) \
.load(filePath)
ratingsDF.show(10)
ratingsDF.createOrReplaceTempView("ratings")
# ========== 2. 求每个用户的平均打分 ==========
print("=" * 50)
print("2. 求每个用户的平均打分")
print("=" * 50)
# DSL 风格
ratingsDF.groupBy("user_id") \
.avg("rating") \
.withColumnRenamed("avg(rating)", "avg_rating") \
.withColumn("avg_rating", F.round("avg_rating", 3)) \
.orderBy("avg_rating", ascending=False) \
.show(10)
# ========== 3. 求每部电影的平均打分 ==========
print("=" * 50)
print("3. 求每部电影的平均打分")
print("=" * 50)
# SQL 风格
spark.sql("""
SELECT movie_id,
ROUND(AVG(rating), 3) AS avg_rating
FROM ratings
GROUP BY movie_id
ORDER BY avg_rating DESC
""").show(10)
# ========== 4. 查询大于平均打分的电影的数量 ==========
print("=" * 50)
print("4. 查询大于平均打分的电影的数量")
print("=" * 50)
# 先计算整体平均分
avg_rating = ratingsDF.select(F.avg("rating")).collect()[0][0]
print(f"整体平均分: {avg_rating:.3f}")
# 统计高于平均分的电影数量
high_rating_movies = ratingsDF \
.filter(ratingsDF["rating"] > avg_rating) \
.select("movie_id") \
.distinct() \
.count()
print(f"大于平均打分的电影数量: {high_rating_movies}")
# 另一种写法
movieCount = ratingsDF.where(ratingsDF["rating"] > avg_rating).count()
print(f"大于平均打分的评分记录数量: {movieCount}")
# ========== 5. 查询高分(大于3分)电影中打分次数最多的用户 ==========
print("=" * 50)
print("5. 查询高分(大于3分)电影中打分次数最多的用户,并给出此人打分的平均值")
print("=" * 50)
# 找出高分电影中打分次数最多的用户
user_id = ratingsDF.where(ratingsDF["rating"] > 3) \
.groupBy("user_id") \
.count() \
.withColumnRenamed("count", "high_rating_count") \
.orderBy("high_rating_count", ascending=False) \
.limit(1) \
.first()["user_id"]
print(f"高分电影中打分次数最多的用户ID: {user_id}")
# 计算该用户的平均打分
ratingsDF.filter(ratingsDF["user_id"] == user_id) \
.select(F.round(F.avg("rating"), 3).alias("avg_rating")) \
.show()
# ========== 6. 查询每个用户的平均打分、最低打分和最高打分 ==========
print("=" * 50)
print("6. 查询每个用户的平均打分、最低打分和最高打分")
print("=" * 50)
ratingsDF.groupBy("user_id") \
.agg(
F.round(F.avg("rating"), 3).alias("avg_rating"),
F.min("rating").alias("min_rating"),
F.max("rating").alias("max_rating")
) \
.orderBy("avg_rating", ascending=False) \
.show(10)
# ========== 7. 查询打分次数超过100次的电影的平均分排行榜TOP10 ==========
print("=" * 50)
print("7. 查询打分次数超过100次的电影的平均分排行榜TOP10")
print("=" * 50)
ratingsDF.groupBy("movie_id") \
.agg(
F.count("movie_id").alias("rating_count"),
F.round(F.avg("rating"), 3).alias("avg_rating")
) \
.where("rating_count > 100") \
.orderBy("avg_rating", ascending=False) \
.limit(10) \
.show()
# 使用 SQL 风格的写法
print("\n--- SQL 风格实现相同功能 ---")
spark.sql("""
SELECT movie_id,
COUNT(*) AS rating_count,
ROUND(AVG(rating), 3) AS avg_rating
FROM ratings
GROUP BY movie_id
HAVING rating_count > 100
ORDER BY avg_rating DESC
LIMIT 10
""").show()
spark.stop()
各题目输出说明
| 题号 | 题目 | 输出内容 |
|---|---|---|
| 1 | 显示数据集 | 前10行数据 |
| 2 | 每个用户的平均打分 | 用户ID + 平均分(降序) |
| 3 | 每部电影的平均打分 | 电影ID + 平均分(降序) |
| 4 | 大于平均分的电影数量 | 统计数字 |
| 5 | 高分电影中打分最多的用户 | 用户ID + 该用户平均分 |
| 6 | 每个用户的打分统计 | 用户ID + 平均分 + 最低分 + 最高分 |
| 7 | 热门电影的TOP10 | 电影ID + 评分次数 + 平均分 |
总结
本章核心知识点:
- SparkSession:Spark SQL 的统一入口
- DataFrame:带 Schema 的分布式数据集,比 RDD 更简洁高效
- 数据源:支持 Parquet、JSON、CSV、Text、JDBC 等
- 两种操作风格 :
- DSL 风格:
df.select().filter().groupBy() - SQL 风格:
df.createTempView()+spark.sql()
- DSL 风格:
- RDD 转 DataFrame:反射推断 或 编程定义 Schema
- JDBC 读写:MySQL 等关系型数据库
- PySpark + pandas:小数据用 pandas,大数据用 PySpark,两者可互转
- UDF:用户自定义函数,支持普通 UDF 和 pandas_udf