Spark SQL 详细讲义

第6章 Spark SQL 详细课堂讲义


第一部分:Spark SQL 概述(6.1)

1.1 从 Shark 说起

知识点讲解:

Shark 是 Hive on Spark 的实现,它复用了 Hive 的 HQL 解析、逻辑计划优化等模块,只把物理执行计划从 MapReduce 替换成了 Spark RDD 操作。虽然 Shark 的性能比 Hive 提高了 10-100 倍,但存在两个严重问题:

  1. 优化依赖 Hive:无法方便地添加新的优化策略
  2. 线程安全问题:Spark 是线程级并行,而 MapReduce 是进程级并行,导致 Spark 在兼容 Hive 时存在线程安全问题

2014年,Shark 停止开发,转向 Spark SQL 和 Hive on Spark 两个分支。

1.2 Spark SQL 架构

核心组件:

  • Catalyst 优化器:负责执行计划的生成和优化
  • DataFrame:带 Schema 信息的分布式数据集
  • 外部数据源支持:Hive、JSON、Parquet、JDBC 等

1.3 为什么推出 Spark SQL

三大原因:

  1. 用户需要从多种数据源(结构化、半结构化、非结构化)执行 ETL 操作
  2. 用户需要执行高级分析(机器学习、图处理)
  3. 实际大数据应用经常需要融合关系查询和复杂算法

Spark SQL 的解决方案:

  • 提供 DataFrame API,支持关系型操作
  • 引入 Catalyst 优化器,支持扩展
  • 融合 SQL 查询和复杂算法

1.4 Spark SQL 的特点

  1. 容易整合:SQL 查询和 Spark 程序无缝集成
  2. 统一数据访问:以相同方式访问 Hive、Avro、Parquet、ORC、JSON、JDBC 等
  3. 兼容 Hive:支持 HiveQL、SerDes、UDF
  4. 标准数据库连接:支持 JDBC/ODBC

1.5 SparkSession 入门

知识点讲解:

Spark 2.0 以上版本使用统一的 SparkSession 替代了 SQLContextHiveContextSparkSession 封装了 SparkContextSparkConfStreamingContext 等,是 Spark 应用程序的统一入口。

完整代码示例:

python 复制代码
# coding:utf8
from pyspark.sql import SparkSession

if __name__ == '__main__':
    # 创建 SparkSession 对象
    spark = SparkSession \
        .builder \
        .master("local[*]") \
        .appName("Simple Application") \
        .getOrCreate()
    
    # 读取文件
    logFile = "file:///usr/local/spark/README.md"
    logDF = spark.read.text(logFile)
    
    # 统计包含字母 a 和 b 的行数
    numAs = logDF.filter(logDF["value"].contains("a")).count()
    numBs = logDF.filter(logDF["value"].contains("b")).count()
    
    print('Lines with a: %s, Lines with b: %s' % (numAs, numBs))
    
    # 停止 SparkSession
    spark.stop()

说明:spark-shell 中,系统已经默认提供了一个 SparkSession 对象,名称为 spark,可以直接使用。


第二部分:DataFrame 详解(6.2)

2.1 什么是 DataFrame

知识点讲解:

DataFrame 是一种以 RDD 为基础的分布式数据集,提供了详细的结构信息(Schema)。与 RDD 相比:

特性 RDD DataFrame
数据结构 无结构,只是 Java 对象集合 有明确的列结构(Schema)
优化能力 无法优化 Spark 可以理解和优化
表达能力 需要告诉 Spark "如何做" 只需要告诉 Spark "做什么"
代码简洁度 复杂 简洁

2.2 DataFrame 的优点(重点案例)

案例:计算每种图书的平均销量

数据集:键值对 ("spark",2)、("hadoop",6)、("hadoop",4)、("spark",6)

RDD 实现方式(复杂):

python 复制代码
# coding:utf8
from pyspark import SparkConf, SparkContext

if __name__ == '__main__':
    conf = SparkConf().setAppName("Simple App")
    sc = SparkContext(conf=conf)
    
    bookRDD = sc.parallelize([("spark", 2), ("hadoop", 6), ("hadoop", 4), ("spark", 6)])
    
    saleRDD = bookRDD.map(lambda x: (x[0], (x[1], 1))) \
        .reduceByKey(lambda x, y: (x[0] + y[0], x[1] + y[1])) \
        .map(lambda x: (x[0], x[1][0] / x[1][1]))
    
    saleRDD.foreach(print)
    sc.stop()

DataFrame 实现方式(简洁):

python 复制代码
# coding:utf8
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .master("local[*]") \
        .appName("Simple Application") \
        .getOrCreate()
    
    # 创建 DataFrame
    bookDF = spark \
        .createDataFrame([("spark", 2), ("hadoop", 6), ("hadoop", 4), ("spark", 6)]) \
        .toDF("book", "amount")
    
    # 分组计算平均值
    avgDF = bookDF.groupBy("book").agg(avg("amount"))
    avgDF.show()
    
    spark.stop()

执行结果:

复制代码
+------+-----------+
|  book| avg(amount)|
+------+-----------+
|spark|        4.0|
|hadoop|        5.0|
+------+-----------+

第三部分:DataFrame 的创建与保存(6.3)

3.1 Parquet 文件格式

知识点: Parquet 是 Spark 的默认数据源,是一种开源的列式存储文件格式,支持压缩(如 snappy),提供高效的 I/O 优化。

读取 Parquet 文件:

python 复制代码
# 方式一:使用 format() 指定格式
filePath = "file:///usr/local/spark/examples/src/main/resources/users.parquet"
df = spark.read.format("parquet").load(filePath)
df.show()

# 方式二:直接使用 parquet() 方法
df = spark.read.parquet(filePath)

保存 DataFrame 为 Parquet 文件:

python 复制代码
# 保存并指定压缩方式
df.write.format("parquet") \
    .mode("overwrite") \
    .option("compression", "snappy") \
    .save("file:///home/hadoop/otherusers")

# 简化写法
df.write.parquet("file:///home/hadoop/otherusers")

3.2 JSON 文件

知识点: JSON 支持单行模式和多行模式两种格式。

读取 JSON 文件:

python 复制代码
filePath = "file:///usr/local/spark/examples/src/main/resources/people.json"

# 方式一
df = spark.read.format("json").load(filePath)

# 方式二
df = spark.read.json(filePath)

df.show()

保存 DataFrame 为 JSON 文件:

python 复制代码
df.write.format("json").mode("overwrite").save("file:///home/hadoop/otherpeople")

# 简化写法
df.write.json("file:///home/hadoop/otherpeople")

3.3 CSV 文件

知识点: CSV 是最常见的文本数据格式,字段用逗号分隔,每行一条记录。

读取 CSV 文件:

python 复制代码
filePath = "file:///usr/local/spark/examples/src/main/resources/people.csv"

# 定义 Schema(表结构)
schema = "name STRING, age INT, job STRING"

# 方式一:使用 format()
df = spark.read.format("csv") \
    .schema(schema) \
    .option("header", "true") \
    .option("sep", ";") \
    .load(filePath)

# 方式二:直接使用 csv() 方法
df = spark.read \
    .schema(schema) \
    .option("header", "true") \
    .option("sep", ";") \
    .csv(filePath)

df.show()

保存 DataFrame 为 CSV 文件:

python 复制代码
df.write.format("csv").mode("overwrite").save("file:///home/hadoop/anotherpeople")

# 简化写法
df.write.csv("file:///home/hadoop/anotherpeople")

3.4 文本文件

读取文本文件:

python 复制代码
filePath = "file:///usr/local/spark/examples/src/main/resources/people.txt"

df = spark.read.format("text").load(filePath)
# 或者
df = spark.read.text(filePath)

df.show()

保存为文本文件:

python 复制代码
df.write.text("file:///home/hadoop/newpeople")
# 或者
df.write.format("text").save("file:///home/hadoop/newpeople")

注意: 保存为文本文件时,DataFrame 只能包含一列数据。

3.5 从序列集合创建 DataFrame

python 复制代码
# 基本创建方式
df = spark.createDataFrame([
    ("Xiaomei", "Female", "21"),
    ("Xiaoming", "Male", "22"),
    ("Xiaoxue", "Female", "23")
]).toDF("name", "sex", "age")

df.show()

# 指定数据类型的方式
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

schema = StructType([
    StructField("name", StringType(), True),
    StructField("sex", StringType(), True),
    StructField("age", IntegerType(), True)
])

data = [("Xiaomei", "Female", 21), ("Xiaoming", "Male", 22), ("Xiaoxue", "Female", 23)]
df = spark.createDataFrame(data, schema)

第四部分:DataFrame 的基本操作(6.4)

4.1 DSL 语法风格

准备数据:

python 复制代码
filePath = "file:///usr/local/spark/examples/src/main/resources/people.json"
df = spark.read.json(filePath)
1. printSchema() - 打印结构信息
python 复制代码
df.printSchema()

输出:

复制代码
root
 |-- age: long (nullable = true)
 |-- name: string (nullable = true)
2. show() - 显示数据
python 复制代码
df.show()

输出:

复制代码
+----+-------+
| age|   name|
+----+-------+
|null|Michael|
|  30|   Andy|
|  19| Justin|
+----+-------+
3. select() - 选择列
python 复制代码
# 选择单列
df.select("name").show()

# 选择多列
df.select("name", "age").show()

# 在 select 中进行计算
df.select(df["name"], df["age"] + 1).show()

# 重命名列
df.select(df["name"].alias("username"), df["age"]).show()
4. filter() / where() - 条件筛选
python 复制代码
# 使用 filter
df.filter(df["age"] > 20).show()

# 使用 where(多种写法)
df.where(df["age"] > 20).show()
df.where("age > 20").show()
5. groupBy() - 分组聚合
python 复制代码
# 分组并统计数量
df.groupBy("age").count().show()

# 使用聚合函数
from pyspark.sql.functions import avg, sum, max, min

df.groupBy("age").agg(avg("age")).show()
6. sort() / orderBy() - 排序
python 复制代码
# 降序排序
df.sort(df["age"].desc()).show()
df.sort(df["age"].desc(), df["name"].asc()).show()

# 使用 orderBy
df.orderBy("age", ascending=False).show()
df.orderBy(df["age"], ascending=False).show()
7. withColumn() - 添加新列
python 复制代码
from pyspark.sql.functions import expr

# 根据条件添加新列
df2 = df.withColumn(
    "IfWithAge",
    expr("CASE WHEN age IS NOT NULL THEN 'YES' ELSE 'NO' END")
)
df2.show()
8. withColumnRenamed() - 重命名列
python 复制代码
df3 = df.withColumnRenamed("name", "username")
df3.show()
9. drop() - 删除列
python 复制代码
df4 = df3.drop("IfWithAge")
df4.show()
10. 统计函数
python 复制代码
# 描述性统计
df.describe().show()

# 单独统计
from pyspark.sql.functions import min, max, avg, sum

df.select(min("age"), max("age"), avg("age"), sum("age")).show()

4.2 SQL 语法风格(重点)

知识点: 使用 SQL 语句操作 DataFrame 之前,必须先创建临时视图。

创建临时视图
python 复制代码
# 创建临时视图(如果已存在会报错)
df.createTempView("people")

# 创建或替换临时视图(推荐)
df.createOrReplaceTempView("people")

# 创建全局临时视图(跨会话可用)
df.createGlobalTempView("global_people")
基础 SQL 查询
python 复制代码
# 1. 简单查询
result = spark.sql("SELECT * FROM people")
result.show()

# 2. 条件查询
result = spark.sql("SELECT name, age FROM people WHERE age > 20")
result.show()

# 3. 分组聚合
result = spark.sql("""
    SELECT age, COUNT(*) as count 
    FROM people 
    GROUP BY age 
    ORDER BY age DESC
""")
result.show()
SQL 函数使用(完整案例)

案例:使用系统函数和用户自定义函数

python 复制代码
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import from_unixtime
from pyspark.sql.types import IntegerType, StringType, LongType, StructField, StructType

# 创建 SparkSession
spark = SparkSession.builder.master("local[*]").appName("SQL Functions").getOrCreate()

# 定义 Schema
schema = StructType([
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("create_time", LongType(), True)
])

# 准备数据
data = [
    Row("Xiaomei", 21, 1580432800),
    Row("Xiaoming", 22, 1580436400),
    Row("Xiaoxue", 23, 1580438800)
]

df = spark.createDataFrame(data, schema)
df.show()

# 创建临时视图
df.createOrReplaceTempView("users")

# 1. 使用系统函数 from_unixtime 格式化时间戳
result1 = spark.sql("""
    SELECT name, 
           age, 
           from_unixtime(create_time) as create_time_str 
    FROM users
""")
result1.show()

# 2. 定义用户自定义函数(UDF)
def to_upper(name):
    return name.upper()

# 注册 UDF
spark.udf.register("upper_name", to_upper, StringType())

# 使用 UDF
result2 = spark.sql("""
    SELECT upper_name(name) as upper_name, age 
    FROM users
""")
result2.show()

# 3. 组合使用系统函数和 UDF
result3 = spark.sql("""
    SELECT upper_name(name) as upper_name,
           age,
           from_unixtime(create_time) as create_time_str
    FROM users
    WHERE age > 21
""")
result3.show()

spark.stop()
常用的 Spark SQL 系统函数分类
分类 函数示例
数学函数 round(), abs(), sqrt(), pow()
字符串函数 length(), substring(), concat(), upper(), lower()
日期函数 year(), month(), day(), date_add(), datediff()
聚合函数 sum(), avg(), count(), max(), min()
窗口函数 row_number(), rank(), lag(), lead()
条件函数 when(), if(), coalesce()
窗口函数使用示例
python 复制代码
# 创建示例数据
data = [
    ("Alice", "Sales", 5000),
    ("Bob", "Sales", 6000),
    ("Charlie", "Sales", 5500),
    ("David", "IT", 7000),
    ("Eve", "IT", 6500)
]
df = spark.createDataFrame(data, ["name", "dept", "salary"])
df.createOrReplaceTempView("employees")

# 使用窗口函数:每个部门内按工资排名
result = spark.sql("""
    SELECT name, dept, salary,
           ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rank
    FROM employees
""")
result.show()

第五部分:从 RDD 转换得到 DataFrame(6.5)

5.1 利用反射机制推断 RDD 模式

知识点: 当数据结构和字段名称已知时,可以使用反射机制自动推断 Schema。

完整案例:

python 复制代码
from pyspark.sql import Row

# 读取文本文件
people = spark.sparkContext \
    .textFile("file:///usr/local/spark/examples/src/main/resources/people.txt") \
    .map(lambda line: line.split(",")) \
    .map(lambda p: Row(name=p[0], age=int(p[1])))

# 创建 DataFrame(自动推断 Schema)
schemaPeople = spark.createDataFrame(people)

# 注册为临时表
schemaPeople.createOrReplaceTempView("people")

# 执行 SQL 查询
personsDF = spark.sql("SELECT name, age FROM people WHERE age > 20")

# 将结果转换为 RDD 并处理
personsRDD = personsDF.rdd.map(lambda p: "Name: " + p.name + ", Age: " + str(p.age))
personsRDD.foreach(print)

# 输出:
# Name: Michael, Age: 29
# Name: Andy, Age: 30

5.2 使用编程方式定义 RDD 模式

知识点: 当数据结构无法提前获知时,需要通过编程方式动态定义 Schema。

完整案例:

python 复制代码
from pyspark.sql.types import *
from pyspark.sql import Row

# 第一步:定义 Schema(表头)
schemaString = "name age"
fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split(" ")]
schema = StructType(fields)

# 第二步:读取数据并转换为 Row 对象(表中的记录)
lines = spark.sparkContext.textFile("file:///usr/local/spark/examples/src/main/resources/people.txt")
parts = lines.map(lambda x: x.split(","))
people = parts.map(lambda p: Row(p[0], p[1].strip()))

# 第三步:组合 Schema 和数据
schemaPeople = spark.createDataFrame(people, schema)

# 第四步:注册临时表并执行查询
schemaPeople.createOrReplaceTempView("people")
results = spark.sql("SELECT name, age FROM people")
results.show()

输出:

复制代码
+-------+---+
|   name|age|
+-------+---+
|Michael| 29|
|   Andy| 30|
| Justin| 19|
+-------+---+

第六部分:使用 Spark SQL 读写数据库(6.6)

6.1 准备工作

  1. 安装 MySQL 数据库(参考:http://dblab.xmu.edu.cn/blog/install-mysql/)
  2. 启动 MySQL 并创建数据库和表
sql 复制代码
-- 登录 MySQL
mysql -u root -p

-- 创建数据库
CREATE DATABASE spark;

-- 使用数据库
USE spark;

-- 创建 student 表
CREATE TABLE student (
    id INT(4), 
    name CHAR(20), 
    gender CHAR(4), 
    age INT(4)
);

-- 插入测试数据
INSERT INTO student VALUES(1, 'Xueqian', 'F', 23);
INSERT INTO student VALUES(2, 'Weiliang', 'M', 24);

-- 查看数据
SELECT * FROM student;
  1. 下载 MySQL JDBC 驱动

    • 下载 mysql-connector-java-5.1.40.tar.gz
    • 解压后将 .jar 文件拷贝到 /usr/local/spark/jars/ 目录
  2. 启动 PySpark

bash 复制代码
cd /usr/local/spark
./bin/pyspark

6.2 读取 MySQL 数据库中的数据

在 PySpark Shell 中执行:

python 复制代码
jdbcDF = spark.read \
    .format("jdbc") \
    .option("driver", "com.mysql.jdbc.Driver") \
    .option("url", "jdbc:mysql://localhost:3306/spark") \
    .option("dbtable", "student") \
    .option("user", "root") \
    .option("password", "123456") \
    .load()

jdbcDF.show()

在 PyCharm 独立应用程序中执行:

python 复制代码
from pyspark.sql import SparkSession

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName('SparkReadMySQL') \
        .master('local[*]') \
        .getOrCreate()
    
    jdbcDF = spark.read \
        .format("jdbc") \
        .option("driver", "com.mysql.jdbc.Driver") \
        .option("url", "jdbc:mysql://localhost:3306/spark?useSSL=false") \
        .option("dbtable", "student") \
        .option("user", "root") \
        .option("password", "123456") \
        .load()
    
    jdbcDF.show()
    spark.stop()

6.3 向 MySQL 数据库写入数据

完整写入程序:

python 复制代码
#!/usr/bin/env python3
from pyspark.sql import Row
from pyspark.sql.types import *
from pyspark.sql import SparkSession

# 创建 SparkSession
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("Write to MySQL") \
    .getOrCreate()

# 定义 Schema
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("age", IntegerType(), True)
])

# 准备数据
studentRDD = spark.sparkContext \
    .parallelize(["3 Rongcheng M 26", "4 Guanhua M 27"]) \
    .map(lambda x: x.split(" "))

# 创建 Row 对象
rowRDD = studentRDD.map(lambda p: Row(
    int(p[0].strip()), 
    p[1].strip(), 
    p[2].strip(), 
    int(p[3].strip())
))

# 创建 DataFrame
studentDF = spark.createDataFrame(rowRDD, schema)

# 设置数据库连接参数
prop = {
    'user': 'root',
    'password': '123456',
    'driver': "com.mysql.jdbc.Driver"
}

# 写入数据库
studentDF.write.jdbc(
    "jdbc:mysql://localhost:3306/spark", 
    "student", 
    'append', 
    prop
)

# 验证:读取并显示
verifyDF = spark.read \
    .format("jdbc") \
    .option("driver", "com.mysql.jdbc.Driver") \
    .option("url", "jdbc:mysql://localhost:3306/spark") \
    .option("dbtable", "student") \
    .option("user", "root") \
    .option("password", "123456") \
    .load()

verifyDF.show()

spark.stop()

写入后 MySQL 中的数据:

复制代码
+------+-----------+--------+------+
| id   | name      | gender | age  |
+------+-----------+--------+------+
| 1    | Xueqian   | F      | 23   |
| 2    | Weiliang  | M      | 24   |
| 3    | Rongcheng | M      | 26   |
| 4    | Guanhua   | M      | 27   |
+------+-----------+--------+------+

第七部分:PySpark 与 pandas 整合(6.7)

7.1 安装 pandas

bash 复制代码
conda activate pyspark
pip install pandas
pip install pyarrow   # 用于高性能数据转换

7.2 pandas 数据结构简介

Series(一维数组)
python 复制代码
import pandas as pd

# 创建 Series(自动索引)
obj = pd.Series([3, 5, 6, 8, 9, 2])
print(obj)

# 创建 Series(指定索引)
obj2 = pd.Series([3, 5, 6, 8, 9, 2], index=['a', 'b', 'c', 'd', 'e', 'f'])
print(obj2)

# 通过索引访问
print(obj2['a'])           # 单个值
print(obj2[['b', 'd', 'f']])  # 多个值
DataFrame(二维表格)
python 复制代码
import pandas as pd

# 从字典创建 DataFrame
data = {
    'sno': ['95001', '95002', '95003', '95004'],
    'name': ['Xiaoming', 'Zhangsan', 'Lisi', 'Wangwu'],
    'sex': ['M', 'F', 'F', 'M'],
    'age': [22, 25, 24, 23]
}
frame = pd.DataFrame(data)
print(frame)

# 指定列顺序
frame = pd.DataFrame(data, columns=['name', 'sno', 'sex', 'age'])
print(frame)

# 指定行索引
frame = pd.DataFrame(data, 
                     columns=['sno', 'name', 'sex', 'age', 'grade'],
                     index=['a', 'b', 'c', 'd'])
print(frame)

# 访问列
print(frame['sno'])
print(frame.name)

# 访问行
print(frame.loc['b'])

7.3 实例1:两种 DataFrame 相互转换

python 复制代码
# coding:utf8
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
import pandas as pd

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .master("local[*]") \
        .appName("Pandas to Spark") \
        .getOrCreate()
    
    # 1. 创建 pandas DataFrame
    pd_df = pd.DataFrame({
        'id': [1, 2, 3, 4, 5],
        'name': ['zhangsan', 'lisi', 'wangwu', 'maliu', 'liuqi'],
        'age': [24, 26, 22, 28, 24]
    })
    
    print("=== pandas DataFrame ===")
    print(pd_df)
    
    # 2. 定义 PySpark DataFrame 的 Schema
    schema = StructType([
        StructField("id", IntegerType(), True),
        StructField("name", StringType(), True),
        StructField("age", IntegerType(), True)
    ])
    
    # 3. pandas DataFrame -> PySpark DataFrame
    spark_df = spark.createDataFrame(pd_df, schema)
    
    print("=== PySpark DataFrame ===")
    spark_df.show()
    
    # 4. PySpark DataFrame -> pandas DataFrame
    pd_df2 = spark_df.toPandas()
    
    print("=== 转换回 pandas DataFrame ===")
    print(pd_df2)
    
    spark.stop()

7.4 实例2:使用自定义聚合函数(pandas_udf)

数据文件 people_data.txt 内容:

复制代码
1,zhangsan,22
2,lisi,26
3,wangwu,28
4,maliu,24

完整代码:

python 复制代码
# coding:utf8
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, StringType, StructField, StructType, Row
import pandas as pd

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .master("local[*]") \
        .appName("Pandas UDF Example") \
        .getOrCreate()
    
    # 读取数据
    peopleRDD = spark.sparkContext \
        .textFile("file:///home/hadoop/people_data.txt") \
        .map(lambda line: line.split(",")) \
        .map(lambda p: Row(id=int(p[0]), name=p[1], age=int(p[2])))
    
    # 定义 Schema
    schema = StructType([
        StructField("id", IntegerType(), True),
        StructField("name", StringType(), True),
        StructField("age", IntegerType(), True)
    ])
    
    # 创建 DataFrame
    peopleDF = peopleRDD.toDF(schema)
    peopleDF.show()
    
    # 步骤1:使用装饰器定义 pandas UDF
    @F.pandas_udf(IntegerType())
    def my_sum(a: pd.Series) -> int:
        """计算年龄总和"""
        return a.sum()
    
    # 步骤2:使用 register 方法注册 UDF
    spark.udf.register("sum_func", my_sum)
    
    # 方式1:在 DSL 风格中使用
    print("=== DSL 风格使用 UDF ===")
    sumDF = peopleDF.select(my_sum("age").alias("total_age"))
    sumDF.show()
    
    # 方式2:在 SQL 风格中使用
    print("=== SQL 风格使用 UDF ===")
    peopleDF.createOrReplaceTempView("people")
    sumDF2 = spark.sql("SELECT sum_func(age) as total_age FROM people")
    sumDF2.show()
    
    spark.stop()

第八部分:综合实例 - 电影评分数据分析(6.8)

数据集说明

电影评分数据集 ratings.dat,每行包含4个字段,用 :: 分隔:

  • 用户ID(User ID)
  • 电影ID(Movie ID)
  • 评分(Rating,1-5分)
  • 时间戳(Timestamp)

数据示例:

复制代码
1::1193::5::978300760
1::661::3::978302109
1::914::3::978301968
2::1193::4::978300760
2::607::3::978302109

完整程序 MovieRating.py

python 复制代码
# coding:utf8
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .master("local[*]") \
        .appName("Movie Rating Analysis") \
        .getOrCreate()
    
    # ========== 1. 读取数据集 ==========
    print("=" * 50)
    print("1. 数据集如下:")
    print("=" * 50)
    
    filePath = "file:///home/hadoop/ratings.dat"
    
    # 定义 Schema
    schema = StructType([
        StructField("user_id", StringType(), True),
        StructField("movie_id", StringType(), True),
        StructField("rating", IntegerType(), True),
        StructField("ts", StringType(), True)
    ])
    
    # 读取 CSV 文件,分隔符为 "::"
    ratingsDF = spark.read.format("csv") \
        .option("sep", "::") \
        .option("header", False) \
        .option("encoding", "utf-8") \
        .schema(schema) \
        .load(filePath)
    
    ratingsDF.show(10)
    ratingsDF.createOrReplaceTempView("ratings")
    
    # ========== 2. 求每个用户的平均打分 ==========
    print("=" * 50)
    print("2. 求每个用户的平均打分")
    print("=" * 50)
    
    # DSL 风格
    ratingsDF.groupBy("user_id") \
        .avg("rating") \
        .withColumnRenamed("avg(rating)", "avg_rating") \
        .withColumn("avg_rating", F.round("avg_rating", 3)) \
        .orderBy("avg_rating", ascending=False) \
        .show(10)
    
    # ========== 3. 求每部电影的平均打分 ==========
    print("=" * 50)
    print("3. 求每部电影的平均打分")
    print("=" * 50)
    
    # SQL 风格
    spark.sql("""
        SELECT movie_id, 
               ROUND(AVG(rating), 3) AS avg_rating 
        FROM ratings 
        GROUP BY movie_id 
        ORDER BY avg_rating DESC
    """).show(10)
    
    # ========== 4. 查询大于平均打分的电影的数量 ==========
    print("=" * 50)
    print("4. 查询大于平均打分的电影的数量")
    print("=" * 50)
    
    # 先计算整体平均分
    avg_rating = ratingsDF.select(F.avg("rating")).collect()[0][0]
    print(f"整体平均分: {avg_rating:.3f}")
    
    # 统计高于平均分的电影数量
    high_rating_movies = ratingsDF \
        .filter(ratingsDF["rating"] > avg_rating) \
        .select("movie_id") \
        .distinct() \
        .count()
    
    print(f"大于平均打分的电影数量: {high_rating_movies}")
    
    # 另一种写法
    movieCount = ratingsDF.where(ratingsDF["rating"] > avg_rating).count()
    print(f"大于平均打分的评分记录数量: {movieCount}")
    
    # ========== 5. 查询高分(大于3分)电影中打分次数最多的用户 ==========
    print("=" * 50)
    print("5. 查询高分(大于3分)电影中打分次数最多的用户,并给出此人打分的平均值")
    print("=" * 50)
    
    # 找出高分电影中打分次数最多的用户
    user_id = ratingsDF.where(ratingsDF["rating"] > 3) \
        .groupBy("user_id") \
        .count() \
        .withColumnRenamed("count", "high_rating_count") \
        .orderBy("high_rating_count", ascending=False) \
        .limit(1) \
        .first()["user_id"]
    
    print(f"高分电影中打分次数最多的用户ID: {user_id}")
    
    # 计算该用户的平均打分
    ratingsDF.filter(ratingsDF["user_id"] == user_id) \
        .select(F.round(F.avg("rating"), 3).alias("avg_rating")) \
        .show()
    
    # ========== 6. 查询每个用户的平均打分、最低打分和最高打分 ==========
    print("=" * 50)
    print("6. 查询每个用户的平均打分、最低打分和最高打分")
    print("=" * 50)
    
    ratingsDF.groupBy("user_id") \
        .agg(
            F.round(F.avg("rating"), 3).alias("avg_rating"),
            F.min("rating").alias("min_rating"),
            F.max("rating").alias("max_rating")
        ) \
        .orderBy("avg_rating", ascending=False) \
        .show(10)
    
    # ========== 7. 查询打分次数超过100次的电影的平均分排行榜TOP10 ==========
    print("=" * 50)
    print("7. 查询打分次数超过100次的电影的平均分排行榜TOP10")
    print("=" * 50)
    
    ratingsDF.groupBy("movie_id") \
        .agg(
            F.count("movie_id").alias("rating_count"),
            F.round(F.avg("rating"), 3).alias("avg_rating")
        ) \
        .where("rating_count > 100") \
        .orderBy("avg_rating", ascending=False) \
        .limit(10) \
        .show()
    
    # 使用 SQL 风格的写法
    print("\n--- SQL 风格实现相同功能 ---")
    spark.sql("""
        SELECT movie_id, 
               COUNT(*) AS rating_count, 
               ROUND(AVG(rating), 3) AS avg_rating 
        FROM ratings 
        GROUP BY movie_id 
        HAVING rating_count > 100 
        ORDER BY avg_rating DESC 
        LIMIT 10
    """).show()
    
    spark.stop()

各题目输出说明

题号 题目 输出内容
1 显示数据集 前10行数据
2 每个用户的平均打分 用户ID + 平均分(降序)
3 每部电影的平均打分 电影ID + 平均分(降序)
4 大于平均分的电影数量 统计数字
5 高分电影中打分最多的用户 用户ID + 该用户平均分
6 每个用户的打分统计 用户ID + 平均分 + 最低分 + 最高分
7 热门电影的TOP10 电影ID + 评分次数 + 平均分

总结

本章核心知识点:

  1. SparkSession:Spark SQL 的统一入口
  2. DataFrame:带 Schema 的分布式数据集,比 RDD 更简洁高效
  3. 数据源:支持 Parquet、JSON、CSV、Text、JDBC 等
  4. 两种操作风格
    • DSL 风格:df.select().filter().groupBy()
    • SQL 风格:df.createTempView() + spark.sql()
  5. RDD 转 DataFrame:反射推断 或 编程定义 Schema
  6. JDBC 读写:MySQL 等关系型数据库
  7. PySpark + pandas:小数据用 pandas,大数据用 PySpark,两者可互转
  8. UDF:用户自定义函数,支持普通 UDF 和 pandas_udf

相关推荐
智慧景区与市集主理人6 小时前
五一乡村文旅增收难?巨有科技大数据双赋能破局突围
大数据·科技
TechubNews9 小时前
新火集团首席经济学家付鹏演讲——2026 年是 Crypto 加入到 FICC 资产配置框架元年
大数据·人工智能
河阿里9 小时前
SQL数据库:五大范式(NF)
数据库·sql·oracle
Elastic 中国社区官方博客10 小时前
为 Elastic Cloud Serverless 和 Elasticsearch 引入统一的 API 密钥
大数据·运维·elasticsearch·搜索引擎·云原生·serverless
CS创新实验室13 小时前
CS实验室行业报告:机器人领域就业分析报告
大数据·人工智能·机器人
Irene199113 小时前
SQL 中日期的特殊性总结(格式符严格要求全大写)
sql
花椒技术14 小时前
从区间锁到行锁:一次高并发写入死锁治理实战
后端·sql
LinuxGeek102416 小时前
Kylin-Server-V11、openEuler-22.03和openEuler-24.03的MySQL 9.7.0版本正式发布
大数据·mysql·kylin
容智信息16 小时前
国家级算力底座+企业级智能体:容智Agent OS 获选入驻移动云能中心,联手赋能千行百业
大数据·人工智能·自然语言处理·智慧城市