Implementation of LSTM-Based Atrial Fibrillation Detection from ECG Signals
Apr 09, 2023
In the article "Classification of ECG Signals Using Long-Term Memory Networks", we proposed a new method to detect atrial fibrillation (AF) from ECG signals using long-term memory networks (LSTM), over-sampling, and time-frequency functions. In this paper, we provide a detailed description of the method implementation, including code for data preprocessing, feature extraction, LSTM classifier training, and model performance evaluation.
This article presents the implementation of our LSTM-based atrial fibrillation detection method, which aims to address classification bias and improve performance by incorporating redundant sampling and time-frequency momentum functions. The code provided in this article is developed using Python and popular libraries such as NumPy, SciPy, and TensorFlow.
Data Preprocessing
To preprocess the raw ECG signals, we apply a band-pass filter to remove noise and baseline wander. We then segment the filtered signals into non-overlapping windows.
import numpy as np
from scipy.signal import butter, filtfilt
def bandpass_filter(data, lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = butter(order, [low, high], btype='band')
return filtfilt(b, a, data)
def preprocess_data(ecg_data, window_size, fs=1000):
# Apply band-pass filter
filtered_ecg = bandpass_filter(ecg_data, lowcut=0.5, highcut=50, fs=fs, order=5)
# Segment the signal into windows
windows = []
for i in range(0, len(filtered_ecg) - window_size, window_size):
windows.append(filtered_ecg[i:i + window_size])
return np.array(windows)
Feature Extraction
We compute two time-frequency momentum functions, Instantaneous Frequency (IF) and Group Delay (GD), for each segmented window.
from scipy.signal import hilbert
def instantaneous_frequency(signal):
analytic_signal = hilbert(signal)
instantaneous_phase = np.unwrap(np.angle(analytic_signal))
instantaneous_frequency = np.diff(instantaneous_phase) / (2.0 * np.pi)
return instantaneous_frequency
def group_delay(signal, fs):
frequency, time, spectrogram = signal.spectrogram(signal, fs)
group_delay = -np.diff(np.unwrap(np.angle(spectrogram)), axis=0)
return group_delay
def extract_features(windows):
features = []
for window in windows:
if_feature = instantaneous_frequency(window)
gd_feature = group_delay(window, fs=1000)
features.append(np.hstack((if_feature, gd_feature)))
return np.array(features)
Redundant Sampling
We apply oversampling and undersampling techniques to create a balanced dataset.
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
def balance_data(X, y):
ros = RandomOverSampler()
rus = RandomUnderSampler()
X_resampled, y_resampled = ros.fit_resample(X, y)
X_resampled, y_resampled = rus.fit_resample(X_resampled, y_resampled)
return X_resampled, y_resampled
LSTM Classifier
We implement and train the LSTM classifier using TensorFlow.
import tensorflow as tf
def build_lstm_classifier(input_shape, num_classes):
model = tf.keras.Sequential([tf.keras.layers.LSTM(128, input_shape=input_shape, return_sequences=True),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.LSTM(64),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
def train_lstm_classifier(model, X_train, y_train, batch_size, epochs):
history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)
return history
Model Evaluation
We evaluate the performance of the LSTM classifier on a test dataset.
def evaluate_lstm_classifier(model, X_test, y_test):
loss, accuracy = model.evaluate(X_test, y_test)
return loss, accuracy
Putting It All Together
We preprocess the data, extract features, balance the dataset, train the LSTM classifier, and evaluate the model's performance.
# Load ECG data and labels
ecg_data, labels = load_ecg_data()
# Preprocess ECG data
windows = preprocess_data(ecg_data, window_size=1000, fs=1000)
# Extract features
features = extract_features(windows)
# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)
# Balance training data
X_train_balanced, y_train_balanced = balance_data(X_train, y_train)
# Build and train the LSTM classifier
input_shape = (X_train_balanced.shape[1], 1)
num_classes = 2
model = build_lstm_classifier(input_shape, num_classes)
history = train_lstm_classifier(model, X_train_balanced, y_train_balanced, batch_size=32, epochs=100)
# Evaluate the LSTM classifier on the test dataset
loss, accuracy = evaluate_lstm_classifier(model, X_test, y_test)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')
Further Improvements and Extensions
Although the provided code demonstrates the core concepts of our LSTM-based atrial fibrillation detection method, there are several potential improvements and extensions that can be explored to enhance the classifier's performance and generalizability.
Hyperparameter Tuning
The choice of hyperparameters, such as the number of LSTM layers, the number of hidden units, and the dropout rate, can significantly impact the model's performance. Using techniques like grid search or random search, we can systematically explore different hyperparameter combinations to find the optimal configuration for our task.
Data Augmentation
Data augmentation techniques, such as adding noise, time-shifting, or time-warping, can be employed to artificially increase the size and diversity of the training dataset. This can help improve the model's generalization capabilities and its ability to handle variations in real-world ECG signals.
Alternative Deep Learning Architectures
While LSTM networks are well-suited for time-series data, other deep learning architectures, such as Convolutional Neural Networks (CNNs) or transformers, may also yield good results for ECG signal classification. Exploring these alternative architectures and comparing their performance can provide insights into the best model for atrial fibrillation detection.
Multi-Task Learning
Instead of focusing solely on atrial fibrillation detection, the model can be extended to detect multiple types of cardiac arrhythmias simultaneously. By leveraging multi-task learning, the model can learn shared representations across different arrhythmia types, potentially improving its overall performance.
In this article, we provided a detailed description and code implementation of our LSTM-based atrial fibrillation detection method, which incorporates redundant sampling and time-frequency momentum functions. The provided code serves as a foundation for further experimentation and improvements, and can be adapted to other ECG classification tasks or deep learning architectures. By exploring additional techniques and extending the model to detect multiple cardiac arrhythmias, we aim to contribute to the development of robust and accurate ECG signal classifiers for improved patient care and early intervention.
Anton Emelianov
CTO (Chief Technology Officer)