通过LLM多轮对话生成单元测试用例

通过LLM多轮对话生成单元测试用例

在采用 随机生成pytorch算子测试序列且保证算子参数合法 这种方法之前,曾通过本文的方法生成算子组合测试用例。目前所测LLM生成的代码均会出现BUG,且多次交互后仍不能解决.也许随着LLM的更新,这个问题会得到解决.记录备用。

代码

python 复制代码
import re
import os
import logging
import random
import numpy as np
import os
import re
import traceback
import subprocess
import tempfile
import copy
import requests
import json

import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'

os.environ["QIANFAN_AK"] = ""
os.environ["QIANFAN_SK"] = ""
os.environ['DASHSCOPE_API_KEY'] = 'sk-'
os.environ['MOONSHOT_API_KEY']="sk-"
os.environ['SPARKAI_APP_ID'] = ''
os.environ['SPARKAI_API_SECRET'] = ''
os.environ['SPARKAI_API_KEY'] = ''
os.environ['SPARKAI_DOMAIN'] = 'generalv3.5'
os.environ['ZhipuAI_API_KEY'] = ''
os.environ['YI_API_KEY']=""

logger = logging.getLogger('llm_logger')
logger.setLevel(logging.DEBUG)  # 设置日志级别
 
# 创建一个handler,用于写入日志文件
log_file = 'llm_opt.log'
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
 
# 创建一个handler,用于将日志输出到控制台
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
 
# 设置日志格式
formatter = logging.Formatter('%(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
 
# 将handlers添加到logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)

system_prompt="你是一位pytorch专家,现在需要编写各种测试程序,挖掘算子的潜在BUG"

question =f'''
背景描述:
1.为了测试pytorch不同算子组合时的精度是否正常,需要构建module级别的测试用例
2.尤其需要关注unsqueeze,repeat,permute,transpose,reshape,expand,view等维度变换算子的各种组合
3.以及在这些组合之后添加其它io或计算类的算子如(contiguous,matmul,mul,concat等)

需求:
1.你一次生成一个测试用例(pytorch module及测例),只包含cpu计算
2.之后,我会从的回复中提取出python代码,执行并将结果反馈给你
3.你根据我的反馈,预测性地生成下一个测试用例
4.我们通过多次交互,最大程度地挖掘出潜在的BUG

约束:
1.所有测试用例的代码放在一个```python ```中,方便提取
2.为了防止shape不匹配,建议在forward中计算shape,并根据当前的shape合理地设置下一个算子的参数
3.你每次提供的代码都必须是完整的,不要添加任何注释
4.测试代码只输出成功、失败或抛异常,不需要输出任何多余信息
5.特别需要注意矩阵乘维度是否匹配

如果你明白我的意思,请直接输出第一个测试用例
'''

def extract_and_run_python_code(markdown_text):
    pattern = re.compile(r'```python\n([^```].*?)\n```', re.DOTALL)
    code_blocks = pattern.findall(markdown_text)
    if len(code_blocks)==0:
        return "没有找到Python代码块。"
    results = []
    for code in code_blocks:
        try:
            with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
                temp_file.write(code.encode())
                temp_filename = temp_file.name
            result = subprocess.run(['python3', temp_filename], capture_output=True, text=True)    
            output=f"{result.stderr}{result.stdout}"
            results.append(output)
        except Exception as e:
            error_message = f"error:{traceback.format_exc()}"
            results.append(error_message)        
        finally:
            os.remove(temp_filename)
    return "".join(results)

class LLMInfer(object):
    def __init__(self, system_prompt,question,history_len=5):
        self.system_prompt=system_prompt
        self.question=question    
        self.history_len=history_len   
    def infer(self,user_input=None):
        pass    
    def reset(self):
        pass

class dashscope_llm(LLMInfer):
    def __init__(self, system_prompt, question):
        super().__init__(system_prompt, question)
        import dashscope
        dashscope.api_key=os.environ['DASHSCOPE_API_KEY'] 
        self.history=[]
        self.history.append({'role': 'system', 'content': self.system_prompt})
        self.history.append({'role': 'user', 'content': self.question})		
        
    def reset(self):
        if len(self.history)>self.history_len:
            self.history=self.history[:2] + self.history[-3:]

    def infer(self,user_input=None):
        from dashscope import Generation
        from http import HTTPStatus          
        if user_input:
            self.history.append({'role': 'user', 'content': user_input})
        response = Generation.call(model="qwen-plus", 
                                   messages=self.history,
                                   result_format='message')
        if response.status_code == HTTPStatus.OK:
            role=response.output.choices[0]['message']['role']
            content=response.output.choices[0]['message']['content']
            self.history.append({'role': role,'content': content})
            return content
        else:
            return None

class moonshot_llm(LLMInfer):
    def __init__(self, system_prompt, question):
        super().__init__(system_prompt, question)
        '''
        pip install --upgrade 'openai>=1.0'
        '''
        from openai import OpenAI
        self.client = OpenAI(
            api_key = os.environ['MOONSHOT_API_KEY'],
            base_url = "https://api.moonshot.cn/v1",
        )
        self.history=[]
        self.history.append({'role': 'system', 'content': self.system_prompt})
        self.history.append({'role': 'user', 'content': self.question})		
        
    def reset(self):
        if len(self.history)>self.history_len:
            self.history=self.history[:2] + self.history[-3:]

    def infer(self,user_input=None):      
        if user_input:
            self.history.append({'role': 'user', 'content': user_input})
        completion = self.client.chat.completions.create(
            model="moonshot-v1-128k",
            messages=self.history,
            temperature=0.3,
            top_p=0.1
        )
        role="assistant"
        content=completion.choices[0].message.content
        self.history.append({'role': role,'content': content})
        return content

class qianfan_llm(LLMInfer):
    def __init__(self, system_prompt, question):
        super().__init__(system_prompt, question)
        '''
        pip3 install qianfan
        '''
        self.history=[]
        #self.history.append({'role': 'system', 'content': self.system_prompt})
        self.history.append({'role': 'user', 'content': self.question})		
        
    def reset(self):
        if len(self.history)>self.history_len:
            self.history=self.history[:1] + self.history[-2:]

    def infer(self,user_input=None):    
        import qianfan  
        if user_input:
            self.history.append({'role': 'user', 'content': user_input})
        response = qianfan.ChatCompletion().do(endpoint="completions_pro", messages=self.history,
                                                temperature=0.7, top_p=0.8, penalty_score=1,                                             
                                                disable_search=False, enable_citation=False)
        role="assistant"
        content=response.body["result"]
        self.history.append({'role': role,'content': content})
        return content

class sparkai_llm(LLMInfer):
    def __init__(self, system_prompt, question):
        super().__init__(system_prompt, question)
        '''
        pip3 install --upgrade spark_ai_python
        '''
        from sparkai.llm.llm import ChatSparkLLM
        from sparkai.core.messages import ChatMessage
        self.spark = ChatSparkLLM(
            spark_api_url='wss://spark-api.xf-yun.com/v3.5/chat',
            spark_app_id=os.environ['SPARKAI_APP_ID'],
            spark_api_key=os.environ['SPARKAI_API_KEY'],
            spark_api_secret=os.environ['SPARKAI_API_SECRET'],
            spark_llm_domain=os.environ['SPARKAI_DOMAIN'],
            streaming=False,        
            temperature=0.1
        )
        self.history=[]
        self.history.append(ChatMessage(role="system",content=self.system_prompt))
        self.history.append(ChatMessage(role="user",content=self.question))
        
    def reset(self):
        if len(self.history)>self.history_len:
            self.history=self.history[:2] + self.history[-3:]

    def infer(self,user_input=None):    
        from sparkai.core.messages import ChatMessage
        from sparkai.llm.llm import ChunkPrintHandler
        if user_input:
            self.history.append(ChatMessage(role="user",content=user_input))        
        handler = ChunkPrintHandler()
        response = self.spark.generate([self.history], callbacks=[handler])
        self.history.append(response.generations[0][0].message)
        return response.generations[0][0].text


class zhipuai_llm(LLMInfer):
    def __init__(self, system_prompt, question):
        super().__init__(system_prompt, question)
        '''
        pip install zhipuai
        '''
        from zhipuai import ZhipuAI
        self.client = ZhipuAI(api_key=os.environ['ZhipuAI_API_KEY'])
        self.history=[]
        self.history.append({'role': 'system', 'content': self.system_prompt})
        self.history.append({'role': 'user', 'content': self.question})		
        
    def reset(self):
        if len(self.history)>self.history_len:
            self.history=self.history[:2] + self.history[-3:]

    def infer(self,user_input=None):      
        if user_input:
            self.history.append({'role': 'user', 'content': user_input})
        completion = self.client.chat.completions.create(
            model="glm-4",
            messages=self.history,
            temperature=0.3,
            top_p=0.1
        )
        role="assistant"
        content=completion.choices[0].message.content
        self.history.append({'role': role,'content': content})
        return content

class yi_llm(LLMInfer):
    def __init__(self, system_prompt, question):
        super().__init__(system_prompt, question)
        '''
        pip install --upgrade 'openai>=1.0'
        '''
        from openai import OpenAI
        self.client = OpenAI(
            api_key = os.environ['YI_API_KEY'],
            base_url = "https://api.lingyiwanwu.com/v1",
        )
        self.history=[]
        self.history.append({'role': 'system', 'content': self.system_prompt})
        self.history.append({'role': 'user', 'content': self.question})		
        
    def reset(self):
        if len(self.history)>self.history_len:
            self.history=self.history[:2] + self.history[-3:]

    def infer(self,user_input=None):      
        if user_input:
            self.history.append({'role': 'user', 'content': user_input})
        completion = self.client.chat.completions.create(
            model="yi-large",
            messages=self.history,
            temperature=0.3,
            top_p=0.1
        )
        role="assistant"
        content=completion.choices[0].message.content
        self.history.append({'role': role,'content': content})
        return content

llms=[dashscope_llm,moonshot_llm,qianfan_llm,sparkai_llm,zhipuai_llm,yi_llm]
for llm in llms:
    logger.info(f" ---------------------------------- {llm.__name__} ---------------------------------- ")
    llm=llm(system_prompt,question)
    response = llm.infer()
    for i in range(15):
        llm.reset()
        logger.info(f" ---------------------------------- 第{i}轮 ---------------------------------- ")
        result=None
        logger.info("####### bot #######")
        logger.info(f"{response}")
        if response:
            result=f"{extract_and_run_python_code(response)}"     
            logger.info("####### user #######")
            logger.info(f"{result}")
        response=llm.infer(result)
相关推荐
q_q王4 小时前
‌FunASR‌阿里开源的语音识别工具
python·大模型·llm·语音识别
pedestrian_h6 小时前
Spring AI 开发本地deepseek对话快速上手笔记
java·spring boot·笔记·llm·ollama·deepseek
浪淘沙jkp7 小时前
AI大模型学习二十、利用Dify+deepseekR1 使用知识库搭建初中英语学习智能客服机器人
人工智能·llm·embedding·agent·知识库·dify·deepseek
程序员小远16 小时前
自动化测试与功能测试详解
自动化测试·软件测试·python·功能测试·测试工具·职场和发展·测试用例
HuggingFace21 小时前
大模型评估排障指南 | 关于可复现性
大模型·llm
AI大模型顾潇1 天前
[特殊字符] 本地部署DeepSeek大模型:安全加固与企业级集成方案
数据库·人工智能·安全·大模型·llm·微调·llama
十里清风2 天前
LLM量化方法:ZeroQuant、LLM.int8()、SmoothQuant、GPTQ、AWQ
llm
知来者逆2 天前
在与大语言模型交互中的礼貌现象:技术影响、社会行为与文化意义的多维度探讨
人工智能·深度学习·语言模型·自然语言处理·llm
SHIPKING3933 天前
【Prompt工程—文生图】案例大全
llm·prompt·文生图
水煮蛋不加蛋3 天前
AutoGen 框架解析:微软开源的多人 Agent 协作新范式
人工智能·microsoft·ai·开源·大模型·llm·agent