对算子shape相关的属性值自动化处理

一.泛化调试

将单单单的grouptype 为2时的2种场景的进行,自动化处理并验证

复制代码
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())

attr_value5 = []
if g == 1:
    attr_value5 = [k]
else:
    for i in range(g):
        value = round(i * k / (g - 1))# 生成从0到k的均匀递增序列
        attr_value5.append(value)
    attr_value5[-1] = k
print(f"\n属性值在第5种情况时:\n分成g段,升序累计为k",attr_value5)

attr_value6 = []
if g == 1:
    attr_value6 = [k]
else:
    for i in range(g):
        attr_value4.append(0)
    attr_value4[-1] = k
print(f"\n属性值在第6种情况时:\n分成g段,最大为k",attr_value6)

二.补充场景

补充其他单单单的场景,如grouptype 为0时的代码,并将各场景和代码归并简化,运行后验证无误

复制代码
# 算子形状处
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())

#当case为1是 分段g个,最大m
# 若输入 256 32 41 32 如[0, 37, 37, 82, 137, 137, 137, 137, 139, 139, 139, 155, 155, 160, 170, 180, 190, 195, 199, 200, 201, 202, 206, 206, 206, 255, 256, 256, 256, 256, 256, 256]
attr_value1 = []
if g == 1:
    attr_value1 = [m]
else:
    for i in range(g):
        value = round(i * m / (g - 1))  # 生成从0到k的均匀递增序列
        attr_value1.append(value)
    attr_value1[-1] = m  # 确保最后一个元素为k
print(f"\n属性值在第1种情况时: 分成{g}段,升序累计为{m}\n",attr_value1)

#当case为2是 分段g个,累计m
# 若输入 256 32 41 32 如[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256]
attr_value2 = []
if g == 1:
    attr_value2 = [m]
else:
    # for i in range(g):
    #     attr_value2.append(0)
    # attr_value2[-1] = k
    attr_value2 = [0] * (g - 1) + [m] #精简代码
print(f"\n属性值在第2种情况时: 分成{g}段,升序累计为{m}\n",attr_value2)

#当case为3是  分段g个,最大k
#若输入 256 32 41 32 如[0, 0, 7, 8, 8, 9, 10, 11, 13, 19, 21, 25, 26,27, 29, 30, 33, 35, 38, 38, 38, 38, 39, 39, 40, 40, 40, 40, 40, 41, 41, 41]
attr_value3 = []
if g == 1:
    attr_value3 = [k]
else:
    for i in range(g):
        value = round(i * k / (g - 1))# 生成从0到k的均匀递增序列
        attr_value3.append(value)
    attr_value3[-1] = k    # 确保最后一个元素为k
print(f"\n属性值在第3种情况时: 分成{g}段,升序累计为{k}\n",attr_value3)

#当case为4是 分段g个,累计k
#若输入 256 32 41 32  如[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 41]
attr_value4 = []
if g == 1:
    attr_value4 = [k]
else:
    attr_value4 = [0] * (g - 1) + [k]
print(f"\n属性值在第4种情况时: 分成{g}段,最大为{k}\n",attr_value4)

三.简化代码

发现代码还是有些冗余,将多场景的按多分支优化,简化代码

复制代码
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())

def gen_attr(val, mode):
    return [val] if g == 1 else [round(i * val / (g-1)) for i in range(g-1)] + [val] if mode else [0]*(g-1) + [val]

cases = [
    (1, m, True, f"分成{g}段,升序累计为{m}"),
    (2, m, False, f"分成{g}段,累计为{m}"),
    (3, k, True, f"分成{g}段,升序累计为{k}"),
    (4, k, False, f"分成{g}段,最大为{k}")
]

for i, (num, val, mode, desc) in enumerate(cases, 1):
    attr = gen_attr(val, mode)
    print(f"\n属性值在第{num}种情况时: {desc}")
    print(attr)

四.同输入输出形状的自动化合并

将之前的shape的自动化一起输出,因为这是2不同维度的,所以可以相互组合形成笛卡尔积,故将此相关情况均枚举,待自定义组合使用

复制代码
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())
def generate_case(case_num):
    """生成指定分支的配置输出"""
    # 基础配置
    x_shape = [m, k] if case_num <= 4 else [k, m]
    weight_shape = [g, n, k] if case_num in (1, 2) else [g, k, n] if case_num in (3, 4) else [k, n]
    bias_shape = [g, n] if case_num in (1, 3) else [1]
    output_shape = [m, n] if case_num <= 4 else [g, m, n]

    # 分支特定配置
    desc = [
        ("groupType=0 groupListType=0 转置2 偏移", f"上限{m}"),
        ("groupType=0 groupListType=0 转置2 无偏移", f"上限{m}"),
        ("groupType=0 groupListType=1 转置0 偏移", f"上限{m}"),
        ("groupType=0 groupListType=1 转置0 无偏移", f"上限{m}"),
        ("groupType=2 groupListType=0 转置1 无偏移", f"上限{k}"),
        ("groupType=2 groupListType=1 转置1 无偏移", f"上限{k}")
    ][case_num - 1]

    # 生成配置字符串
    config = f"[[{x_shape}], [{weight_shape}], [{bias_shape}], [{g}], " + \
             ", ".join(["[1]"] * 8) + "]"

    return f"\n{desc[0]} {desc[1]}: \n{config}\n[[{output_shape}], [[0]], [[0]]]"

# 输出所有分支
for case in range(1, 7):
    print(generate_case(case))

def gen_attr(val, mode):
    return [val] if g == 1 else [round(i * val / (g-1)) for i in range(g-1)] + [val] if mode else [0]*(g-1) + [val]

cases = [
    (1, m, True, f"分成{g}段,升序累计为{m}"),
    (2, m, False, f"分成{g}段,累计为{m}"),
    (3, k, True, f"分成{g}段,升序累计为{k}"),
    (4, k, False, f"分成{g}段,最大为{k}")
]

for i, (num, val, mode, desc) in enumerate(cases, 1):
    attr = gen_attr(val, mode)
    print(f"\n属性值在第{num}种情况时: {desc}")
    print(attr)

整理不易,诚望各位看官点赞 收藏 评论 予以支持,这将成为我持续更新的动力源泉。若您在阅览时存有异议或建议,敬请留言指正批评,让我们携手共同学习,共同进取,吾辈自当相互勉励!

相关推荐
WoY20202 小时前
本地PyCharm配置远程服务器上的python环境
服务器·python·pycharm
tzjly2 小时前
JSON数据一键导入SQL Server
python
高山上有一只小老虎2 小时前
小红的推荐系统
java·算法
冰西瓜6002 小时前
贪心(一)——从动态规划到贪心 算法设计与分析 国科大
算法·贪心算法·动态规划
javachen__2 小时前
341-十道经典程序设计题目
数据结构·c++·算法
一分半心动3 小时前
清理C盘的python脚本
开发语言·python
natide3 小时前
表示/嵌入差异-7-间隔/边际对齐(Alignment Margin)
人工智能·深度学习·算法·机器学习·自然语言处理·知识图谱
毅炼3 小时前
hot100打卡——day08
java·数据结构·算法·leetcode·深度优先
l1t3 小时前
DeepSeek总结的算法 X 与舞蹈链文章
前端·javascript·算法