心电数据大模型尝试 -- QRS波检测

最近闲暇之余,尝试了通过大模型训练进行心电数据的分析,通过结合AI工具进行研究,感觉得到的结果比几年前的要高了一些。

当前的主要内容是进行QRS波检测的研究,通过MIT-BIH的心律失常数据库获取训练集、验证集和测试集数据,采用CNN+LSTM的模型,基于pytorch框架进行训练。

部署测试采用基于云上的进行,购买了一个腾讯云服务,在云上搭建了一个简单的后端预测推理,以及前端Web显示,开放端口后可通过IP地址访问前端Web,将数据转为约定的JSON格式后上传进行推理预测后显示。

数据选取

选用MIT-BIH心律失常数据库中的数据,去掉其中的起搏信号数据(4例起搏数据),在剩余的44例数据中,随机打乱后选取30例作为训练数据,5例作为验证数据,9例作为测试数据。

为了兼容目前静态常用的500Hz采样率,所有的数据均由360Hz重采样为250Hz。

数据与标签

数据采用重采样为250Hz后的原始数据,标签为与数据长度相等的0/1数据列表,在QRS波左右各50ms,总共100ms的范围为1,其他地方为0。

数据片段的长度为2秒钟共500点,为了增加数据量,采用步长为1秒进行截取,同时为了减少通道数据冗余,第2各通道的截取起始点与第一个通道有0.5秒钟的偏移。

数据预处理

采用一个带通滤波器对数据进行滤波,保留0.5~40Hz的成分,并进行去均值的处理。

此处未进行常规的去均值和标准差归一化处理,以保留信号幅度的相对信息。

python 复制代码
class QRSDetectPreprocessing:
    def __init__(self, fs=250) -> None:
        self.fs = fs
        self._get_filter_params(fs=fs)

    def set_fs(self, fs: int):
        self.fs = fs
        self._get_filter_params(fs=fs)

    def _get_filter_params(self, fs, low=0.5, high=40, order=4):
        self.b, self.a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')

    def _bandpass_filter(self, signal):
        return filtfilt(self.b, self.a, signal)

    def _multileads_bandpass_filter(self, signal):
        sig_f = np.zeros_like(signal)
        for ch in range(signal.shape[0]):
            sig_f[ch, :] = self._bandpass_filter(signal[ch, :])
        return sig_f

    def _remove_mean(self, signal, axis=None):
        if signal.ndim == 1:
            axis = 0
        elif axis is None:
            axis = 1
        return signal - np.mean(signal, axis=axis, keepdims=True)

    def preprocess_signal(self, signal, n_channels=1):
        if not isinstance(signal, np.ndarray):
            signal = np.asarray(signal)
        
        if n_channels == 1:
            signal = self._bandpass_filter(signal)
            signal = self._remove_mean(signal)
        else:
            if signal.shape[0] > signal.shape[1]:
                raise ValueError(
                    f"Input signal shape is {signal.shape}, expected (n_channels, n_samples) with n_channels < n_samples. "
                    "Please check your input."
                )
            signal = self._multileads_bandpass_filter(signal)
            signal = self._remove_mean(signal)
        return signal

模型设计

模型设计借助GPT,设计了一个CNN+LSTM的结构

python 复制代码
class QRS_CNN_BiLSTM(nn.Module):
    def __init__(self,
                 input_channels=1,
                 cnn_channels=32,
                 lstm_hidden_size=64,
                 lstm_layers=1,
                 kernel_sizes=kernel_size):
        super(QRS_CNN_BiLSTM, self).__init__()
        assert len(kernel_sizes) == 3, "kernel_sizes 必须是长度为 3 的列表"

        # 根据 kernel_sizes 计算 padding
        paddings = [k // 2 for k in kernel_sizes]

        # CNN 特征提取模块,用 Sequential 封装
        self.cnn = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels,
                      kernel_size=kernel_sizes[0],
                      padding=paddings[0]),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),

            nn.Conv1d(cnn_channels, cnn_channels,
                      kernel_size=kernel_sizes[1],
                      padding=paddings[1]),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),

            nn.Conv1d(cnn_channels, cnn_channels,
                      kernel_size=kernel_sizes[2],
                      padding=paddings[2]),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU()
        )

        # 双向 LSTM
        self.lstm = nn.LSTM(input_size=cnn_channels,
                            hidden_size=lstm_hidden_size,
                            num_layers=lstm_layers,
                            batch_first=True,
                            bidirectional=True)

        # 输出层:逐点输出
        self.fc = nn.Linear(lstm_hidden_size * 2, 1)

    def forward(self, x):
        # x: (B, 1, seq_len)
        x = self.cnn(x)             # (B, cnn_channels, seq_len)
        x = x.permute(0, 2, 1)      # (B, seq_len, cnn_channels)

        lstm_out, _ = self.lstm(x)  # (B, seq_len, hidden*2)
        logits = self.fc(lstm_out)  # (B, seq_len, 1)
        logits = logits.squeeze(-1) # (B, seq_len)

        return logits

评价指标

评价指标采用实际QRS波检测的表现进行,即统计正确检测个数,漏检个数及误检个数,然后计算得到F1。

训练中保存F1最高的模型结果,将结果用于测试集上进行测试,以及后面的静态数据推理。

测试结果

训练结束后,通过加载保存的最佳模型,在测试集数据上进行测试,总体来说结果还可以。

在9个测试数据上,取得0.999717的F1,正确检出19412个QRS波,误检3个,漏检了8个。在正确检出的QRS波中,与参考QRS波位置距离的平均值为0.469ms,标准差为5.486ms。

实际应用测试

在腾讯云服务器上实现两个工程,一是QRS波检出模型和心电质量检测模型推理的部署,另一个是心电图Web显示。通过浏览器输入公网地址打开Web前端,然后上传JSON格式的12通道心电数据,通过对各通道进行质量分析后,在高质量通道上进行QRS波检测,最后综合各通道的QRS波得到最后的检测结果。流程如下:

  1. 浏览器输入公网地址:http://1.12.228.183,进入心电Web页面
  1. 点击页面"上传JSON"按钮,选择心电JSON格式上传,此时显示心电样式
  1. 点击"预测"按钮,此时显示预测得到的结果,目前暂时仅完成了波形质量检测和QRS波检测两个模型。信号质量包含"高/中/低"三类,图上分别以"绿/橙/红"显示;QRS波检测结果在图中上方以倒三角的方式标注。

总结

尝试了通过MIT-BIH数据库进行QRS波深度学习的训练与测试,并通过前端Web测试模型在静态12通道心电数据的效果。从结果来看,目前的模型的表现还较为良好,不过在实际的应用中,还需要多走一些数据,将误检或漏检的数据加入模型中,提高模型的泛化能力。

相关推荐
铮铭5 小时前
【论文阅读】OpenDriveVLA:基于大型视觉语言动作模型的端到端自动驾驶
人工智能·机器学习·自动驾驶
一碗白开水一7 小时前
【第30话:路径规划】自动驾驶中Hybrid A星(A*)搜索算法的详细推导及代码示例
人工智能·算法·机器学习·计算机视觉·数学建模·自动驾驶
audyxiao0018 小时前
一文可视化分析2025年8月arXiv机器学习前沿热点
人工智能·机器学习·arxiv
胖达不服输8 小时前
「日拱一码」098 机器学习可解释——PDP分析
人工智能·机器学习·机器学习可解释·pdp分析·部分依赖图
美码师9 小时前
向量那点事儿
机器学习
AI小云9 小时前
【机器学习与实战】分类与聚类算法:KNN鸢尾花分类
机器学习
憨憨爱编程10 小时前
机器学习-多因子线性回归
人工智能·机器学习·线性回归
悟乙己10 小时前
机器学习(MLOps)系统在线部署的基本指南
人工智能·机器学习·模型部署·mlops
林文韬32717 小时前
语义精炼技巧生成对抗网络(3)基于Wasserstein GAN 的特征生成
深度学习·机器学习·生成对抗网络