目录
[2.1 加载数据](#2.1 加载数据)
[从ML libraries 库中读取(不支持并行读取)](#从ML libraries 库中读取(不支持并行读取))
[2.2 变换数据](#2.2 变换数据)
[Transforming batches](#Transforming batches)
[Shuffling rows](#Shuffling rows)
[Repartitioning data](#Repartitioning data)
[2.3 消费数据](#2.3 消费数据)
[1) 按行遍历](#1) 按行遍历)
[2.4 保存数据](#2.4 保存数据)
今天来带大家一起来学习下ray中对数据的操作,还是非常简洁的。
一、overview
基础代码
from typing import Dict
import numpy as np
import ray
# Create datasets from on-disk files, Python objects, and cloud storage like S3.
ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
# Apply functions to transform data. Ray Data executes transformations in parallel.
def compute_area(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
length = batch["petal length (cm)"]
width = batch["petal width (cm)"]
batch["petal area (cm^2)"] = length * width
return batch
transformed_ds = ds.map_batches(compute_area)
# Iterate over batches of data.
for batch in transformed_ds.iter_batches(batch_size=4):
print(batch)
# Save dataset contents to on-disk files or cloud storage.
transformed_ds.write_parquet("local:///tmp/iris/")
使用ray.data可以方便地从硬盘、python对象、S3上读取文件
最后写入云端
核心API:
-
简单变换(map_batches())
-
全局聚合和分组聚合(groupby())
-
Shuffling 操作 (random_shuffle(), sort(), repartition()).
二、核心概念
2.1 加载数据
从S3上读
import ray
#加载csv文件
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
print(ds.schema())
ds.show(limit=1)
#加载parquet文件
ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
#加载image
ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea/JPEGImages/")
# Text
ds = ray.data.read_text("s3://anonymous@ray-example-data/this.txt")
# binary
ds = ray.data.read_binary_files("s3://anonymous@ray-example-data/documents")
#tfrecords
ds = ray.data.read_tfrecords("s3://anonymous@ray-example-data/iris.tfrecords")
从本地读:
ds = ray.data.read_parquet("local:///tmp/iris.parquet")
-
处理压缩文件
ds = ray.data.read_csv(
"s3://anonymous@ray-example-data/iris.csv.gz",
arrow_open_stream_args={"compression": "gzip"},
)
其他读取方式
import ray
# 从python对象里获取
ds = ray.data.from_items([
{"food": "spam", "price": 9.34},
{"food": "ham", "price": 5.37},
{"food": "eggs", "price": 0.94}
])
ds = ray.data.from_items([1, 2, 3, 4, 5])
# 从numpy里获取
array = np.ones((3, 2, 2))
ds = ray.data.from_numpy(array)
# 从pandas里获取
df = pd.DataFrame({
"food": ["spam", "ham", "eggs"],
"price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_pandas(df)
# 从py arrow里获取
table = pa.table({
"food": ["spam", "ham", "eggs"],
"price": [9.34, 5.37, 0.94]
})
ds = ray.data.from_arrow(table)
读取分布式数据(spark)
import ray
import raydp
spark = raydp.init_spark(app_name="Spark -> Datasets Example",
num_executors=2,
executor_cores=2,
executor_memory="500MB")
df = spark.createDataFrame([(i, str(i)) for i in range(10000)], ["col1", "col2"])
ds = ray.data.from_spark(df)
ds.show(3)
从ML libraries 库中读取(不支持并行读取)
import ray.data
from datasets import load_dataset
# 从huggingface里读取(不支持并行读取)
hf_ds = load_dataset("wikitext", "wikitext-2-raw-v1")
ray_ds = ray.data.from_huggingface(hf_ds["train"])
ray_ds.take(2)
# 从TensorFlow中读取(不支持并行读取)
import ray
import tensorflow_datasets as tfds
tf_ds, _ = tfds.load("cifar10", split=["train", "test"])
ds = ray.data.from_tf(tf_ds)
print(ds)
从sql中读取
import mysql.connector
import ray
def create_connection():
return mysql.connector.connect(
user="admin",
password=...,
host="example-mysql-database.c2c2k1yfll7o.us-west-2.rds.amazonaws.com",
connection_timeout=30,
database="example",
)
# Get all movies
dataset = ray.data.read_sql("SELECT * FROM movie", create_connection)
# Get movies after the year 1980
dataset = ray.data.read_sql(
"SELECT title, score FROM movie WHERE year >= 1980", create_connection
)
# Get the number of movies per year
dataset = ray.data.read_sql(
"SELECT year, COUNT(*) FROM movie GROUP BY year", create_connection
)
Ray还支持从BigQuery和MongoDB中读取,篇幅问题,不赘述了。
2.2 变换数据
变换默认是lazy,直到遍历、保存、检视数据集时才执行
map
import os
from typing import Any, Dict
import ray
def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
row["filename"] = os.path.basename(row["path"])
return row
ds = (
ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple", include_paths=True)
.map(parse_filename)
)
flat_map
from typing import Any, Dict, List
import ray
def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
return [row] * 2
print(
ray.data.range(3)
.flat_map(duplicate_row)
.take_all()
)
# 结果:
# [{'id': 0}, {'id': 0}, {'id': 1}, {'id': 1}, {'id': 2}, {'id': 2}]
# 原先的元素都变成2个
Transforming batches
from typing import Dict
import numpy as np
import ray
def increase_brightness(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
batch["image"] = np.clip(batch["image"] + 4, 0, 255)
return batch
# batch_format:指定batch类型,可不加
ds = (
ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
.map_batches(increase_brightness, batch_format="numpy")
)
如果初始化较贵,使用类而不是函数,这样每次调用类的时候,进行初始化。类有状态,而函数没有状态。
并行度可以指定(min,max)来自由调整
Shuffling rows
import ray
ds = (
ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
.random_shuffle()
)
Repartitioning data
import ray
ds = ray.data.range(10000, parallelism=1000)
# Repartition the data into 100 blocks. Since shuffle=False, Ray Data will minimize
# data movement during this operation by merging adjacent blocks.
ds = ds.repartition(100, shuffle=False).materialize()
# Repartition the data into 200 blocks, and force a full data shuffle.
# This operation will be more expensive
ds = ds.repartition(200, shuffle=True).materialize()
2.3 消费数据
1) 按行遍历
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
for row in ds.iter_rows():
print(row)
2)按batch遍历
numpy、pandas、torch、tf使用不同的API遍历batch
# numpy
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_batches(batch_size=2, batch_format="numpy"):
print(batch)
# pandas
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
for batch in ds.iter_batches(batch_size=2, batch_format="pandas"):
print(batch)
# torch
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_torch_batches(batch_size=2):
print(batch)
# tf
import ray
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
tf_dataset = ds.to_tf(
feature_columns="sepal length (cm)",
label_columns="target",
batch_size=2
)
for features, labels in tf_dataset:
print(features, labels)
3)遍历batch时shuffle
只需要在遍历batch时增加local_shuffle_buffer_size参数即可。
非全局洗牌,但性能更好。
import ray
ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
for batch in ds.iter_batches(
batch_size=2,
batch_format="numpy",
local_shuffle_buffer_size=250,
):
print(batch)
4)为分布式并行训练分割数据
import ray
@ray.remote
class Worker:
def train(self, data_iterator):
for batch in data_iterator.iter_batches(batch_size=8):
pass
ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
workers = [Worker.remote() for _ in range(4)]
shards = ds.streaming_split(n=4, equal=True)
ray.get([w.train.remote(s) for w, s in zip(workers, shards)])
2.4 保存数据
保存文件
非常类似pandas保存文件,唯一的区别保存本地文件时需要加入local://前缀。
注意:如果不加local://前缀,ray则会将不同分区的数据写在不同节点上
import ray
ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
# local
ds.write_parquet("local:///tmp/iris/")
# s3
ds.write_parquet("s3://my-bucket/my-folder")
修改分区数
import os
import ray
ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
ds.repartition(2).write_csv("/tmp/two_files/")
print(os.listdir("/tmp/two_files/"))
将数据转换为python对象
import ray
ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
df = ds.to_pandas()
print(df)
将数据转换为分布式数据(spark)
import ray
import raydp
spark = raydp.init_spark(
app_name = "example",
num_executors = 1,
executor_cores = 4,
executor_memory = "512M"
)
ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
df = ds.to_spark(spark)