Spark 4.0 新特性Python Data Source API 快速上手

1. 什么是 Python Data Source API

Python Data Source API 是 Spark 4.0 引入的新能力,它允许开发者在 Python 中直接实现自定义数据源和数据写出逻辑。换句话说,你可以像实现一个插件一样,为 Spark 增加新的读取来源和写出目标,而不必先写 Scala/Java 版本。

它支持的能力包括:

  • 批处理读取
  • 批处理写入
  • 流式读取
  • 流式写入

官方给出的能力映射关系很清晰:

  • 批处理读:实现 reader()
  • 批处理写:实现 writer()
  • 流式读:实现 streamReader()simpleStreamReader()
  • 流式写:实现 streamWriter()

2. 一个最简单的自定义数据源

先看一个最小示例:定义一个只返回两行数据的数据源。

python 复制代码
from typing import Iterator, Tuple

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

class SimpleDataSource(DataSource):
    @classmethod
    def name(cls) -> str:
        return "simple"

    def schema(self) -> StructType:
        return StructType([
            StructField("name", StringType()),
            StructField("age", IntegerType())
        ])

    def reader(self, schema: StructType) -> DataSourceReader:
        return SimpleDataSourceReader()

class SimpleDataSourceReader(DataSourceReader):
    def read(self, partition: InputPartition) -> Iterator[Tuple]:
        yield ("Alice", 20)
        yield ("Bob", 30)

注册后即可直接读取:

python 复制代码
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
spark.dataSource.register(SimpleDataSource)

spark.read.format("simple").load().show()

输出结果:

text 复制代码
+-----+---+
| name|age|
+-----+---+
|Alice| 20|
|  Bob| 30|
+-----+---+

这个例子说明了 Python Data Source API 的核心思路:定义名字、定义 schema、实现 reader,然后注册给 Spark 使用。

3. DataSource 的基本结构

一个完整的 Python Data Source,通常要继承 DataSource,然后根据你的能力实现对应方法。官方示例中给出了一份更完整的定义:

python 复制代码
from typing import Union

from pyspark.sql.datasource import (
    DataSource,
    DataSourceReader,
    DataSourceStreamReader,
    DataSourceStreamWriter,
    DataSourceWriter
)
from pyspark.sql.types import StructType

class FakeDataSource(DataSource):
    @classmethod
    def name(cls) -> str:
        return "fake"

    def schema(self) -> Union[StructType, str]:
        return "name string, date string, zipcode string, state string"

    def reader(self, schema: StructType) -> DataSourceReader:
        return FakeDataSourceReader(schema, self.options)

    def writer(self, schema: StructType, overwrite: bool) -> DataSourceWriter:
        return FakeDataSourceWriter(self.options)

    def streamReader(self, schema: StructType) -> DataSourceStreamReader:
        return FakeStreamReader(schema, self.options)

    def streamWriter(self, schema: StructType, overwrite: bool) -> DataSourceStreamWriter:
        return FakeStreamWriter(self.options)

这段代码已经覆盖了:

  • 批量读
  • 批量写
  • 流式读
  • 流式写

如果你只需要其中一种能力,只实现对应方法即可。

4. 批处理 Reader 怎么写

官方示例里的 FakeDataSourceReader 使用 faker 库按 schema 生成假数据。

python 复制代码
from typing import Dict
from pyspark.sql.datasource import DataSourceReader
from pyspark.sql.types import StructType

class FakeDataSourceReader(DataSourceReader):
    def __init__(self, schema: StructType, options: Dict[str, str]):
        self.schema = schema
        self.options = options

    def read(self, partition):
        from faker import Faker
        fake = Faker()

        num_rows = int(self.options.get("numRows", 3))
        for _ in range(num_rows):
            row = []
            for field in self.schema.fields:
                value = getattr(fake, field.name)()
                row.append(value)
            yield tuple(row)

这里有两个关键点:

  1. options 里的值全部都是字符串,所以要自己做类型转换
  2. 依赖库 faker 是在方法内部导入的,这和后面要说的序列化要求有关

5. 批处理 Writer 怎么写

批处理写出要实现 DataSourceWriter。官方示例里,Writer 会统计每个分区写出的记录数,并在 commit() 中输出总行数。

python 复制代码
from dataclasses import dataclass
from typing import Iterator, List

from pyspark.sql import Row
from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage

@dataclass
class SimpleCommitMessage(WriterCommitMessage):
    partition_id: int
    count: int

class FakeDataSourceWriter(DataSourceWriter):
    def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
        from pyspark import TaskContext

        context = TaskContext.get()
        partition_id = context.partitionId()
        cnt = sum(1 for _ in rows)
        return SimpleCommitMessage(partition_id=partition_id, count=cnt)

    def commit(self, messages: List[SimpleCommitMessage]) -> None:
        total_count = sum(message.count for message in messages)
        print(f"Total number of rows: {total_count}")

    def abort(self, messages: List[SimpleCommitMessage]) -> None:
        failed_count = sum(message is None for message in messages)
        print(f"Number of failed tasks: {failed_count}")

这类写法很适合:

  • 统计落盘条数
  • 汇总分区写出结果
  • 自定义成功/失败处理逻辑

6. 流式 Reader 怎么写

6.1 标准流式 Reader

如果你的流式数据源需要自己管理 offset、分区规划、提交状态,可以实现 DataSourceStreamReader。官方示例中每个 micro-batch 生成 2 行数据。

python 复制代码
from typing import Iterator, Tuple
from pyspark.sql.datasource import DataSourceStreamReader, InputPartition

class RangePartition(InputPartition):
    def __init__(self, start: int, end: int):
        self.start = start
        self.end = end

class FakeStreamReader(DataSourceStreamReader):
    def __init__(self, schema, options):
        self.current = 0

    def initialOffset(self) -> dict:
        return {"offset": 0}

    def latestOffset(self) -> dict:
        self.current += 2
        return {"offset": self.current}

    def partitions(self, start: dict, end: dict) -> list[InputPartition]:
        return [RangePartition(start["offset"], end["offset"])]

    def commit(self, end: dict) -> None:
        pass

    def read(self, partition) -> Iterator[Tuple]:
        for i in range(partition.start, partition.end):
            yield (i, str(i))

6.2 简化版流式 Reader

如果数据源吞吐量不高,也不需要复杂分区,可以实现 SimpleDataSourceStreamReader。官方说明里明确说:streamReader()simpleStreamReader() 二选一即可,前者优先。

python 复制代码
from typing import Iterator, Tuple
from pyspark.sql.datasource import SimpleDataSourceStreamReader

class FakeSimpleStreamReader(SimpleDataSourceStreamReader):
    def initialOffset(self) -> dict:
        return {"offset": 0}

    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
        start_idx = start["offset"]
        it = iter([(i,) for i in range(start_idx, start_idx + 2)])
        return (it, {"offset": start_idx + 2})

    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
        return iter([(i,) for i in range(start["offset"], end["offset"])])

    def commit(self, end: dict) -> None:
        pass

7. 流式 Writer 怎么写

流式写出对应的是 DataSourceStreamWriter。官方示例里,它把每个 micro-batch 的元信息写到本地路径中。

python 复制代码
from typing import Iterator, List, Optional
from pyspark.sql import Row
from pyspark.sql.datasource import DataSourceStreamWriter, WriterCommitMessage

class SimpleCommitMessage(WriterCommitMessage):
    partition_id: int
    count: int

class FakeStreamWriter(DataSourceStreamWriter):
    def __init__(self, options):
        self.options = options
        self.path = self.options.get("path")
        assert self.path is not None

    def write(self, iterator: Iterator[Row]) -> WriterCommitMessage:
        from pyspark import TaskContext
        context = TaskContext.get()
        partition_id = context.partitionId()
        cnt = 0
        for _ in iterator:
            cnt += 1
        return SimpleCommitMessage(partition_id=partition_id, count=cnt)

    def commit(self, messages: List[Optional[SimpleCommitMessage]], batchId: int) -> None:
        status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
        with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
            file.write(json.dumps(status) + "\n")

    def abort(self, messages: List[Optional[SimpleCommitMessage]], batchId: int) -> None:
        with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
            file.write(f"failed in batch {batchId}")

这个模式的核心是:

  • write() 负责分区级写出
  • commit() 负责 batch 成功后的汇总处理
  • abort() 负责 batch 失败后的兜底逻辑

8. 序列化要求一定要注意

官方专门强调:DataSourceDataSourceReaderDataSourceWriterDataSourceStreamReaderDataSourceStreamWriter 以及它们的方法,都必须能够被 pickle 序列化。

因此有一个很重要的实践:

方法内部用到的库,尽量在方法内部导入。

比如官方示例中,TaskContext 不是在文件开头导入,而是在 read()write() 方法内部导入:

python 复制代码
def read(self, partition):
    from pyspark import TaskContext
    context = TaskContext.get()

这是为了避免序列化问题。

9. 如何注册和使用 Python Data Source

9.1 注册

python 复制代码
spark.dataSource.register(FakeDataSource)

9.2 读取

使用默认 schema:

python 复制代码
spark.read.format("fake").load().show()

使用自定义 schema:

python 复制代码
spark.read.format("fake").schema("name string, company string").load().show()

传入 option:

python 复制代码
spark.read.format("fake").option("numRows", 5).load().show()

这些都是官方给出的直接用法。

9.3 写出

写出时要指定 mode(),支持 appendoverwrite

python 复制代码
df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()

10. 在 Structured Streaming 里使用

注册之后,同一个 Python Data Source 也可以用于 readStream()writeStream()

10.1 作为流式 source

python 复制代码
query = spark.readStream.format("fake").load() \
    .writeStream.format("console").start()

10.2 同时作为流式 source 和 sink

python 复制代码
query = spark.readStream.format("fake").load() \
    .writeStream.format("fake").start("/output_path")

这意味着你可以完全用 Python 打通一个自定义流式读写链路。

11. Arrow Batch 支持:性能提升的重点

官方文档最后一部分特别强调了一点:Python Data Source Reader 支持直接产出 Arrow Batch,这能显著提升性能,甚至在大数据场景下带来一个数量级的提升。

实现方式很简单:在 read() 方法里直接 yield pyarrow.RecordBatch

python 复制代码
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa

class ArrowBatchDataSource(DataSource):
    @classmethod
    def name(cls):
        return "arrowbatch"

    def schema(self):
        return "key int, value string"

    def reader(self, schema: str):
        return ArrowBatchDataSourceReader(schema, self.options)

class ArrowBatchDataSourceReader(DataSourceReader):
    def __init__(self, schema, options):
        self.schema = schema
        self.options = options

    def read(self, partition):
        keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
        values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
        schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
        record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
        yield record_batch

    def partitions(self):
        return [InputPartition(i) for i in range(1)]

spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()
spark.dataSource.register(ArrowBatchDataSource)

df = spark.read.format("arrowbatch").load()
df.show()

这部分非常值得关注,因为它说明 Python Data Source API 不只是"能用",而且在配合 Arrow 后可以更快。

12. 使用时的几个注意点

官方文档还给了几条很实用的说明:

12.1 名称冲突时,内置和 Scala/Java Data Source 优先

如果 Python Data Source 和内置数据源或 Scala/Java 数据源同名,默认优先解析后者。因此定义名字时,尽量避免冲突。

12.2 可以重复注册,后注册会覆盖前注册

这对调试很方便,但也意味着线上环境要注意命名和注册顺序。

12.3 支持自动注册

官方提到,如果你在顶层模块里把数据源导出为 DefaultSource,并且模块名前缀是 pyspark_,就可以自动注册。

13. 总结

Python Data Source API 的价值很直接:它让 Spark 4.0 开始真正具备了"用 Python 写自定义数据源"的能力。以前这类扩展大多要落到 JVM 侧,现在很多批处理和流处理场景都可以直接用 Python 完成。它最重要的几个亮点是:

  • 批处理读写可扩展
  • 流处理读写可扩展
  • spark.read / write / readStream / writeStream 体系自然融合
  • 支持 Arrow Batch 提升性能

如果你在做自定义 Connector、测试数据源、内部平台适配层,或者想快速接入非标准数据来源,这个 API 会非常值得深入。

相关推荐
王小义笔记3 小时前
大模型微调步骤与精髓总结
python·大模型·llm
源码之家4 小时前
计算机毕业设计:Python汽车销量数据采集分析可视化系统 Flask框架 requests爬虫 可视化 车辆 大数据 机器学习 hadoop(建议收藏)✅
大数据·爬虫·python·django·flask·课程设计·美食
Roselind_Yi4 小时前
【吴恩达2026 Agentic AI】面试向+项目实战(含面试题+项目案例)-2
人工智能·python·机器学习·面试·职场和发展·langchain·agent
2401_827499994 小时前
python核心语法01-数据存储与运算
java·数据结构·python
一直会游泳的小猫4 小时前
ClaudeCode完整学习指南
python·ai编程·claude code·claude code指南
第一程序员4 小时前
Python与容器化:Docker和Kubernetes实战
python·github
JaydenAI4 小时前
[RAG在LangChain中的实现-04]常用的向量存储和基于向量存储的检索器
python·langchain·ai编程
Roselind_Yi4 小时前
【吴恩达2026 Agentic AI】面试向+项目实战(含面试题+项目案例)-1
人工智能·python·面试·职场和发展·langchain·gpt-3·agent
Alan GEO实施教练4 小时前
专利申请是否找代理机构:核心考量与决策逻辑拆解
大数据·人工智能·python