python
复制代码
import json
from modelscope.msdatasets import MsDataset
def load_all_coig_cqia_subdatasets(dataset_name: str = "AI-ModelScope/COIG-CQIA"):
"""
加载 COIG-CQIA 数据集的所有子数据集
参考 train_sft.py 中的加载方法
Args:
dataset_name: 数据集名称
Returns:
所有子数据集的数据列表
"""
# 使用 train_sft.py 中定义的默认子集列表
default_subsets = [
'chinese_traditional',
'coig_pc',
'exam',
'finance',
'douban',
'human_value',
'logi_qa',
'ruozhiba',
'segmentfault',
'wiki',
'wikihow',
'xhs',
'zhihu'
]
print(f"使用默认数据集: {dataset_name}")
print(f"加载子集: {default_subsets}")
all_data = []
# 加载每个子数据集
for subset in default_subsets:
try:
print(f"正在加载子集: {subset}...")
subset_ds = MsDataset.load(dataset_name, subset_name=subset, split='train')
# 转换为HuggingFace Dataset格式(如果还不是)
if hasattr(subset_ds, 'to_hf_dataset'):
subset_ds = subset_ds.to_hf_dataset()
# 转换为列表
subset_data = list(subset_ds)
print(f"子集 {subset} 加载完成,大小: {len(subset_data)}")
all_data.extend(subset_data)
except Exception as e:
print(f"警告: 加载子集 {subset} 失败: {e}")
continue
if not all_data:
raise ValueError("所有子集加载失败,请检查数据集名称和网络连接")
print(f"总共加载了 {len(all_data)} 条数据")
return all_data
def convert_to_alpaca_format(example):
"""
将 COIG-CQIA 数据转换为 Alpaca 格式
Alpaca 格式: {"instruction": str, "input": str, "output": str}
COIG-CQIA 数据格式可能包含以下字段:
- instruction/question/prompt: 指令或问题
- input/context: 输入或上下文
- output/response/answer: 输出或回答
"""
# 处理不同类型的数据
if isinstance(example, str):
# 如果是字符串,尝试解析为 JSON
try:
example = json.loads(example)
except json.JSONDecodeError:
# 如果无法解析为 JSON,将整个字符串作为 instruction
return {
"instruction": example,
"input": "",
"output": ""
}
# 如果不是字典,尝试转换为字典或使用字符串表示
if not isinstance(example, dict):
# 如果是列表,尝试处理
if isinstance(example, list):
if len(example) > 0:
# 取第一个元素作为 instruction,最后一个作为 output
instruction = str(example[0]) if example else ""
output = str(example[-1]) if len(example) > 1 else ""
return {
"instruction": instruction,
"input": "",
"output": output
}
else:
return {
"instruction": "",
"input": "",
"output": ""
}
else:
# 其他类型,直接转换为字符串作为 instruction
return {
"instruction": str(example),
"input": "",
"output": ""
}
# 现在 example 应该是字典了
# 尝试多种可能的字段名来获取 instruction
instruction = (
example.get('instruction') or
example.get('question') or
example.get('prompt') or
example.get('input') or
example.get('query') or
''
)
# 尝试多种可能的字段名来获取 input
input_text = (
example.get('input') or
example.get('context') or
example.get('history') or
''
)
# 如果 instruction 和 input 都为空,尝试从 conversation 中提取
if not instruction and not input_text:
conversation = example.get('conversation') or example.get('messages') or []
if isinstance(conversation, list) and len(conversation) > 0:
# 取第一条消息作为 instruction
first_msg = conversation[0]
if isinstance(first_msg, dict):
instruction = first_msg.get('content') or first_msg.get('text') or ''
elif isinstance(first_msg, str):
instruction = first_msg
# 尝试多种可能的字段名来获取 output
output = (
example.get('output') or
example.get('response') or
example.get('answer') or
example.get('target') or
''
)
# 如果 output 为空,尝试从 conversation 中提取最后一条消息
if not output:
conversation = example.get('conversation') or example.get('messages') or []
if isinstance(conversation, list) and len(conversation) > 1:
# 取最后一条消息作为 output
last_msg = conversation[-1]
if isinstance(last_msg, dict):
output = last_msg.get('content') or last_msg.get('text') or ''
elif isinstance(last_msg, str):
output = last_msg
# 如果仍然没有 instruction,使用整个 example 的字符串表示
if not instruction and not input_text:
# 尝试从其他字段构建
for key in ['text', 'content', 'message']:
if key in example:
instruction = str(example[key])
break
return {
"instruction": instruction if instruction else "",
"input": input_text if input_text else "",
"output": output if output else ""
}
# 下载所有子数据集
print("=" * 60)
print("开始下载 COIG-CQIA 数据集的所有子数据集")
print("=" * 60)
dataset_name = "AI-ModelScope/COIG-CQIA"
all_data = load_all_coig_cqia_subdatasets(dataset_name)
if not all_data:
print("警告: 未能加载任何数据,请检查数据集名称和网络连接")
exit(1)
# 转换所有数据
print("\n" + "=" * 60)
print("正在转换为 Alpaca 格式...")
print("=" * 60)
# 打印第一条数据的类型和结构,用于调试
if all_data:
print(f"第一条数据的类型: {type(all_data[0])}")
if isinstance(all_data[0], dict):
print(f"第一条数据的键: {list(all_data[0].keys())[:10]}") # 只显示前10个键
elif isinstance(all_data[0], str):
print(f"第一条数据的前100个字符: {all_data[0][:100]}")
print()
alpaca_data = []
for i, example in enumerate(all_data):
try:
alpaca_item = convert_to_alpaca_format(example)
# 只保留有效的数据(至少要有 instruction 或 output)
if alpaca_item.get('instruction') or alpaca_item.get('output'):
alpaca_data.append(alpaca_item)
if (i + 1) % 100 == 0:
print(f"已处理 {i + 1}/{len(all_data)} 条数据")
except Exception as e:
print(f"处理第 {i} 条数据时出错: {e}")
print(f"数据类型: {type(example)}, 数据内容: {str(example)[:200]}")
continue
print(f"\n转换完成,共 {len(alpaca_data)} 条有效数据")
# 保存为 JSONL 文件
output_file_jsonl = 'coig_cqia_alpaca.jsonl'
print(f"\n正在保存到 {output_file_jsonl}...")
with open(output_file_jsonl, 'w', encoding='utf-8') as f:
for item in alpaca_data:
json.dump(item, f, ensure_ascii=False)
f.write('\n')
print(f"数据已成功保存到 {output_file_jsonl}")
print(f"文件包含 {len(alpaca_data)} 条记录")
# 保存为 JSON 文件(数组格式)
output_file_json = 'coig_cqia_alpaca.json'
print(f"\n正在保存到 {output_file_json}...")
with open(output_file_json, 'w', encoding='utf-8') as f:
json.dump(alpaca_data, f, ensure_ascii=False, indent=2)
print(f"数据已成功保存到 {output_file_json}")
print(f"文件包含 {len(alpaca_data)} 条记录")