最近闲暇之余,尝试了通过大模型训练进行心电数据的分析,通过结合AI工具进行研究,感觉得到的结果比几年前的要高了一些。
当前的主要内容是进行QRS波检测的研究,通过MIT-BIH的心律失常数据库获取训练集、验证集和测试集数据,采用CNN+LSTM的模型,基于pytorch框架进行训练。
部署测试采用基于云上的进行,购买了一个腾讯云服务,在云上搭建了一个简单的后端预测推理,以及前端Web显示,开放端口后可通过IP地址访问前端Web,将数据转为约定的JSON格式后上传进行推理预测后显示。
数据选取
选用MIT-BIH心律失常数据库中的数据,去掉其中的起搏信号数据(4例起搏数据),在剩余的44例数据中,随机打乱后选取30例作为训练数据,5例作为验证数据,9例作为测试数据。
为了兼容目前静态常用的500Hz采样率,所有的数据均由360Hz重采样为250Hz。
数据与标签
数据采用重采样为250Hz后的原始数据,标签为与数据长度相等的0/1数据列表,在QRS波左右各50ms,总共100ms的范围为1,其他地方为0。
数据片段的长度为2秒钟共500点,为了增加数据量,采用步长为1秒进行截取,同时为了减少通道数据冗余,第2各通道的截取起始点与第一个通道有0.5秒钟的偏移。
数据预处理
采用一个带通滤波器对数据进行滤波,保留0.5~40Hz的成分,并进行去均值的处理。
此处未进行常规的去均值和标准差归一化处理,以保留信号幅度的相对信息。
python
class QRSDetectPreprocessing:
def __init__(self, fs=250) -> None:
self.fs = fs
self._get_filter_params(fs=fs)
def set_fs(self, fs: int):
self.fs = fs
self._get_filter_params(fs=fs)
def _get_filter_params(self, fs, low=0.5, high=40, order=4):
self.b, self.a = butter(order, [low/(fs/2), high/(fs/2)], btype='band')
def _bandpass_filter(self, signal):
return filtfilt(self.b, self.a, signal)
def _multileads_bandpass_filter(self, signal):
sig_f = np.zeros_like(signal)
for ch in range(signal.shape[0]):
sig_f[ch, :] = self._bandpass_filter(signal[ch, :])
return sig_f
def _remove_mean(self, signal, axis=None):
if signal.ndim == 1:
axis = 0
elif axis is None:
axis = 1
return signal - np.mean(signal, axis=axis, keepdims=True)
def preprocess_signal(self, signal, n_channels=1):
if not isinstance(signal, np.ndarray):
signal = np.asarray(signal)
if n_channels == 1:
signal = self._bandpass_filter(signal)
signal = self._remove_mean(signal)
else:
if signal.shape[0] > signal.shape[1]:
raise ValueError(
f"Input signal shape is {signal.shape}, expected (n_channels, n_samples) with n_channels < n_samples. "
"Please check your input."
)
signal = self._multileads_bandpass_filter(signal)
signal = self._remove_mean(signal)
return signal
模型设计
模型设计借助GPT,设计了一个CNN+LSTM的结构
python
class QRS_CNN_BiLSTM(nn.Module):
def __init__(self,
input_channels=1,
cnn_channels=32,
lstm_hidden_size=64,
lstm_layers=1,
kernel_sizes=kernel_size):
super(QRS_CNN_BiLSTM, self).__init__()
assert len(kernel_sizes) == 3, "kernel_sizes 必须是长度为 3 的列表"
# 根据 kernel_sizes 计算 padding
paddings = [k // 2 for k in kernel_sizes]
# CNN 特征提取模块,用 Sequential 封装
self.cnn = nn.Sequential(
nn.Conv1d(input_channels, cnn_channels,
kernel_size=kernel_sizes[0],
padding=paddings[0]),
nn.BatchNorm1d(cnn_channels),
nn.ReLU(),
nn.Conv1d(cnn_channels, cnn_channels,
kernel_size=kernel_sizes[1],
padding=paddings[1]),
nn.BatchNorm1d(cnn_channels),
nn.ReLU(),
nn.Conv1d(cnn_channels, cnn_channels,
kernel_size=kernel_sizes[2],
padding=paddings[2]),
nn.BatchNorm1d(cnn_channels),
nn.ReLU()
)
# 双向 LSTM
self.lstm = nn.LSTM(input_size=cnn_channels,
hidden_size=lstm_hidden_size,
num_layers=lstm_layers,
batch_first=True,
bidirectional=True)
# 输出层:逐点输出
self.fc = nn.Linear(lstm_hidden_size * 2, 1)
def forward(self, x):
# x: (B, 1, seq_len)
x = self.cnn(x) # (B, cnn_channels, seq_len)
x = x.permute(0, 2, 1) # (B, seq_len, cnn_channels)
lstm_out, _ = self.lstm(x) # (B, seq_len, hidden*2)
logits = self.fc(lstm_out) # (B, seq_len, 1)
logits = logits.squeeze(-1) # (B, seq_len)
return logits
评价指标
评价指标采用实际QRS波检测的表现进行,即统计正确检测个数,漏检个数及误检个数,然后计算得到F1。
训练中保存F1最高的模型结果,将结果用于测试集上进行测试,以及后面的静态数据推理。
测试结果
训练结束后,通过加载保存的最佳模型,在测试集数据上进行测试,总体来说结果还可以。
在9个测试数据上,取得0.999717的F1,正确检出19412个QRS波,误检3个,漏检了8个。在正确检出的QRS波中,与参考QRS波位置距离的平均值为0.469ms,标准差为5.486ms。

实际应用测试
在腾讯云服务器上实现两个工程,一是QRS波检出模型和心电质量检测模型推理的部署,另一个是心电图Web显示。通过浏览器输入公网地址打开Web前端,然后上传JSON格式的12通道心电数据,通过对各通道进行质量分析后,在高质量通道上进行QRS波检测,最后综合各通道的QRS波得到最后的检测结果。流程如下:
- 浏览器输入公网地址:http://1.12.228.183,进入心电Web页面

- 点击页面"上传JSON"按钮,选择心电JSON格式上传,此时显示心电样式

- 点击"预测"按钮,此时显示预测得到的结果,目前暂时仅完成了波形质量检测和QRS波检测两个模型。信号质量包含"高/中/低"三类,图上分别以"绿/橙/红"显示;QRS波检测结果在图中上方以倒三角的方式标注。

总结
尝试了通过MIT-BIH数据库进行QRS波深度学习的训练与测试,并通过前端Web测试模型在静态12通道心电数据的效果。从结果来看,目前的模型的表现还较为良好,不过在实际的应用中,还需要多走一些数据,将误检或漏检的数据加入模型中,提高模型的泛化能力。