#P4869.第2题-基于LSTM进行室内温度预测

第2题-基于LSTM进行室内温度预测 - problem_ide - CodeFun2000

【LSTM系列·第一篇】彻底搞懂:细胞状态、隐藏状态、候选状态、遗忘门------新手最晕的4个概念,一篇厘清_lstm遗忘门-CSDN博客

python 复制代码
import sys
import numpy as np
import math



def func():
    data = sys.stdin.read().split()
    if not data:
        return

    T = int(data[0])
    B = int(data[1])
    D = int(data[2])
    H = int(data[3])
    idx = 4
    X = np.zeros((T, B, D), dtype=np.float64)


    for t in range(T):
        row = list(map(float, data[idx: idx + B * D]))
        idx += B * D
        X[t] = np.array(row).reshape(B, D)
    
    total_param_per_gate = D*H+H*H+H
    gates = ['i','f','o','g']
    params = {}

    for gate in gates:
        param_vals = list(map(float,data[idx:idx+total_param_per_gate]))
        
        idx += total_param_per_gate 
        Wx = np.array(param_vals[:D*H]).reshape(D,H)
        Wh = np.array(param_vals[D*H:D*H+H*H]).reshape(H,H)
        b = np.array(param_vals[D*H+H*H:]).reshape(H,)
        params[gate] = (Wx,Wh,b)
    
    h_prev = np.zeros((B,H),dtype=np.float64)
    C_prev = np.zeros((B,H),dtype=np.float64)

    all_h = []

    # def sigmoid(x):
    #     x = np.clip(x,-500,500)
    #     return 1/(1+np.exp(-x))
    

    def sigmoid(x):
        x = np.array(x,dtype=float)
        result = np.empty_like(x)
        mask = (x>=0)
        result[mask] = 1/(1+np.exp(-x[mask]))
        result[~mask] = np.exp(x[~mask])/(1+np.exp(x[~mask]))

        return result
    
    for t in range(T):
        x_t = X[t]  # (B,D)
        Wx_i,Wh_i,b_i = params['i']
        Wx_f,Wh_f,b_f = params['f']
        Wx_o,Wh_o,b_o = params['o']
        Wx_g,Wh_g,b_g = params['g']

        i_t = sigmoid(x_t @ Wx_i + h_prev @ Wh_i + b_i)
        f_t = sigmoid(x_t @ Wx_f + h_prev @ Wh_f + b_f)
        o_t = sigmoid(x_t @ Wx_o + h_prev @ Wh_o + b_o)
        g_t = np.tanh(x_t @ Wx_g + h_prev @ Wh_g + b_g)

        C_t = f_t*C_prev+i_t*g_t
        h_t = o_t*np.tanh(C_t)

        all_h.append(h_t.copy())

        h_prev = h_t
        C_prev = C_t

    all_h = np.array(all_h)
    final_C = C_prev

    h_flat = all_h.reshape(-1)
    C_flat = final_C.reshape(-1)

    h_flat = np.round(h_flat,4)
    C_flat = np.round(C_flat,4)

    h_str = ' '.join(f"{x:.4f}" for x in h_flat)
    C_str = ' '.join(f"{x:.4f}" for x in C_flat)

    print(h_str)
    print(C_str)

if __name__ == '__main__':
    func()
相关推荐
IALab-检测行业AI报告生成1 小时前
IACheck 报告AI审核产品更新清单|上周更新(2026.5.4-2026.5.8)
人工智能
Alson_Code1 小时前
Spring Ai Alibaba
java·人工智能·spring
迅利科技1 小时前
CATIA:高端制造的“数字母体”
人工智能·科技·制造
Honey Ro1 小时前
pytorch中的损失函数使用
人工智能·pytorch·深度学习
weixin_435208161 小时前
大模型 Agent 面试高频100题——基础篇
人工智能·深度学习·自然语言处理·面试·职场和发展·aigc
青稞社区.1 小时前
OpenAI 翁家翌:“启发式学习”的强化学习新范式
人工智能·经验分享·学习·agi
QYR-分析1 小时前
全球及中国固定翼无人机光电吊舱行业发展现状与前景分析
人工智能·无人机
扬帆破浪1 小时前
免费开源AI软件.桌面单机版,可移动的AI知识库,察元 AI桌面版:公司只允许装签名应用 给察元AI打企业内部分发包
人工智能·windows·电脑·知识图谱
深度学习lover1 小时前
<数据集>yolo 桃子识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·桃子识别