对算子shape相关的属性值自动化处理

一.泛化调试

将单单单的grouptype 为2时的2种场景的进行,自动化处理并验证

复制代码
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())

attr_value5 = []
if g == 1:
    attr_value5 = [k]
else:
    for i in range(g):
        value = round(i * k / (g - 1))# 生成从0到k的均匀递增序列
        attr_value5.append(value)
    attr_value5[-1] = k
print(f"\n属性值在第5种情况时:\n分成g段,升序累计为k",attr_value5)

attr_value6 = []
if g == 1:
    attr_value6 = [k]
else:
    for i in range(g):
        attr_value4.append(0)
    attr_value4[-1] = k
print(f"\n属性值在第6种情况时:\n分成g段,最大为k",attr_value6)

二.补充场景

补充其他单单单的场景,如grouptype 为0时的代码,并将各场景和代码归并简化,运行后验证无误

复制代码
# 算子形状处
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())

#当case为1是 分段g个,最大m
# 若输入 256 32 41 32 如[0, 37, 37, 82, 137, 137, 137, 137, 139, 139, 139, 155, 155, 160, 170, 180, 190, 195, 199, 200, 201, 202, 206, 206, 206, 255, 256, 256, 256, 256, 256, 256]
attr_value1 = []
if g == 1:
    attr_value1 = [m]
else:
    for i in range(g):
        value = round(i * m / (g - 1))  # 生成从0到k的均匀递增序列
        attr_value1.append(value)
    attr_value1[-1] = m  # 确保最后一个元素为k
print(f"\n属性值在第1种情况时: 分成{g}段,升序累计为{m}\n",attr_value1)

#当case为2是 分段g个,累计m
# 若输入 256 32 41 32 如[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256]
attr_value2 = []
if g == 1:
    attr_value2 = [m]
else:
    # for i in range(g):
    #     attr_value2.append(0)
    # attr_value2[-1] = k
    attr_value2 = [0] * (g - 1) + [m] #精简代码
print(f"\n属性值在第2种情况时: 分成{g}段,升序累计为{m}\n",attr_value2)

#当case为3是  分段g个,最大k
#若输入 256 32 41 32 如[0, 0, 7, 8, 8, 9, 10, 11, 13, 19, 21, 25, 26,27, 29, 30, 33, 35, 38, 38, 38, 38, 39, 39, 40, 40, 40, 40, 40, 41, 41, 41]
attr_value3 = []
if g == 1:
    attr_value3 = [k]
else:
    for i in range(g):
        value = round(i * k / (g - 1))# 生成从0到k的均匀递增序列
        attr_value3.append(value)
    attr_value3[-1] = k    # 确保最后一个元素为k
print(f"\n属性值在第3种情况时: 分成{g}段,升序累计为{k}\n",attr_value3)

#当case为4是 分段g个,累计k
#若输入 256 32 41 32  如[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 41]
attr_value4 = []
if g == 1:
    attr_value4 = [k]
else:
    attr_value4 = [0] * (g - 1) + [k]
print(f"\n属性值在第4种情况时: 分成{g}段,最大为{k}\n",attr_value4)

三.简化代码

发现代码还是有些冗余,将多场景的按多分支优化,简化代码

复制代码
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())

def gen_attr(val, mode):
    return [val] if g == 1 else [round(i * val / (g-1)) for i in range(g-1)] + [val] if mode else [0]*(g-1) + [val]

cases = [
    (1, m, True, f"分成{g}段,升序累计为{m}"),
    (2, m, False, f"分成{g}段,累计为{m}"),
    (3, k, True, f"分成{g}段,升序累计为{k}"),
    (4, k, False, f"分成{g}段,最大为{k}")
]

for i, (num, val, mode, desc) in enumerate(cases, 1):
    attr = gen_attr(val, mode)
    print(f"\n属性值在第{num}种情况时: {desc}")
    print(attr)

四.同输入输出形状的自动化合并

将之前的shape的自动化一起输出,因为这是2不同维度的,所以可以相互组合形成笛卡尔积,故将此相关情况均枚举,待自定义组合使用

复制代码
m, n, k, g = map(int, input("输入m(输出行数/高度), n(输出列数/宽度), k(内积维度/中间维), g(分组数) 参数:").split())
def generate_case(case_num):
    """生成指定分支的配置输出"""
    # 基础配置
    x_shape = [m, k] if case_num <= 4 else [k, m]
    weight_shape = [g, n, k] if case_num in (1, 2) else [g, k, n] if case_num in (3, 4) else [k, n]
    bias_shape = [g, n] if case_num in (1, 3) else [1]
    output_shape = [m, n] if case_num <= 4 else [g, m, n]

    # 分支特定配置
    desc = [
        ("groupType=0 groupListType=0 转置2 偏移", f"上限{m}"),
        ("groupType=0 groupListType=0 转置2 无偏移", f"上限{m}"),
        ("groupType=0 groupListType=1 转置0 偏移", f"上限{m}"),
        ("groupType=0 groupListType=1 转置0 无偏移", f"上限{m}"),
        ("groupType=2 groupListType=0 转置1 无偏移", f"上限{k}"),
        ("groupType=2 groupListType=1 转置1 无偏移", f"上限{k}")
    ][case_num - 1]

    # 生成配置字符串
    config = f"[[{x_shape}], [{weight_shape}], [{bias_shape}], [{g}], " + \
             ", ".join(["[1]"] * 8) + "]"

    return f"\n{desc[0]} {desc[1]}: \n{config}\n[[{output_shape}], [[0]], [[0]]]"

# 输出所有分支
for case in range(1, 7):
    print(generate_case(case))

def gen_attr(val, mode):
    return [val] if g == 1 else [round(i * val / (g-1)) for i in range(g-1)] + [val] if mode else [0]*(g-1) + [val]

cases = [
    (1, m, True, f"分成{g}段,升序累计为{m}"),
    (2, m, False, f"分成{g}段,累计为{m}"),
    (3, k, True, f"分成{g}段,升序累计为{k}"),
    (4, k, False, f"分成{g}段,最大为{k}")
]

for i, (num, val, mode, desc) in enumerate(cases, 1):
    attr = gen_attr(val, mode)
    print(f"\n属性值在第{num}种情况时: {desc}")
    print(attr)

整理不易,诚望各位看官点赞 收藏 评论 予以支持,这将成为我持续更新的动力源泉。若您在阅览时存有异议或建议,敬请留言指正批评,让我们携手共同学习,共同进取,吾辈自当相互勉励!

相关推荐
sxtyjty12 小时前
AtCoder Beginner Contest 450 G题题解
数学·算法·期望
m0_7301151112 小时前
用户认证与授权:使用JWT保护你的API
jvm·数据库·python
ccLianLian12 小时前
数论·快速幂和逆元
数据结构·算法
kaisun6412 小时前
树莓派4B上使用INMP441麦克风进行语音识别:从I2S配置到Python环境搭建全记录
python·语音识别·树莓派
没头脑的男大12 小时前
华为题目152乘积最大子数组
算法·华为
Yeats_Liao12 小时前
华为开源自研AI框架昇思MindSpore应用案例:WaveNet实现音乐生成
人工智能·深度学习·算法·机器学习·边缘计算
七夜zippoe12 小时前
Python 3.12+ 新特性深度解析:类型系统与性能革命
android·网络·python·类型系统·性能革命·3.12+
_饭团12 小时前
C 语言数据存储全解析:原反补码、大小端与 IEEE 754 浮点数
c语言·数据结构·算法·leetcode·面试·蓝桥杯·学习方法
2401_8732046512 小时前
C++与Docker集成开发
开发语言·c++·算法
j_xxx404_12 小时前
力扣--分治(归并排序)算法题II:计算右侧小于当前元素的个数,翻转对(无痛通关困难题)
开发语言·数据结构·c++·算法·leetcode