SM3算法Python实现(无第三方库)

一、SM3算法介绍

SM3算法是中国国家密码管理局(OSCCA)于2010年发布的商用密码散列函数标准,属于我国自主设计的密码算法体系之一 ,标准文档下载地址为:SM3密码杂凑算法 。SM3算法输出长度为256位(32字节),与SHA-256类似,但采用了更适合国内安全需求的优化结构。SM3基于Merkle-Damgård迭代结构,通过填充、消息分组、扩展和压缩等步骤处理输入数据,确保任意长度的消息都能生成固定长度的摘要。作为我国密码行业标准(GM/T 0004-2012),SM3在政务、金融、物联网等领域广泛应用,是我国信息安全国产化的重要支撑。

SM3算法的核心流程包括消息填充、消息扩展和压缩函数三部分。首先,输入数据会被填充至512位的整数倍,并附加长度信息。随后,消息分组通过扩展算法生成132个32位字,供压缩函数使用。压缩函数采用64轮非线性迭代运算,结合与、或、异或、模加法等操作,并引入多个常量进行混淆,确保雪崩效应(微小输入变化导致输出巨大差异)。SM3的设计在安全性和效率上取得平衡,能够有效抵抗碰撞攻击、长度扩展攻击等威胁。

二、Python代码实现

SM3算法的Python实现如下所示,大致可以分为三个部分:

1. 初始化与基础运算函数

SM3 类初始化时设定默认的初始向量 IV(8个32位常量)和字符编码方式(ASCII/UTF-8/GBK)。核心辅助函数包括:

  • cshift_left(x, l):实现32位整数的循环左移,确保位移后仍为32位。

  • Tj(j):根据轮数返回常量(前16轮为0x79cc4519,后48轮为0x7a879d8a)。

  • FFjGGj:布尔函数,分别用于压缩函数中的非线性运算,前16轮使用异或逻辑,后48轮改用与/或逻辑增强扩散性。

  • P0P1:置换函数,通过循环左移和异或操作打乱数据(P0用于压缩末步,P1用于消息扩展)。

2. 消息填充与分组处理

padding(msg) 方法将输入消息按SM3标准填充为512位的整数倍:

  1. 数据转换:支持字符串、整数或字节流输入,统一转为字节序列。

  2. 填充规则:末尾添加0x80,补零至长度满足 (消息长度 + 64) % 512 = 0,最后64位写入原始消息长度的二进制表示。

  3. 分组输出:返回16字(32位/字)的块列表,每块包含填充后的数据,用于后续压缩。

3. 压缩函数与摘要生成

CF(V, B):压缩函数的核心,处理单个512位分组:

  • 消息扩展:将16字的输入块B扩展为68字(W0)和64字(W1),通过P1置换增强非线性。

  • 64轮迭代:每轮更新8个工作变量(A-H),结合SS1/SS2位移、TT1/TT2混合运算及P0置换,实现高强度混淆。

compression(msg):驱动流程,调用padding分块后,逐块应用CF压缩,最终将8个状态变量拼接为256位哈希值,以字节形式输出。

python 复制代码
class SM3:
    def __init__(self, encoding='ascii'):  # encodine:ascii/utf-8/gbk
        self.IV = [0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e]
        self.encoding = encoding

    def cshift_left(self, x, l):
        while l >= 32:
            l -= 32
        x = x & 0xffffffff
        bin_x = '{:032b}'.format(x)
        bin_x = bin_x[l:] + bin_x[:l]
        return int(bin_x, 2)

    def Tj(self, j):
        if (j < 16):
            return 0x79cc4519
        else:
            return 0x7a879d8a

    def FFj(self, x, y, z, j):
        if j < 16:
            return x ^ y ^ z
        else:
            return (x & y) | (x & z) | (y & z)

    def GGj(self, x, y, z, j):
        if j < 16:
            return x ^ y ^ z
        else:
            return (x & y) | (~x & z)

    def P0(self, x):
        return x ^ self.cshift_left(x, 9) ^ self.cshift_left(x, 17)

    def P1(self, x):
        return x ^ self.cshift_left(x, 15) ^ self.cshift_left(x, 23)

    def padding(self, msg):
        msg_len = len(msg)
        msg_blen = msg_len << 3
        m, n = msg_len >> 2, msg_len & 3
        block = []
        one_block = []
        if type(msg) == type(''):
            bt_msg = msg.encode(encoding=self.encoding, errors='strict')
        elif type(msg) == type(0):
            bt_msg = msg.to_bytes((msg.bit_length() + 7) // 8, "big")
        else:
            bt_msg = msg
        for i in range(m):
            wd = bt_msg[0] << 24 | bt_msg[1] << 16 | bt_msg[2] << 8 | bt_msg[3]
            one_block.append(wd)
            bt_msg = bt_msg[4:]
            if i & 15 == 15:
                block.append(one_block.copy())
                one_block.clear()
        if n == 0:
            new_wd = 0x80 << 24
        elif n == 1:
            new_wd = bt_msg[0] << 24 | 0x80 << 16
        elif n == 2:
            new_wd = bt_msg[0] << 24 | bt_msg[1] << 16 | 0x80 << 8
        else:
            new_wd = bt_msg[0] << 24 | bt_msg[1] << 16 | bt_msg[2] << 8 | 0x80
        one_block.append(new_wd)
        ob_len = len(one_block)
        if ob_len <= 14:
            for i in range(14 - ob_len):
                one_block.append(0)
            one_block.append(msg_blen >> 32)
            one_block.append(msg_blen & 0xffffffff)
            block.append(one_block.copy())
        else:
            for i in range(16 - ob_len):
                one_block.append(0)
            block.append(one_block.copy())
            one_block = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, msg_blen >> 32, msg_blen & 0xffffffff]
            block.append(one_block.copy())
        return block

    def CF(self, V, B):
        W0, W1 = B.copy(), []
        for i in range(16, 68):
            wd = self.P1(W0[i - 16] ^ W0[i - 9] ^ self.cshift_left(W0[i - 3], 15)) ^ self.cshift_left(W0[i - 13], 7) ^ \
                 W0[i - 6]
            W0.append(wd)
        for i in range(64):
            W1.append(W0[i] ^ W0[i + 4])
        A, B, C, D, E, F, G, H = V
        for i in range(64):
            SS1 = (self.cshift_left(self.cshift_left(A, 12) + E + self.cshift_left(self.Tj(i), i), 7)) & 0xffffffff
            SS2 = SS1 ^ self.cshift_left(A, 12)
            TT1 = (self.FFj(A, B, C, i) + D + SS2 + W1[i]) & 0xffffffff
            TT2 = (self.GGj(E, F, G, i) + H + SS1 + W0[i]) & 0xffffffff
            D = C
            C = self.cshift_left(B, 9)
            B = A
            A = TT1
            H = G
            G = self.cshift_left(F, 19)
            F = E
            E = self.P0(TT2)
        return A, B, C, D, E, F, G, H

    def compression(self, msg):
        block = self.padding(msg)
        V = self.IV
        for bi in block:
            res = self.CF(V, bi)
            for i in range(8):
                V[i] = V[i] ^ res[i]
        res = b''.join(v.to_bytes(4, "big") for v in V)
        return res

三、正确性验证

我们编写了下面的代码进行正确性验证,这是SM3标准文档中的两个测试向量:

python 复制代码
def test_case1():
    msg = "abc"
    expected_hash=0x66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0
    test_hash_bytes = SM3().compression(msg)
    test_hash=int.from_bytes(test_hash_bytes, 'big')
    if test_hash == expected_hash:
        print('测试用例1通过.')
    else:
        print('测试用例1失败')

def test_case2():
    msg = "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd"
    expected_hash=0xdebe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732
    test_hash_bytes = SM3().compression(msg)
    test_hash=int.from_bytes(test_hash_bytes, 'big')
    if test_hash == expected_hash:
        print('测试用例2通过.')
    else:
        print('测试用例2失败')

if __name__ == '__main__':
    test_case1()
    test_case2()

需要注意示例2中512比特消息为16进制数表示,我们的测试用例2中的字符"abcd"对应的ASCII码正是"0x61, 0x62, 0x63, 0x64"。运行代码后,输出如下图,我们的代码正确运行。

为了方便读者直接运行,我们的完整代码如下:

python 复制代码
class SM3:
    def __init__(self, encoding='ascii'):  # encodine:ascii/utf-8/gbk
        self.IV = [0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e]
        self.encoding = encoding

    def cshift_left(self, x, l):
        while l >= 32:
            l -= 32
        x = x & 0xffffffff
        bin_x = '{:032b}'.format(x)
        bin_x = bin_x[l:] + bin_x[:l]
        return int(bin_x, 2)

    def Tj(self, j):
        if (j < 16):
            return 0x79cc4519
        else:
            return 0x7a879d8a

    def FFj(self, x, y, z, j):
        if j < 16:
            return x ^ y ^ z
        else:
            return (x & y) | (x & z) | (y & z)

    def GGj(self, x, y, z, j):
        if j < 16:
            return x ^ y ^ z
        else:
            return (x & y) | (~x & z)

    def P0(self, x):
        return x ^ self.cshift_left(x, 9) ^ self.cshift_left(x, 17)

    def P1(self, x):
        return x ^ self.cshift_left(x, 15) ^ self.cshift_left(x, 23)

    def padding(self, msg):
        msg_len = len(msg)
        msg_blen = msg_len << 3
        m, n = msg_len >> 2, msg_len & 3
        block = []
        one_block = []
        if type(msg) == type(''):
            bt_msg = msg.encode(encoding=self.encoding, errors='strict')
        elif type(msg) == type(0):
            bt_msg = msg.to_bytes((msg.bit_length() + 7) // 8, "big")
        else:
            bt_msg = msg
        for i in range(m):
            wd = bt_msg[0] << 24 | bt_msg[1] << 16 | bt_msg[2] << 8 | bt_msg[3]
            one_block.append(wd)
            bt_msg = bt_msg[4:]
            if i & 15 == 15:
                block.append(one_block.copy())
                one_block.clear()
        if n == 0:
            new_wd = 0x80 << 24
        elif n == 1:
            new_wd = bt_msg[0] << 24 | 0x80 << 16
        elif n == 2:
            new_wd = bt_msg[0] << 24 | bt_msg[1] << 16 | 0x80 << 8
        else:
            new_wd = bt_msg[0] << 24 | bt_msg[1] << 16 | bt_msg[2] << 8 | 0x80
        one_block.append(new_wd)
        ob_len = len(one_block)
        if ob_len <= 14:
            for i in range(14 - ob_len):
                one_block.append(0)
            one_block.append(msg_blen >> 32)
            one_block.append(msg_blen & 0xffffffff)
            block.append(one_block.copy())
        else:
            for i in range(16 - ob_len):
                one_block.append(0)
            block.append(one_block.copy())
            one_block = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, msg_blen >> 32, msg_blen & 0xffffffff]
            block.append(one_block.copy())
        return block

    def CF(self, V, B):
        W0, W1 = B.copy(), []
        for i in range(16, 68):
            wd = self.P1(W0[i - 16] ^ W0[i - 9] ^ self.cshift_left(W0[i - 3], 15)) ^ self.cshift_left(W0[i - 13], 7) ^ \
                 W0[i - 6]
            W0.append(wd)
        for i in range(64):
            W1.append(W0[i] ^ W0[i + 4])
        A, B, C, D, E, F, G, H = V
        for i in range(64):
            SS1 = (self.cshift_left(self.cshift_left(A, 12) + E + self.cshift_left(self.Tj(i), i), 7)) & 0xffffffff
            SS2 = SS1 ^ self.cshift_left(A, 12)
            TT1 = (self.FFj(A, B, C, i) + D + SS2 + W1[i]) & 0xffffffff
            TT2 = (self.GGj(E, F, G, i) + H + SS1 + W0[i]) & 0xffffffff
            D = C
            C = self.cshift_left(B, 9)
            B = A
            A = TT1
            H = G
            G = self.cshift_left(F, 19)
            F = E
            E = self.P0(TT2)
        return A, B, C, D, E, F, G, H

    def compression(self, msg):
        block = self.padding(msg)
        V = self.IV
        for bi in block:
            res = self.CF(V, bi)
            for i in range(8):
                V[i] = V[i] ^ res[i]
        res = b''.join(v.to_bytes(4, "big") for v in V)
        return res

def test_case1():
    msg = "abc"
    expected_hash=0x66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0
    test_hash_bytes = SM3().compression(msg)
    test_hash=int.from_bytes(test_hash_bytes, 'big')
    if test_hash == expected_hash:
        print('测试用例1通过.')
    else:
        print('测试用例1失败')

def test_case2():
    msg = "abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd"
    expected_hash=0xdebe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732
    test_hash_bytes = SM3().compression(msg)
    test_hash=int.from_bytes(test_hash_bytes, 'big')
    if test_hash == expected_hash:
        print('测试用例2通过.')
    else:
        print('测试用例2失败')

if __name__ == '__main__':
    test_case1()
    test_case2()
相关推荐
aq55356001 分钟前
PHP vs Python:30秒看懂核心区别
开发语言·python·php
我是无敌小恐龙3 分钟前
Java SE 零基础入门Day01 超详细笔记(开发前言+环境搭建+基础语法)
java·开发语言·人工智能·opencv·spring·机器学习
m0_3776182333 分钟前
Redis怎样应对大规模集群的重启风暴_分批次重启节点并等待集群状态恢复绿灯后再继续操作
jvm·数据库·python
码云数智-大飞40 分钟前
零基础微信小程序制作平台哪个好
开发语言
心态与习惯1 小时前
Julia 初探,及与 C++,Java,Python 的比较
java·c++·python·julia·比较
py有趣1 小时前
力扣热门100题之不同路径
算法·leetcode
神仙别闹1 小时前
基于 MATLAB 实现的 DCT 域的信息隐藏
开发语言·matlab
ZC跨境爬虫1 小时前
3D 地球卫星轨道可视化平台开发 Day8(分步渲染200颗卫星+ 前端分页控制)
前端·python·3d·重构·html
_日拱一卒1 小时前
LeetCode:25K个一组翻转链表
算法·leetcode·链表
techdashen1 小时前
Go 标准库 JSON 包迎来重大升级:encoding/json/v2 实验版来了
开发语言·golang·json