1. 什么是 Python Data Source API
Python Data Source API 是 Spark 4.0 引入的新能力,它允许开发者在 Python 中直接实现自定义数据源和数据写出逻辑。换句话说,你可以像实现一个插件一样,为 Spark 增加新的读取来源和写出目标,而不必先写 Scala/Java 版本。
它支持的能力包括:
- 批处理读取
- 批处理写入
- 流式读取
- 流式写入
官方给出的能力映射关系很清晰:
- 批处理读:实现
reader() - 批处理写:实现
writer() - 流式读:实现
streamReader()或simpleStreamReader() - 流式写:实现
streamWriter()
2. 一个最简单的自定义数据源
先看一个最小示例:定义一个只返回两行数据的数据源。
python
from typing import Iterator, Tuple
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
class SimpleDataSource(DataSource):
@classmethod
def name(cls) -> str:
return "simple"
def schema(self) -> StructType:
return StructType([
StructField("name", StringType()),
StructField("age", IntegerType())
])
def reader(self, schema: StructType) -> DataSourceReader:
return SimpleDataSourceReader()
class SimpleDataSourceReader(DataSourceReader):
def read(self, partition: InputPartition) -> Iterator[Tuple]:
yield ("Alice", 20)
yield ("Bob", 30)
注册后即可直接读取:
python
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark.dataSource.register(SimpleDataSource)
spark.read.format("simple").load().show()
输出结果:
text
+-----+---+
| name|age|
+-----+---+
|Alice| 20|
| Bob| 30|
+-----+---+
这个例子说明了 Python Data Source API 的核心思路:定义名字、定义 schema、实现 reader,然后注册给 Spark 使用。
3. DataSource 的基本结构
一个完整的 Python Data Source,通常要继承 DataSource,然后根据你的能力实现对应方法。官方示例中给出了一份更完整的定义:
python
from typing import Union
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
DataSourceStreamReader,
DataSourceStreamWriter,
DataSourceWriter
)
from pyspark.sql.types import StructType
class FakeDataSource(DataSource):
@classmethod
def name(cls) -> str:
return "fake"
def schema(self) -> Union[StructType, str]:
return "name string, date string, zipcode string, state string"
def reader(self, schema: StructType) -> DataSourceReader:
return FakeDataSourceReader(schema, self.options)
def writer(self, schema: StructType, overwrite: bool) -> DataSourceWriter:
return FakeDataSourceWriter(self.options)
def streamReader(self, schema: StructType) -> DataSourceStreamReader:
return FakeStreamReader(schema, self.options)
def streamWriter(self, schema: StructType, overwrite: bool) -> DataSourceStreamWriter:
return FakeStreamWriter(self.options)
这段代码已经覆盖了:
- 批量读
- 批量写
- 流式读
- 流式写
如果你只需要其中一种能力,只实现对应方法即可。
4. 批处理 Reader 怎么写
官方示例里的 FakeDataSourceReader 使用 faker 库按 schema 生成假数据。
python
from typing import Dict
from pyspark.sql.datasource import DataSourceReader
from pyspark.sql.types import StructType
class FakeDataSourceReader(DataSourceReader):
def __init__(self, schema: StructType, options: Dict[str, str]):
self.schema = schema
self.options = options
def read(self, partition):
from faker import Faker
fake = Faker()
num_rows = int(self.options.get("numRows", 3))
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
row.append(value)
yield tuple(row)
这里有两个关键点:
options里的值全部都是字符串,所以要自己做类型转换- 依赖库
faker是在方法内部导入的,这和后面要说的序列化要求有关
5. 批处理 Writer 怎么写
批处理写出要实现 DataSourceWriter。官方示例里,Writer 会统计每个分区写出的记录数,并在 commit() 中输出总行数。
python
from dataclasses import dataclass
from typing import Iterator, List
from pyspark.sql import Row
from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage
@dataclass
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeDataSourceWriter(DataSourceWriter):
def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = sum(1 for _ in rows)
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages: List[SimpleCommitMessage]) -> None:
total_count = sum(message.count for message in messages)
print(f"Total number of rows: {total_count}")
def abort(self, messages: List[SimpleCommitMessage]) -> None:
failed_count = sum(message is None for message in messages)
print(f"Number of failed tasks: {failed_count}")
这类写法很适合:
- 统计落盘条数
- 汇总分区写出结果
- 自定义成功/失败处理逻辑
6. 流式 Reader 怎么写
6.1 标准流式 Reader
如果你的流式数据源需要自己管理 offset、分区规划、提交状态,可以实现 DataSourceStreamReader。官方示例中每个 micro-batch 生成 2 行数据。
python
from typing import Iterator, Tuple
from pyspark.sql.datasource import DataSourceStreamReader, InputPartition
class RangePartition(InputPartition):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
class FakeStreamReader(DataSourceStreamReader):
def __init__(self, schema, options):
self.current = 0
def initialOffset(self) -> dict:
return {"offset": 0}
def latestOffset(self) -> dict:
self.current += 2
return {"offset": self.current}
def partitions(self, start: dict, end: dict) -> list[InputPartition]:
return [RangePartition(start["offset"], end["offset"])]
def commit(self, end: dict) -> None:
pass
def read(self, partition) -> Iterator[Tuple]:
for i in range(partition.start, partition.end):
yield (i, str(i))
6.2 简化版流式 Reader
如果数据源吞吐量不高,也不需要复杂分区,可以实现 SimpleDataSourceStreamReader。官方说明里明确说:streamReader() 和 simpleStreamReader() 二选一即可,前者优先。
python
from typing import Iterator, Tuple
from pyspark.sql.datasource import SimpleDataSourceStreamReader
class FakeSimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self) -> dict:
return {"offset": 0}
def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
return (it, {"offset": start_idx + 2})
def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
return iter([(i,) for i in range(start["offset"], end["offset"])])
def commit(self, end: dict) -> None:
pass
7. 流式 Writer 怎么写
流式写出对应的是 DataSourceStreamWriter。官方示例里,它把每个 micro-batch 的元信息写到本地路径中。
python
from typing import Iterator, List, Optional
from pyspark.sql import Row
from pyspark.sql.datasource import DataSourceStreamWriter, WriterCommitMessage
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeStreamWriter(DataSourceStreamWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")
assert self.path is not None
def write(self, iterator: Iterator[Row]) -> WriterCommitMessage:
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = 0
for _ in iterator:
cnt += 1
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages: List[Optional[SimpleCommitMessage]], batchId: int) -> None:
status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
file.write(json.dumps(status) + "\n")
def abort(self, messages: List[Optional[SimpleCommitMessage]], batchId: int) -> None:
with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
file.write(f"failed in batch {batchId}")
这个模式的核心是:
write()负责分区级写出commit()负责 batch 成功后的汇总处理abort()负责 batch 失败后的兜底逻辑
8. 序列化要求一定要注意
官方专门强调:DataSource、DataSourceReader、DataSourceWriter、DataSourceStreamReader、DataSourceStreamWriter 以及它们的方法,都必须能够被 pickle 序列化。
因此有一个很重要的实践:
方法内部用到的库,尽量在方法内部导入。
比如官方示例中,TaskContext 不是在文件开头导入,而是在 read() 或 write() 方法内部导入:
python
def read(self, partition):
from pyspark import TaskContext
context = TaskContext.get()
这是为了避免序列化问题。
9. 如何注册和使用 Python Data Source
9.1 注册
python
spark.dataSource.register(FakeDataSource)
9.2 读取
使用默认 schema:
python
spark.read.format("fake").load().show()
使用自定义 schema:
python
spark.read.format("fake").schema("name string, company string").load().show()
传入 option:
python
spark.read.format("fake").option("numRows", 5).load().show()
这些都是官方给出的直接用法。
9.3 写出
写出时要指定 mode(),支持 append 和 overwrite。
python
df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()
10. 在 Structured Streaming 里使用
注册之后,同一个 Python Data Source 也可以用于 readStream() 和 writeStream()。
10.1 作为流式 source
python
query = spark.readStream.format("fake").load() \
.writeStream.format("console").start()
10.2 同时作为流式 source 和 sink
python
query = spark.readStream.format("fake").load() \
.writeStream.format("fake").start("/output_path")
这意味着你可以完全用 Python 打通一个自定义流式读写链路。
11. Arrow Batch 支持:性能提升的重点
官方文档最后一部分特别强调了一点:Python Data Source Reader 支持直接产出 Arrow Batch,这能显著提升性能,甚至在大数据场景下带来一个数量级的提升。
实现方式很简单:在 read() 方法里直接 yield pyarrow.RecordBatch。
python
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa
class ArrowBatchDataSource(DataSource):
@classmethod
def name(cls):
return "arrowbatch"
def schema(self):
return "key int, value string"
def reader(self, schema: str):
return ArrowBatchDataSourceReader(schema, self.options)
class ArrowBatchDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema = schema
self.options = options
def read(self, partition):
keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
yield record_batch
def partitions(self):
return [InputPartition(i) for i in range(1)]
spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()
spark.dataSource.register(ArrowBatchDataSource)
df = spark.read.format("arrowbatch").load()
df.show()
这部分非常值得关注,因为它说明 Python Data Source API 不只是"能用",而且在配合 Arrow 后可以更快。
12. 使用时的几个注意点
官方文档还给了几条很实用的说明:
12.1 名称冲突时,内置和 Scala/Java Data Source 优先
如果 Python Data Source 和内置数据源或 Scala/Java 数据源同名,默认优先解析后者。因此定义名字时,尽量避免冲突。
12.2 可以重复注册,后注册会覆盖前注册
这对调试很方便,但也意味着线上环境要注意命名和注册顺序。
12.3 支持自动注册
官方提到,如果你在顶层模块里把数据源导出为 DefaultSource,并且模块名前缀是 pyspark_,就可以自动注册。
13. 总结
Python Data Source API 的价值很直接:它让 Spark 4.0 开始真正具备了"用 Python 写自定义数据源"的能力。以前这类扩展大多要落到 JVM 侧,现在很多批处理和流处理场景都可以直接用 Python 完成。它最重要的几个亮点是:
- 批处理读写可扩展
- 流处理读写可扩展
- 与
spark.read / write / readStream / writeStream体系自然融合 - 支持 Arrow Batch 提升性能
如果你在做自定义 Connector、测试数据源、内部平台适配层,或者想快速接入非标准数据来源,这个 API 会非常值得深入。