【从零开始】13. 数据增强(Data Augmentation)

书接上回,上一章我们简单地过了一遍如何获取训练数据并通过代码将数据保存到 Elasticsearch 中了,但是保存下来后发现数据量还是不够(总数据量 22w+),这个时候可以选择采用数据增强技术来补充训练数据。

同理,这次我也是在 Modelscope、和鲸、飞桨、天池等多个平台获取跟中药相关的开放数据集,由于这次数据集不再是直接的问答数据而是一些内容详情,因此无法直接使用。为了方便数据整理我先按分类将数据集放入最为熟悉的 MySQL 中进行数据"初加工"。之前的文章中有讲过如何处理:

【AutoML】AutoKeras 数据清洗与简单提纯

处理起来大同小异,就不再叙述了。之后再通过 Python 脚本提交给大模型进行数据增强了。伪代码如下所示:

python 复制代码
...

api = ApiUtil()
myclient_local = MysqlUtil()

class MysqlToTurning:

    def setup_data(self,params,source_table,batch_size,call_func):
        """
        数据处理公共函数
        """
        ...
     
def _thread_genrate_turning(prompt_str, insert_batch, source_table):
    """
    查询大模型并组装返回训练用问答对函数
    """
    ...

def setup_ancient_books_data(results, source_table):
    """
    处理中药古籍数据函数
    """
    ...

def setup_herbal_medicines_data(results, source_table):
    """
    处理中药基础数据函数
    """
    ...

def setup_medicine_formulas_data(results, source_table):
    """
    处理中药方剂数据函数
    """
    ...

def setup_pharmacopeia_data(results, source_table):
    """
    处理药典、法律法规数据函数
    """
    ...

mtt = MysqlToTurning()
    
if __name__ == "__main__":
    threads = []

    # 设定 mysql 中表与字段的对应关系,以及动态调用那个处理函数
    params_array = [
        {
            "table_name":"my_ancient_books", # mysql 表
            "table_fields": ["book_name","book_author","book_dynasty","book_release","book_content"], # 表中字段
            "function_call": setup_ancient_books_data # 对应的处理函数
        },
        {
            "table_name":"my_herbal_medicines",
            "table_fields": ["variety_name","variety_data_type","variety_source","variety_content"],
            "function_call": setup_herbal_medicines_data
        },
        {
            "table_name":"my_medicine_formulas",
            "table_fields": ["formula_name","formula_data_type","formula_source","formulas_content"],
            "function_call": setup_medicine_formulas_data
        },
        {
            "table_name":"my_pharmacopeia",
            "table_fields": ["pharmacopeiacol_name","pharmacopeiacol_data_type","pharmacopeia_source","pharmacopeiacol_content"],
            "function_call": setup_pharmacopeia_data
        }
    ]
    
    # 创建多线程进行处理
    for params in params_array:
        thread =threading.Thread(
                target=mtt.setup_data, 
                args=(params["table_fields"],params["table_name"],5,params["function_call"]),
                daemon=True
                )
        threads.append(thread)
        thread.start()
        
    # 等待多线程结束,不然程序会立刻关闭
    for thread in threads:
        thread.join()

多线程先调用 setup_data 函数,并根据 setup_data 函数中的传参中的"function_call"字段决定调用那个特殊函数,而在 setup_ancient_books_data、setup_herbal_medicines_data、setup_medicine_formulas_data和setup_pharmacopeia_data 函数中都会调用 _thread_genrate_turning 函数来进行大模型访问,之后大模型就会返回对应的数据集,解析数据集后进行收集,并最终由 setup_data 函数进行输出插入操作,至此完成整个处理过程。

接下来将一个一个函数进行分析。

setup_data

python 复制代码
def setup_data(self,params,source_table,batch_size,call_func):
    
    # 动态的 sql 语句拼接
    search_sql = f"SELECT `id`,`{'`,`'.join(params)}` FROM {source_table} WHERE `generate_flag` = 0 LIMIT {batch_size}"
    
    # 动态的 insert 语句 
    insert_turning_sql = "insert into my_fine_turning_datas(`content`,`sources`) values(%s,%s)"
    
    # 根据动态 sql 进行查询
    result_resp = myclient_local.query_by_list(search_sql)
    while len(result_resp) > 0:
        
        # 若发现存在未转换的数据,则调用指定处理函数进行处理
        # 以中药古籍为例,call_func 参数传入的是 setup_ancient_books_data 函数,则这里处理的就是 setup_ancient_books_data 函数
        # 处理完成后将得到更新 id 值和批量插入数据
        update_id_batch,insert_batch = call_func(result_resp,source_table)
        
        # 若存在需要批量插入数据
        if insert_batch:    
            # 则进行批量插入
            insert_counter = myclient_local.batch_save_or_update(insert_turning_sql, insert_batch)
            
            # 插入完成则更新原数据表中状态
            if insert_counter > 0:
                update_id_batch_str = "','".join(update_id_batch)
                update_sql = f"update {source_table} set `generate_flag` = 1 where `id` in ('{update_id_batch_str}')"
                update_counter,_ = myclient_local.save_or_update(update_sql)
                logger.info(f"共生成{insert_counter}条数据,并更新了"{source_table}"表{update_counter}条数据的状态")   
        
        # 进行第二批次的查询看是否还有未处理数据
        result_resp = myclient_local.query_by_list(search_sql) 

setup_ancient_books_data(样例)

python 复制代码
def setup_ancient_books_data(results, source_table):
    
    insert_batch = []
    update_id_batch = []
    search_threads = []
    
    # 遍历查询数据集并获取字段数据
    for id, book_name, book_author, book_dynasty, book_release, book_content in results:
        # 收集需要更新的 id
        update_id_batch.append(str(id))
        
        # 组装中文 json 字符串
        montage = f"{{\"书名\":\"{book_name}\",\"朝代\":\"{book_dynasty}\",\"作者\":\"{book_author}\",\"出版时间\":\"{book_release}\",\"书中内容\":\"{book_content}\"}}"
        
        # 将 json 字符串传入 get_question_and_answer_prompts 函数组装大模型 prompt 字符串
        prompt_str = CU.get_question_and_answer_prompts(montage)
        
        # 多线程启动调用大模型接口处理
        search_thread = threading.Thread(
            target=_thread_genrate_turning,
            args=(prompt_str, insert_batch, source_table),
            daemon=True
            )
        search_threads.append(search_thread)
        search_thread.start()
    
    # 等待一个批次所有处理完成
    for search_thread in search_threads:
        search_thread.join()
    
    # 返回更新 id 数组和训练数据集
    return update_id_batch, insert_batch

如上图所示,setup_ancient_books_data 函数调用了 get_question_and_answer_prompts 函数来组装 prompt 字符串,该函数内容如下:

python 复制代码
def get_question_and_answer_prompts(message):
    return f"""
        **输入信息:**
        【内容片段】:{message}

        **任务指令:**
        请根据以上【内容片段】提炼10个独特且多样化的中药问答对。
        问答内容必须完全基于【内容片段】中的信息。
        问题设计应覆盖:核心理论、方剂药材、诊治方法、性味归经、功效作用、历史背景、适用病症或人群、作用机制或原理、术语解释等。
        
        **输出格式:**
        [
            {{"question": "问题1","answer": "答案1"}},
            {{"question": "问题2","answer": "答案2"}},
            {{"question": "问题3","answer": "答案3"}},
            {{"question": "问题4","answer": "答案4"}},
            {{"question": "问题5","answer": "答案5"}},
            {{"question": "问题6","answer": "答案6"}},
            {{"question": "问题7","answer": "答案7"}},
            {{"question": "问题8","answer": "答案8"}},
            {{"question": "问题9","answer": "答案9"}},
            {{"question": "问题10","answer": "答案10"}}
        ]
    """

除了 get_question_and_answer_prompts 外,我们发现在启动多线程的时候还调用了一个名为 _thread_genrate_turning 的函数,

python 复制代码
def _thread_genrate_turning(prompt_str, insert_batch, source_table):
    # 为了保证返回内容符合要求这里采用了 while 循环作为重试机制
    while True:
        # chat_with_sync 函数是调用第三方接口的封装函数
        resp = api.chat_with_sync(model_params, prompt_str)
        try:   
            # 得到大模型的返回后
            resp_array = json.loads(resp)
            if resp_array:
                # 整理数据集
                insert_batch.extend((json.dumps(resp, ensure_ascii=False), source_table) for resp in resp_array)
                break
            else:
                # 增加停顿时间避免访问过快
                time.sleep(random.randint(1, 5))
        except Exception:
            # 增加停顿时间避免访问过快
            time.sleep(random.randint(1, 5))

其他函数如 setup_herbal_medicines_data、setup_medicine_formulas_data 和 setup_pharmacopeia_data 基本都是大同小异,只是字段不同需要单独拿出来处理而已。

通过以上的方式我们就能够快速地采用多线程同时处理多个类型的数据了。这里又要感谢硅基流动给予我这种贫苦开发者福利。上述的 chat_with_sync 函数调用的就是免费的硅基流动接口。配置如下:

yaml 复制代码
silicon:
  agent:
    generate_qaa:
      prompt: "你是一位精通中药学的知识专家,负责从古籍中提炼知识并生成问答。

        你的任务是根据我提供的【古籍名称】和【内容片段】,生成10个独特且多样化的问答对。

        请确保每个问答都紧密围绕【内容片段】,且答案可直接从片段中得出或合理推断。

        注意:你的输出**必须是一个只包含JSON数组的文本**。不允许包含任何额外的文字、注释、解释、Markdown代码块(如```json```)、或任何非JSON内容。"
      model: THUDM/GLM-4-9B-0414
      options:
        max_tokens: 6144
        temperature: 0.4
        top_p: 0.95
        frequency_penalty: 0.1

这里调用的是免费的 THUDM/GLM-4-9B-0414 模型。虽然返回内容没有 Qwen3 返回的内容详实,但胜在速度上占优(毕竟有大量的数据需要转换),因此选择使用它。

到这里就结束了吗?No。

在做数据增强的过程中会生成多条重复的记录出来(大模型机制 + kv cache的作用下,在大批量生成时难免会出现这种情况)。这个时候就需要补充多一个函数用于定时处理冗余数据,如下图:

python 复制代码
def clean_dulpicate_data(self):
    search_sql = f"""
            SELECT content, MIN(id) AS id
            FROM {CU.TMP_MYSQL_TURNING_DATA_TABLE}
            GROUP BY content
            HAVING COUNT(1) > 1
            LIMIT 1000
        """

    # sql 采用分页获取的方式查询(避免一次性加载太多数据到内存)
    results = myclient_local.query_by_list(search_sql)
    while len(results) > 0:
        # 从第二个冗余数据的 id 开始组装
        ids_result = [row[1] for row in results]

        # 批量删除冗余数据
        delete_sql = f"DELETE FROM {CU.TMP_MYSQL_TURNING_DATA_TABLE} WHERE id IN ({', '.join(map(str, ids_result))})"
        counter, _ = myclient_local.save_or_update(delete_sql)
        if counter > 0:
            logger.info(f"Deleted {counter} duplicate records from {CU.TMP_MYSQL_TURNING_DATA_TABLE}.")

            # 下一个批次查询
            results = myclient_local.query_by_list(search_sql)

但接下来的问题来了。在执行的 main 函数中我已经采用了多线程启动并等待直至所有线程完成,难道我这个删除重复数据的函数执行在之后执行吗?这样就起不到定时器分批处理的效果了。既然这样就将 main 函数稍微改造一下吧。

python 复制代码
mtt = MysqlToTurning()

# 定时器设置每两个小时执行一次数据清理工作
schedule.every(2).hours.do(mtt.clean_dulpicate_data)
    
if __name__ == "__main__":
    threads = []
    params_array = [...]
    
    # 数据增强多线程处理
    for params in params_array:
        thread =threading.Thread(
                target=mtt.setup_data, 
                args=(params["table_fields"],params["table_name"],5,params["function_call"]),
                daemon=True
                )
        threads.append(thread)
        thread.start()
        
    # 创建函数 run_scheduler 用于指定线程使用,用于监听是否到点执行定时器
    def run_scheduler():
        
        while True:
            schedule.run_pending()
            time.sleep(1)

    # 将 run_scheduler 当做另一个线程执行
    scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
    scheduler_thread.start()
    
    # 在这个地方等待数据增强线程结束
    for thread in threads:
        thread.join()
    
    # 在数据增强线程结束后再等待定时器线程结束
    scheduler_thread.join()

如上图所示,两组线程之间是嵌套关系,这样就能够同时执行两组不同操作的线程了...

在经过 5 x 24 小时的不间断运行,共生成了 200w+ 的训练数据(一开始访问过快了,后面调低一点速度了)。

数据生成完毕后需要将这部分后续用于训练的数据追加到原问答数据集存放的 Elasticsearch 中,如下图:

python 复制代码
def mysql_turning_data_to_es(self):
    # 分批查询 sql    
    search_sql = f"""
        select id,content,sources
        from {CU.TMP_MYSQL_TURNING_DATA_TABLE}
        where generate_flag = 0
        limit 10000
    """
    counter = 1
    results = self.mysql.query_by_list(search_sql)

    # 循环进行多批次插入
    while len(results)>0:
        logger.info(f"Processing batch {counter}")
        update_ids = [] # 可转换数组
        update_not_ids = [] # 无法转换数组
        qa_array = []

        for id, content, sources in results:
            try:
                # 尝试将数据转换为 json 对象
                json_content = json.loads(content)

                # 如果转换成功,则放入批处理数组中,若转换过程中出错,则将 id 记录到无法转换数组中
                qa_array.append({"question": json_content["question"], "answer": json_content["answer"],"data_source": sources})

                # 将 id 放入可转换数组中
                update_ids.append(id)
            except:
                update_not_ids.append(id)

        # 若批处理数组有数据
        if qa_array:
            # 批量保存到 es
            self._save_to_es(qa_array)

            # 并且更新可转换数组中 id 记录的状态为 1
            if update_ids:
                update_sql = f"""
                    update {CU.TMP_MYSQL_TURNING_DATA_TABLE}
                    set generate_flag = 1
                    where id in ({','.join(map(str, update_ids))})
                """
                self.mysql.save_or_update(update_sql)

        # 若无法转换数组中存在 id,则对这些 id 进行状态 2 的更新
        if update_not_ids:
            update_sql = f"""
                    update {CU.TMP_MYSQL_TURNING_DATA_TABLE}
                    set generate_flag = 2
                    where id in ({','.join(map(str, update_not_ids))})
                """
            self.mysql.save_or_update(update_sql)    

        # 为下一次循环查询是否还存在没有处理的记录
        results = self.mysql.query_by_list(search_sql)
        counter += 1

以上代码均发布到 brain-mix 项目中,欢迎各位的指导。

gitee: gitee.com/yzh0623/bra...

github:github.com/yzh0623/bra...

下一章将继续讲解数据评分筛选处理,敬请留意。

(未完待续...)

相关推荐
好易学·数据结构3 小时前
可视化图解算法60: 矩阵最长递增路径
数据结构·算法·leetcode·力扣·递归·回溯算法·牛客
封奚泽优3 小时前
班级互动小程序(Python)
python·deepseek
大锦终3 小时前
【算法】栈专题
数据结构·c++·算法·leetcode
天选之女wow3 小时前
【代码随想录算法训练营——Day6(Day5周日休息)】哈希表——242.有效的字母异位词、349.两个数组的交集、202.快乐数、1.两数之和
数据结构·算法·leetcode·散列表
MediaTea3 小时前
Python:math 库函数手册(双曲函数)
开发语言·python
JJJJ_iii3 小时前
【左程云算法07】队列和栈-链表数组实现
数据结构
枫叶V3 小时前
用 FastAPI 实现大文件分片上传与断点续传(含可运行示例与客户端脚本,仅供参考)
python·fastapi
薛定谔的算法3 小时前
JavaScript队列实现详解:从基础到性能优化
javascript·数据结构·算法