Doris sql语句转换为sqlite

python 复制代码
import re
import string
from typing import Optional, Dict, List, Tuple

def convert_sqlite_group_concat(sql: str) -> str:
    """
    通用转换函数:修复转义符、截取长度、分隔符匹配问题,全场景兼容
    """
    # ---------------------- 步骤1:SQL预处理 ----------------------
    processed_sql = sql.strip()
    processed_sql = processed_sql.lower()
    processed_sql = processed_sql.replace('\\', '')
    
    # 双引号分隔符转单引号
    processed_sql = re.sub(
        r'group_concat\(distinct\s*(\w+)\s*,\s*"([^"]+)"\)',
        r"group_concat(distinct \1, '\2')",
        processed_sql
    )
    
    # 去除多余空格
    processed_sql = re.sub(r'\s+(?=[^()]*\))', ' ', processed_sql)
    processed_sql = re.sub(r'\s+', ' ', processed_sql)
    processed_sql = re.sub(r',\s+', ',', processed_sql)
    processed_sql = re.sub(r'\s+,', ',', processed_sql)
    
    # 移除注释
    processed_sql = re.sub(r'--.*?$', '', processed_sql, flags=re.MULTILINE)
    
    # 去除括号前后空格
    processed_sql = re.sub(r'\s*\(\s*', '(', processed_sql)
    processed_sql = re.sub(r'\s*\)\s*', ')', processed_sql)

    print("===== 预处理后的SQL(调试) =====")
    print(processed_sql)
    print("=================================\n")
    # 新增:将MySQL的date_format转换为SQLite的strftime
    # date_format(日期表达式, 格式字符串) -> strftime(格式字符串, 日期表达式)
    processed_sql = re.sub(
        r'date_format\s*\(\s*([^,]+?)\s*,\s*([^)]+?)\s*\)',
        r'strftime(\2, \1)',
        processed_sql
    )
    
    # 双引号分隔符转单引号
    processed_sql = re.sub(
        r'group_concat\(distinct\s*(\w+)\s*,\s*"([^"]+)"\)',
        r"group_concat(distinct \1, '\2')",
        processed_sql
    )
    print("===== 预处理后的SQL(调试) =====")
    print(processed_sql)
    print("=================================\n")
    
    # ---------------------- 步骤2:改进的子句提取逻辑 ----------------------
    other_clauses = {
        'order_by': None,
        'limit': None,
        'offset': None,
        'having': None
    }
    
    # 首先提取LIMIT和OFFSET(它们通常在最后)
    limit_match = re.search(r'limit\s+(\d+)(?:\s+offset\s+(\d+))?\s*$', processed_sql, re.IGNORECASE)
    if limit_match:
        other_clauses['limit'] = limit_match.group(1)
        if limit_match.group(2):
            other_clauses['offset'] = limit_match.group(2)
        print(f"【调试】匹配到LIMIT:{other_clauses['limit']}, OFFSET:{other_clauses['offset']}")
    
    # 提取ORDER BY(需要排除窗口函数中的ORDER BY)
    # 使用更精确的方法:找到GROUP BY之后的所有内容,然后从中提取ORDER BY
    group_by_match = re.search(r'group\s+by\s+[^)]+(?=\s*(?:order\s+by|having|limit|$))', processed_sql, re.IGNORECASE)
    if group_by_match:
        group_by_end = group_by_match.end()
        remaining_sql = processed_sql[group_by_end:].strip()
        
        # 从剩余部分提取ORDER BY
        order_by_match = re.search(r'order\s+by\s+(.+?)(?=\s*(?:limit|$))', remaining_sql, re.IGNORECASE)
        if order_by_match:
            other_clauses['order_by'] = order_by_match.group(1).strip()
            print(f"【调试】匹配到ORDER BY子句:{other_clauses['order_by']}")
    
    # 提取HAVING子句
    having_match = re.search(r'having\s+(.+?)(?=\s*(?:order\s+by|limit|$))', processed_sql, re.IGNORECASE)
    if having_match:
        other_clauses['having'] = having_match.group(1).strip()
        print(f"【调试】匹配到HAVING子句:{other_clauses['having']}")
    
    # 核心SQL是GROUP BY之前的部分
    group_by_start = re.search(r'group\s+by', processed_sql, re.IGNORECASE)
    if group_by_start:
        core_sql = processed_sql[:group_by_start.start()].strip()
    else:
        core_sql = processed_sql
    
    print(f"【调试】核心SQL:{core_sql}")

    # ---------------------- 步骤3:提取SELECT字段和GROUP BY字段 ----------------------
    # 查找FROM关键字的位置
    from_match = re.search(r'\bfrom\b', core_sql, re.IGNORECASE)
    if not from_match:
        raise ValueError("解析失败:未找到FROM关键字")
    
    from_pos = from_match.start()
    select_part = core_sql[:from_pos].strip()
    
    # 移除SELECT关键字
    if not select_part.lower().startswith('select'):
        raise ValueError("解析失败:不是SELECT语句")
    
    select_fields_str = select_part[6:].strip()  # 移除"select"
    print(f"【调试】SELECT字段部分:{select_fields_str}")
    
    # 使用新的方法解析SELECT字段
    select_fields, group_concat_info = parse_select_fields(select_fields_str)
    
    print(f"【调试】普通SELECT字段(字段名, 别名):{select_fields}")
    print(f"【调试】GROUP_CONCAT信息:{group_concat_info}")
    
    # 提取表名
    table_match = re.search(r'from\s+(\w+)', core_sql, re.IGNORECASE)
    if not table_match:
        raise ValueError("解析失败:未找到FROM子句中的表名")
    
    table_name = table_match.group(1)
    print(f"【调试】表名:{table_name}")
    
    # 提取WHERE条件
    where_match = re.search(r'where\s+(.+?)(?=\s*group\s+by)', core_sql, re.IGNORECASE)
    where_clause = where_match.group(1).strip() if where_match else ""
    if where_clause:
        print(f"【调试】WHERE条件:{where_clause}")
    
    # 提取GROUP BY字段
    group_by_match = re.search(r'group\s+by\s+(.+?)(?=\s*(?:order\s+by|having|limit|$))', processed_sql, re.IGNORECASE)
    if not group_by_match:
        raise ValueError("解析失败:未找到GROUP BY子句")
    
    group_by_raw = group_by_match.group(1).strip()
    group_by_fields = []
    for field in group_by_raw.split(','):
        field = field.strip()
        field_name = re.sub(r'\s+.*$', '', field)
        group_by_fields.append(field_name)
    
    print(f"【调试】GROUP BY字段:{group_by_fields}")
    
    # ---------------------- 步骤4:智能提取真实表字段 ----------------------
    all_used_fields = set()
    
    # SQL关键字和函数名黑名单
    sql_keywords = {
        'as', 'select', 'from', 'where', 'group', 'by', 'order', 'limit', 'offset', 'having',
        'count', 'sum', 'round', 'max', 'cast', 'substring', 'group_concat', 'distinct',
        'decimal', 'datetime', 'int', 'varchar', 'text', 'date', 'time', 'timestamp',
        'over', 'partition', 'window'  # 添加窗口函数相关关键字
    }
    
    # 从普通SELECT字段中提取真实字段
    for field_name, alias in select_fields:
        # 使用改进的字段提取方法
        real_fields = extract_real_table_fields(field_name, sql_keywords)
        all_used_fields.update(real_fields)
    
    # 添加GROUP_CONCAT字段
    all_used_fields.add(group_concat_info['concat_col'])
    
    # 添加GROUP BY字段
    for group_field in group_by_fields:
        all_used_fields.add(group_field)
    
    print(f"【调试】所有使用的真实字段:{sorted(all_used_fields)}")
    
    # ---------------------- 步骤5:生成最终SQL ----------------------
    # 构建内层查询的字段列表(只包含真实表字段)
    inner_select_fields = list(all_used_fields)
    
    # 构建外层查询的SELECT字段
    outer_select_fields = []
    for field_name, alias in select_fields:
        # 正确处理带别名的字段
        outer_select_fields.append(f"{field_name} as {alias}")

    # 添加转换后的GROUP_CONCAT字段
    group_concat_expr = f"SUBSTRING(GROUP_CONCAT({group_concat_info['concat_col']}, '{group_concat_info['delimiter']}'), {group_concat_info['sub_start']}, {group_concat_info['sub_length']})"
    outer_select_fields.append(f"{group_concat_expr} as {group_concat_info['alias']}")
    
    outer_select_clause = ', '.join(outer_select_fields)
    
    # 内层查询的GROUP BY字段
    inner_group_by_fields = list(set([group_concat_info['concat_col']] + group_by_fields))
    inner_group_by_clause = ', '.join(inner_group_by_fields)
    
    # 外层GROUP BY
    outer_group_by_clause = ', '.join(group_by_fields)
    
    # 构建其他子句
    where_clause_sql = f"WHERE {where_clause}" if where_clause else ""
    order_by_sql = f"ORDER BY {other_clauses['order_by']}" if other_clauses['order_by'] else ""
    having_sql = f"HAVING {other_clauses['having']}" if other_clauses['having'] else ""
    
    limit_sql = ""
    if other_clauses['limit']:
        if other_clauses['offset']:
            limit_sql = f"LIMIT {other_clauses['limit']} OFFSET {other_clauses['offset']}"
        else:
            limit_sql = f"LIMIT {other_clauses['limit']}"
    
    # 构建完整SQL
    final_sql = f"""
SELECT 
    {outer_select_clause}
FROM (
    SELECT {', '.join(inner_select_fields)}
    FROM {table_name}
    {where_clause_sql}
    GROUP BY {inner_group_by_clause}
) AS temp
GROUP BY {outer_group_by_clause}
{having_sql}
{order_by_sql}
{limit_sql}
""".strip()

    final_sql = re.sub(r'\n\s*\n', '\n', final_sql)
    return final_sql

def extract_real_table_fields(field_expression: str, keyword_blacklist: set) -> List[str]:
    """
    从字段表达式中提取真实的表字段,排除SQL关键字和函数名
    """
    real_fields = set()
    
    # 移除字符串常量(单引号或双引号内容)
    field_expr_no_strings = re.sub(r'[\'\"][^\'\"]*[\'\"]', '', field_expression)
    
    # 移除数字(包括小数)
    field_expr_no_numbers = re.sub(r'\b\d+\.?\d*\b', '', field_expr_no_strings)
    
    # 提取所有单词
    words = re.findall(r'\b[a-zA-Z_]\w*\b', field_expr_no_numbers)
    
    for word in words:
        word_lower = word.lower()
        # 排除SQL关键字、函数名和数字
        if (word_lower not in keyword_blacklist and 
            not word.isdigit() and
            len(word) > 1):  # 排除单字母(通常是变量名)
            real_fields.add(word)
    
    return list(real_fields)

def parse_select_fields(select_fields_str: str):
    """
    解析SELECT字段,分离普通字段和GROUP_CONCAT字段
    """
    select_fields = []
    group_concat_info = {}
    
    # 先分割所有字段
    all_fields = smart_split_fields(select_fields_str)
    
    for field in all_fields:
        field = field.strip()
        if not field:
            continue
            
        # 检查是否是GROUP_CONCAT字段
        concat_match = re.search(
            r'substring\(\s*group_concat\(distinct\s*(\w+)\s*,\s*\'([^\']+)\'\)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*(?:as\s+(\w+))?',
            field, 
            re.IGNORECASE
        )
        
        if not concat_match:
            # 尝试不带substring的GROUP_CONCAT
            concat_match = re.search(
                r'group_concat\(distinct\s*(\w+)\s*,\s*\'([^\']+)\'\)\s*(?:as\s+(\w+))?',
                field, 
                re.IGNORECASE
            )
            if concat_match:
                group_concat_info['concat_col'] = concat_match.group(1)
                group_concat_info['delimiter'] = concat_match.group(2)
                group_concat_info['sub_start'] = '1'
                group_concat_info['sub_length'] = '256'
                group_concat_info['alias'] = concat_match.group(3) if concat_match.group(3) else 'group_concat_result'
                continue
        
        if concat_match:
            group_concat_info['concat_col'] = concat_match.group(1)
            group_concat_info['delimiter'] = concat_match.group(2)
            group_concat_info['sub_start'] = concat_match.group(3)
            group_concat_info['sub_length'] = concat_match.group(4)
            group_concat_info['alias'] = concat_match.group(5) if concat_match.lastindex >= 5 else 'group_concat_result'
            continue
        
        # 如果不是GROUP_CONCAT字段,则作为普通字段处理
        field_name, alias = parse_field_with_alias(field)
        select_fields.append((field_name, alias))
    
    if not group_concat_info:
        raise ValueError("解析失败:未找到GROUP_CONCAT字段")
    
    return select_fields, group_concat_info

def smart_split_fields(fields_str: str) -> List[str]:
    """
    智能分割SELECT字段,处理复杂表达式
    """
    fields = []
    current_field = ""
    paren_depth = 0
    bracket_depth = 0  # 处理方括号
    
    i = 0
    while i < len(fields_str):
        char = fields_str[i]
        
        if char == '(':
            paren_depth += 1
            current_field += char
        elif char == ')':
            paren_depth -= 1
            current_field += char
            if paren_depth < 0:
                paren_depth = 0
        elif char == '[':
            bracket_depth += 1
            current_field += char
        elif char == ']':
            bracket_depth -= 1
            current_field += char
            if bracket_depth < 0:
                bracket_depth = 0
        elif char == ',' and paren_depth == 0 and bracket_depth == 0:
            if current_field.strip():
                fields.append(current_field.strip())
            current_field = ""
        else:
            current_field += char
        
        i += 1
    
    if current_field.strip():
        fields.append(current_field.strip())
    
    return fields

def parse_field_with_alias(field_str: str) -> Tuple[str, str]:
    """
    解析字段表达式和别名 - 修复版,正确处理无空格as和复杂表达式
    """
    field_str = field_str.strip()
    
    # 处理特殊情况:如果字段以"as"结尾,说明是复杂表达式但没有别名
    if re.search(r'\bas\s*$', field_str, re.IGNORECASE):
        field_str = re.sub(r'\s+as\s*$', '', field_str, flags=re.IGNORECASE).strip()
        return field_str, field_str
    
    # 查找最后一个"as"关键字(允许前后无空格)
    as_pattern = re.compile(r'\s*as\s*', re.IGNORECASE)
    as_matches = list(as_pattern.finditer(field_str))
    
    if as_matches:
        last_as_match = as_matches[-1]
        alias_part = field_str[last_as_match.end():].strip()
        
        # 检查别名部分是否有效
        if alias_part and re.match(r'^[\w\u4e00-\u9fff()()]+$', alias_part):
            field_expr = field_str[:last_as_match.start()].strip()
            
            # 检查字段表达式是否完整(括号平衡)
            if field_expr.count('(') == field_expr.count(')'):
                return field_expr, alias_part
    
    # 对于没有明确AS关键字的简单字段
    if not re.search(r'[(),]', field_str):
        parts = re.split(r'\s+', field_str.strip())
        if len(parts) == 2 and re.match(r'^[\w\u4e00-\u9fff()()]+$', parts[1]) and not parts[1].isdigit():
            return parts[0], parts[1]
    
    # 默认:整个表达式作为字段名,同时作为别名
    return field_str, field_str

# 测试用例
if __name__ == "__main__":
    original_sql5 = """
SELECT region as a, count(*) as g, round(sum(cast(ee as decimal(20,2))),2)/10000 as 总额(万元), country b, city as c, population as d, max(cast(sj as datetime)) as h,
       SUBSTRING(GROUP_CONCAT(DISTINCT landmark, '|'),1,300)  as f
FROM geography 
GROUP BY region, country
order by 总额(万元)
desc limit 20;
"""
    original_sql6 = """
SELECT region as a, count(*) as g, round(sum(cast(ee as decimal(20,2))),2)/10000 as 总额(万元), country b, city as c, population as d, max(cast(sj as datetime)) as h,
       SUBSTRING(GROUP_CONCAT(DISTINCT landmark, '|'),1,300)  as f, count(*) over (partition by a order by cast(sj as datetime)) as i, count(distinct date_format(cast(sj as datetime),'%y-%m-%d %h')) as 天数
FROM geography 
GROUP BY region, country
order by 总额(万元)
desc limit 20;
"""

    print(f"\n{'='*20} 测试用例5 {'='*20}")
    try:
        new_sql = convert_sqlite_group_concat(original_sql5)
        print("\n✅ 转换成功!最终SQL:")
        print(new_sql)
    except Exception as e:
        print(f"\n❌ 转换失败:{e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n{'='*20} 测试用例6 {'='*20}")
    try:
        new_sql = convert_sqlite_group_concat(original_sql6)
        print("\n✅ 转换成功!最终SQL:")
        print(new_sql)
    except Exception as e:
        print(f"\n❌ 转换失败:{e}")
        import traceback
        traceback.print_exc()
    
    print('='*50)
相关推荐
冰暮流星2 小时前
sql语言之having语句使用
java·数据库·sql
麦聪聊数据2 小时前
从数据采集到 API 市场的完整技术链路
数据库·sql·低代码·微服务
he___H2 小时前
jvm48-96回
java·jvm·性能优化
知识即是力量ol16 小时前
口语八股——MySQL 核心原理系列(终篇):SQL优化篇、日志与主从复制篇、高级特性篇、面试回答技巧总结
sql·mysql·面试·核心原理
yixin12320 小时前
【玩转全栈】----Django基本配置和介绍
数据库·django·sqlite
he___H21 小时前
jvm前15回
jvm
不剪发的Tony老师1 天前
FlySpeed:一款通用的SQL查询工具
数据库·sql
ℳ₯㎕ddzོꦿ࿐1 天前
[特殊字符] 【踩坑记录】没调 startPage(),SQL 却被自动分页了?
数据库·sql