Spark 4.0 新特性Python Data Source API 快速上手

1. 什么是 Python Data Source API

Python Data Source API 是 Spark 4.0 引入的新能力,它允许开发者在 Python 中直接实现自定义数据源和数据写出逻辑。换句话说,你可以像实现一个插件一样,为 Spark 增加新的读取来源和写出目标,而不必先写 Scala/Java 版本。

它支持的能力包括:

  • 批处理读取
  • 批处理写入
  • 流式读取
  • 流式写入

官方给出的能力映射关系很清晰:

  • 批处理读:实现 reader()
  • 批处理写:实现 writer()
  • 流式读:实现 streamReader()simpleStreamReader()
  • 流式写:实现 streamWriter()

2. 一个最简单的自定义数据源

先看一个最小示例:定义一个只返回两行数据的数据源。

python 复制代码
from typing import Iterator, Tuple

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

class SimpleDataSource(DataSource):
    @classmethod
    def name(cls) -> str:
        return "simple"

    def schema(self) -> StructType:
        return StructType([
            StructField("name", StringType()),
            StructField("age", IntegerType())
        ])

    def reader(self, schema: StructType) -> DataSourceReader:
        return SimpleDataSourceReader()

class SimpleDataSourceReader(DataSourceReader):
    def read(self, partition: InputPartition) -> Iterator[Tuple]:
        yield ("Alice", 20)
        yield ("Bob", 30)

注册后即可直接读取:

python 复制代码
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
spark.dataSource.register(SimpleDataSource)

spark.read.format("simple").load().show()

输出结果:

text 复制代码
+-----+---+
| name|age|
+-----+---+
|Alice| 20|
|  Bob| 30|
+-----+---+

这个例子说明了 Python Data Source API 的核心思路:定义名字、定义 schema、实现 reader,然后注册给 Spark 使用。

3. DataSource 的基本结构

一个完整的 Python Data Source,通常要继承 DataSource,然后根据你的能力实现对应方法。官方示例中给出了一份更完整的定义:

python 复制代码
from typing import Union

from pyspark.sql.datasource import (
    DataSource,
    DataSourceReader,
    DataSourceStreamReader,
    DataSourceStreamWriter,
    DataSourceWriter
)
from pyspark.sql.types import StructType

class FakeDataSource(DataSource):
    @classmethod
    def name(cls) -> str:
        return "fake"

    def schema(self) -> Union[StructType, str]:
        return "name string, date string, zipcode string, state string"

    def reader(self, schema: StructType) -> DataSourceReader:
        return FakeDataSourceReader(schema, self.options)

    def writer(self, schema: StructType, overwrite: bool) -> DataSourceWriter:
        return FakeDataSourceWriter(self.options)

    def streamReader(self, schema: StructType) -> DataSourceStreamReader:
        return FakeStreamReader(schema, self.options)

    def streamWriter(self, schema: StructType, overwrite: bool) -> DataSourceStreamWriter:
        return FakeStreamWriter(self.options)

这段代码已经覆盖了:

  • 批量读
  • 批量写
  • 流式读
  • 流式写

如果你只需要其中一种能力,只实现对应方法即可。

4. 批处理 Reader 怎么写

官方示例里的 FakeDataSourceReader 使用 faker 库按 schema 生成假数据。

python 复制代码
from typing import Dict
from pyspark.sql.datasource import DataSourceReader
from pyspark.sql.types import StructType

class FakeDataSourceReader(DataSourceReader):
    def __init__(self, schema: StructType, options: Dict[str, str]):
        self.schema = schema
        self.options = options

    def read(self, partition):
        from faker import Faker
        fake = Faker()

        num_rows = int(self.options.get("numRows", 3))
        for _ in range(num_rows):
            row = []
            for field in self.schema.fields:
                value = getattr(fake, field.name)()
                row.append(value)
            yield tuple(row)

这里有两个关键点:

  1. options 里的值全部都是字符串,所以要自己做类型转换
  2. 依赖库 faker 是在方法内部导入的,这和后面要说的序列化要求有关

5. 批处理 Writer 怎么写

批处理写出要实现 DataSourceWriter。官方示例里,Writer 会统计每个分区写出的记录数,并在 commit() 中输出总行数。

python 复制代码
from dataclasses import dataclass
from typing import Iterator, List

from pyspark.sql import Row
from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage

@dataclass
class SimpleCommitMessage(WriterCommitMessage):
    partition_id: int
    count: int

class FakeDataSourceWriter(DataSourceWriter):
    def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
        from pyspark import TaskContext

        context = TaskContext.get()
        partition_id = context.partitionId()
        cnt = sum(1 for _ in rows)
        return SimpleCommitMessage(partition_id=partition_id, count=cnt)

    def commit(self, messages: List[SimpleCommitMessage]) -> None:
        total_count = sum(message.count for message in messages)
        print(f"Total number of rows: {total_count}")

    def abort(self, messages: List[SimpleCommitMessage]) -> None:
        failed_count = sum(message is None for message in messages)
        print(f"Number of failed tasks: {failed_count}")

这类写法很适合:

  • 统计落盘条数
  • 汇总分区写出结果
  • 自定义成功/失败处理逻辑

6. 流式 Reader 怎么写

6.1 标准流式 Reader

如果你的流式数据源需要自己管理 offset、分区规划、提交状态,可以实现 DataSourceStreamReader。官方示例中每个 micro-batch 生成 2 行数据。

python 复制代码
from typing import Iterator, Tuple
from pyspark.sql.datasource import DataSourceStreamReader, InputPartition

class RangePartition(InputPartition):
    def __init__(self, start: int, end: int):
        self.start = start
        self.end = end

class FakeStreamReader(DataSourceStreamReader):
    def __init__(self, schema, options):
        self.current = 0

    def initialOffset(self) -> dict:
        return {"offset": 0}

    def latestOffset(self) -> dict:
        self.current += 2
        return {"offset": self.current}

    def partitions(self, start: dict, end: dict) -> list[InputPartition]:
        return [RangePartition(start["offset"], end["offset"])]

    def commit(self, end: dict) -> None:
        pass

    def read(self, partition) -> Iterator[Tuple]:
        for i in range(partition.start, partition.end):
            yield (i, str(i))

6.2 简化版流式 Reader

如果数据源吞吐量不高,也不需要复杂分区,可以实现 SimpleDataSourceStreamReader。官方说明里明确说:streamReader()simpleStreamReader() 二选一即可,前者优先。

python 复制代码
from typing import Iterator, Tuple
from pyspark.sql.datasource import SimpleDataSourceStreamReader

class FakeSimpleStreamReader(SimpleDataSourceStreamReader):
    def initialOffset(self) -> dict:
        return {"offset": 0}

    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
        start_idx = start["offset"]
        it = iter([(i,) for i in range(start_idx, start_idx + 2)])
        return (it, {"offset": start_idx + 2})

    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
        return iter([(i,) for i in range(start["offset"], end["offset"])])

    def commit(self, end: dict) -> None:
        pass

7. 流式 Writer 怎么写

流式写出对应的是 DataSourceStreamWriter。官方示例里,它把每个 micro-batch 的元信息写到本地路径中。

python 复制代码
from typing import Iterator, List, Optional
from pyspark.sql import Row
from pyspark.sql.datasource import DataSourceStreamWriter, WriterCommitMessage

class SimpleCommitMessage(WriterCommitMessage):
    partition_id: int
    count: int

class FakeStreamWriter(DataSourceStreamWriter):
    def __init__(self, options):
        self.options = options
        self.path = self.options.get("path")
        assert self.path is not None

    def write(self, iterator: Iterator[Row]) -> WriterCommitMessage:
        from pyspark import TaskContext
        context = TaskContext.get()
        partition_id = context.partitionId()
        cnt = 0
        for _ in iterator:
            cnt += 1
        return SimpleCommitMessage(partition_id=partition_id, count=cnt)

    def commit(self, messages: List[Optional[SimpleCommitMessage]], batchId: int) -> None:
        status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
        with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
            file.write(json.dumps(status) + "\n")

    def abort(self, messages: List[Optional[SimpleCommitMessage]], batchId: int) -> None:
        with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
            file.write(f"failed in batch {batchId}")

这个模式的核心是:

  • write() 负责分区级写出
  • commit() 负责 batch 成功后的汇总处理
  • abort() 负责 batch 失败后的兜底逻辑

8. 序列化要求一定要注意

官方专门强调:DataSourceDataSourceReaderDataSourceWriterDataSourceStreamReaderDataSourceStreamWriter 以及它们的方法,都必须能够被 pickle 序列化。

因此有一个很重要的实践:

方法内部用到的库,尽量在方法内部导入。

比如官方示例中,TaskContext 不是在文件开头导入,而是在 read()write() 方法内部导入:

python 复制代码
def read(self, partition):
    from pyspark import TaskContext
    context = TaskContext.get()

这是为了避免序列化问题。

9. 如何注册和使用 Python Data Source

9.1 注册

python 复制代码
spark.dataSource.register(FakeDataSource)

9.2 读取

使用默认 schema:

python 复制代码
spark.read.format("fake").load().show()

使用自定义 schema:

python 复制代码
spark.read.format("fake").schema("name string, company string").load().show()

传入 option:

python 复制代码
spark.read.format("fake").option("numRows", 5).load().show()

这些都是官方给出的直接用法。

9.3 写出

写出时要指定 mode(),支持 appendoverwrite

python 复制代码
df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()

10. 在 Structured Streaming 里使用

注册之后,同一个 Python Data Source 也可以用于 readStream()writeStream()

10.1 作为流式 source

python 复制代码
query = spark.readStream.format("fake").load() \
    .writeStream.format("console").start()

10.2 同时作为流式 source 和 sink

python 复制代码
query = spark.readStream.format("fake").load() \
    .writeStream.format("fake").start("/output_path")

这意味着你可以完全用 Python 打通一个自定义流式读写链路。

11. Arrow Batch 支持:性能提升的重点

官方文档最后一部分特别强调了一点:Python Data Source Reader 支持直接产出 Arrow Batch,这能显著提升性能,甚至在大数据场景下带来一个数量级的提升。

实现方式很简单:在 read() 方法里直接 yield pyarrow.RecordBatch

python 复制代码
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa

class ArrowBatchDataSource(DataSource):
    @classmethod
    def name(cls):
        return "arrowbatch"

    def schema(self):
        return "key int, value string"

    def reader(self, schema: str):
        return ArrowBatchDataSourceReader(schema, self.options)

class ArrowBatchDataSourceReader(DataSourceReader):
    def __init__(self, schema, options):
        self.schema = schema
        self.options = options

    def read(self, partition):
        keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
        values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
        schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
        record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
        yield record_batch

    def partitions(self):
        return [InputPartition(i) for i in range(1)]

spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()
spark.dataSource.register(ArrowBatchDataSource)

df = spark.read.format("arrowbatch").load()
df.show()

这部分非常值得关注,因为它说明 Python Data Source API 不只是"能用",而且在配合 Arrow 后可以更快。

12. 使用时的几个注意点

官方文档还给了几条很实用的说明:

12.1 名称冲突时,内置和 Scala/Java Data Source 优先

如果 Python Data Source 和内置数据源或 Scala/Java 数据源同名,默认优先解析后者。因此定义名字时,尽量避免冲突。

12.2 可以重复注册,后注册会覆盖前注册

这对调试很方便,但也意味着线上环境要注意命名和注册顺序。

12.3 支持自动注册

官方提到,如果你在顶层模块里把数据源导出为 DefaultSource,并且模块名前缀是 pyspark_,就可以自动注册。

13. 总结

Python Data Source API 的价值很直接:它让 Spark 4.0 开始真正具备了"用 Python 写自定义数据源"的能力。以前这类扩展大多要落到 JVM 侧,现在很多批处理和流处理场景都可以直接用 Python 完成。它最重要的几个亮点是:

  • 批处理读写可扩展
  • 流处理读写可扩展
  • spark.read / write / readStream / writeStream 体系自然融合
  • 支持 Arrow Batch 提升性能

如果你在做自定义 Connector、测试数据源、内部平台适配层,或者想快速接入非标准数据来源,这个 API 会非常值得深入。

相关推荐
2301_809204704 小时前
JavaScript中严格模式use-strict对引擎解析的辅助.txt
jvm·数据库·python
zjy277774 小时前
mysql如何选择合适的索引类型_mysql索引设计实战
jvm·数据库·python
Aaswk4 小时前
Java Lambda 表达式与流处理
java·开发语言·python
万邦科技Lafite5 小时前
京东item_get接口实战案例:实时商品价格监控全流程解析
java·开发语言·数据库·python·开放api·淘宝开放平台
Cyber4K6 小时前
【Python专项】进阶语法-系统资源监控与数据采集(1)
开发语言·python·php
苍煜7 小时前
Java开发IO零基础吃透:BIO、NIO、同步异步、阻塞非阻塞
java·python·nio
AllData公司负责人8 小时前
通过Postgresql同步到Doris,全视角演示AllData数据中台核心功能效果,涵盖:数据入湖仓,数据同步,数据处理,数据服务,BI可视化驾驶舱
java·大数据·数据库·数据仓库·人工智能·python·postgresql
Flittly9 小时前
【LangGraph新手村系列】(5)时间旅行:浏览历史、分叉时间线与修改过去
python·langchain
2301_782040459 小时前
CSS Flex布局中如何实现导航栏与Logo的左右分布_利用justify-content- space-between
jvm·数据库·python
yaoxin5211239 小时前
400. Java 文件操作基础 - 使用 Buffered Stream I/O 读取文本文件
java·开发语言·python