Back

Blog

Implementation of LSTM-Based Atrial Fibrillation Detection from ECG Signals

AI

Apr 09, 2023

In the article "Classification of ECG Signals Using Long-Term Memory Networks", we proposed a new method to detect atrial fibrillation (AF) from ECG signals using long-term memory networks (LSTM), over-sampling, and time-frequency functions. In this paper, we provide a detailed description of the method implementation, including code for data preprocessing, feature extraction, LSTM classifier training, and model performance evaluation.

This article presents the implementation of our LSTM-based atrial fibrillation detection method, which aims to address classification bias and improve performance by incorporating redundant sampling and time-frequency momentum functions. The code provided in this article is developed using Python and popular libraries such as NumPy, SciPy, and TensorFlow.

Data Preprocessing

To preprocess the raw ECG signals, we apply a band-pass filter to remove noise and baseline wander. We then segment the filtered signals into non-overlapping windows.

import numpy as np
from scipy.signal import butter, filtfilt

def bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def preprocess_data(ecg_data, window_size, fs=1000):
    # Apply band-pass filter
    filtered_ecg = bandpass_filter(ecg_data, lowcut=0.5, highcut=50, fs=fs, order=5)
    
    # Segment the signal into windows
    windows = []
    for i in range(0, len(filtered_ecg) - window_size, window_size):
        windows.append(filtered_ecg[i:i + window_size])
    return np.array(windows)

Feature Extraction

We compute two time-frequency momentum functions, Instantaneous Frequency (IF) and Group Delay (GD), for each segmented window.

from scipy.signal import hilbert

def instantaneous_frequency(signal):
    analytic_signal = hilbert(signal)
    instantaneous_phase = np.unwrap(np.angle(analytic_signal))
    instantaneous_frequency = np.diff(instantaneous_phase) / (2.0 * np.pi)
    return instantaneous_frequency

def group_delay(signal, fs):
    frequency, time, spectrogram = signal.spectrogram(signal, fs)
    group_delay = -np.diff(np.unwrap(np.angle(spectrogram)), axis=0)
    return group_delay

def extract_features(windows):
    features = []
    for window in windows:
        if_feature = instantaneous_frequency(window)
        gd_feature = group_delay(window, fs=1000)
        features.append(np.hstack((if_feature, gd_feature)))
    return np.array(features)

Redundant Sampling

We apply oversampling and undersampling techniques to create a balanced dataset.

from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler

def balance_data(X, y):
    ros = RandomOverSampler()
    rus = RandomUnderSampler()
    X_resampled, y_resampled = ros.fit_resample(X, y)
    X_resampled, y_resampled = rus.fit_resample(X_resampled, y_resampled)
    return X_resampled, y_resampled

LSTM Classifier

We implement and train the LSTM classifier using TensorFlow.

import tensorflow as tf

def build_lstm_classifier(input_shape, num_classes):
    model = tf.keras.Sequential([tf.keras.layers.LSTM(128, input_shape=input_shape, return_sequences=True),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.LSTM(64),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

def train_lstm_classifier(model, X_train, y_train, batch_size, epochs):
    history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.2)
    return history

Model Evaluation

We evaluate the performance of the LSTM classifier on a test dataset.

def evaluate_lstm_classifier(model, X_test, y_test):
    loss, accuracy = model.evaluate(X_test, y_test)
    return loss, accuracy

Putting It All Together

We preprocess the data, extract features, balance the dataset, train the LSTM classifier, and evaluate the model's performance.

# Load ECG data and labels
ecg_data, labels = load_ecg_data()

# Preprocess ECG data
windows = preprocess_data(ecg_data, window_size=1000, fs=1000)

# Extract features
features = extract_features(windows)

# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

# Balance training data
X_train_balanced, y_train_balanced = balance_data(X_train, y_train)

# Build and train the LSTM classifier
input_shape = (X_train_balanced.shape[1], 1)
num_classes = 2
model = build_lstm_classifier(input_shape, num_classes)
history = train_lstm_classifier(model, X_train_balanced, y_train_balanced, batch_size=32, epochs=100)

# Evaluate the LSTM classifier on the test dataset
loss, accuracy = evaluate_lstm_classifier(model, X_test, y_test)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

Further Improvements and Extensions

Although the provided code demonstrates the core concepts of our LSTM-based atrial fibrillation detection method, there are several potential improvements and extensions that can be explored to enhance the classifier's performance and generalizability.

  1. Hyperparameter Tuning

The choice of hyperparameters, such as the number of LSTM layers, the number of hidden units, and the dropout rate, can significantly impact the model's performance. Using techniques like grid search or random search, we can systematically explore different hyperparameter combinations to find the optimal configuration for our task.

  1. Data Augmentation

Data augmentation techniques, such as adding noise, time-shifting, or time-warping, can be employed to artificially increase the size and diversity of the training dataset. This can help improve the model's generalization capabilities and its ability to handle variations in real-world ECG signals.

  1. Alternative Deep Learning Architectures

While LSTM networks are well-suited for time-series data, other deep learning architectures, such as Convolutional Neural Networks (CNNs) or transformers, may also yield good results for ECG signal classification. Exploring these alternative architectures and comparing their performance can provide insights into the best model for atrial fibrillation detection.

  1. Multi-Task Learning

Instead of focusing solely on atrial fibrillation detection, the model can be extended to detect multiple types of cardiac arrhythmias simultaneously. By leveraging multi-task learning, the model can learn shared representations across different arrhythmia types, potentially improving its overall performance.

In this article, we provided a detailed description and code implementation of our LSTM-based atrial fibrillation detection method, which incorporates redundant sampling and time-frequency momentum functions. The provided code serves as a foundation for further experimentation and improvements, and can be adapted to other ECG classification tasks or deep learning architectures. By exploring additional techniques and extending the model to detect multiple cardiac arrhythmias, we aim to contribute to the development of robust and accurate ECG signal classifiers for improved patient care and early intervention.

Anton Emelianov

CTO (Chief Technology Officer)

Other articles

By continuing to use this website you agree to our Cookie Policy