编写一个 Python 函数,将 clusters.csv
文件,格式为:第一列为编号,第二列为聚类的代表序列,第三列为所有的其它同源序列,逗号隔开(TM-align蛋白质聚类数据格式转化-CSDN博客),划分为 train
, valid
, 和 test
数据集。通常,我们可以根据一定的比例(比如 70% 用于 train
,15% 用于 valid
,15% 用于 test
)进行划分。
以下是一个将 clusters.csv
划分为三个数据集的示例函数:
import csv
import random
import os
def split_dataset(input_csv, train_csv, valid_csv, test_csv, train_ratio=0.7, valid_ratio=0.15, test_ratio=0.15, seed=42):
"""
将clusters.csv文件划分为train, valid, test数据集。
:param input_csv: 输入的csv文件路径
:param train_csv: 输出的训练集csv文件路径
:param valid_csv: 输出的验证集csv文件路径
:param test_csv: 输出的测试集csv文件路径
:param train_ratio: 训练集比例,默认0.7
:param valid_ratio: 验证集比例,默认0.15
:param test_ratio: 测试集比例,默认0.15
:param seed: 随机种子,确保划分结果可复现
"""
# 检查比例是否为1
assert train_ratio + valid_ratio + test_ratio == 1, "训练、验证和测试集的比例必须加起来等于1"
# 设置随机种子
random.seed(seed)
# 读取原始数据
with open(input_csv, 'r') as csvfile:
reader = list(csv.reader(csvfile))
header = reader[0] # 读取标题行
data = reader[1:] # 读取数据部分
# 打乱数据
random.shuffle(data)
# 计算划分的索引
total_size = len(data)
train_size = int(total_size * train_ratio)
valid_size = int(total_size * valid_ratio)
train_data = data[:train_size]
valid_data = data[train_size:train_size + valid_size]
test_data = data[train_size + valid_size:]
# 定义一个辅助函数来写csv文件
def write_csv(output_csv, data):
with open(output_csv, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(header) # 写入标题行
writer.writerows(data) # 写入数据
# 写入三个文件
write_csv(train_csv, train_data)
write_csv(valid_csv, valid_data)
write_csv(test_csv, test_data)
print(f"数据集已成功划分:\n训练集: {len(train_data)} 条记录\n验证集: {len(valid_data)} 条记录\n测试集: {len(test_data)} 条记录")
# 调用函数进行数据集划分
split_dataset(
input_csv='clusters.csv',
train_csv='train.csv',
valid_csv='valid.csv',
test_csv='test.csv'
)
解释:
-
函数参数:
input_csv
: 输入的clusters.csv
文件路径。train_csv
,valid_csv
,test_csv
: 输出的训练集、验证集、测试集文件路径。train_ratio
,valid_ratio
,test_ratio
: 数据集划分的比例,默认是 70% 训练集,15% 验证集,15% 测试集。seed
: 随机种子,确保每次划分的结果一致。
-
逻辑:
- 读取
clusters.csv
文件并将数据打乱。 - 按照指定比例计算每个数据集的大小。
- 将数据分别写入
train.csv
、valid.csv
和test.csv
文件中。
- 读取
-
使用:
- 将原始的
clusters.csv
文件作为输入,输出train.csv
、valid.csv
和test.csv
。
- 将原始的