Back

Blog

Distributed Deep Learning in TensorFlow

AI

May 05, 2023

I know that solving complex problems and processing large-scale deep learning problems can be quite a challenge. Fortunately, distributed deep learning comes to our rescue, allowing us to leverage the power of multiple devices and computing resources to better train our models. And what better way to discuss this than with TensorFlow, which offers built-in support for distributed learning using the tf.distribute package.

In this article, I'll dive into distributed deep learning in TensorFlow, delving into model and data parallelism strategies. We'll explore synchronous and asynchronous learning strategies, look at examples of how to use them, and give practical examples to help you implement them in your projects.

In the following sections, we will look at these strategies in detail, understand their inner workings, and analyze their suitability for different use cases. By the end, you will have a good understanding of TensorFlow's distributed learning strategies and will be well prepared to implement them in your projects.

Distributed learning strategies in TensorFlow

Distributed learning is an important aspect of training deep learning models on large data sets, as it allows us to share the computational load across multiple devices or even clusters of devices. TensorFlow, being a popular and versatile deep learning framework, offers us the tf.distribute package, which is equipped with various strategies to seamlessly implement distributed learning.

In the following sections, we will look at these strategies in detail, understand their inner workings, and analyze their suitability for different use cases. By the end, you will have a good understanding of TensorFlow's distributed learning strategies and will be well prepared to implement them in your projects.

Synchronous learning strategies

Synchronous learning strategies are characterized by simultaneous model updates that ensure consistency and accuracy in the learning process. TensorFlow offers us three main synchronous strategies: MirroredStrategy, MultiWorkerMirroredStrategy, and CentralStorageStrategy. Let's take a look at each of them.

MirroredStrategy

MirroredStrategy is a standard TensorFlow synchronous learning strategy that provides data parallelism by replicating the model across multiple devices, usually GPUs. In this strategy, each device processes different minipacks of data and computes gradients independently of each other. Once all devices have completed their computations, the gradients are combined and applied to update the model parameters.

Consider an example. In this example, we will use a more complex model architecture, the deep residual network (ResNet) for image classification. This model consists of several residual blocks.

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Define the residual block
def residual_block(x, filters, strides=1):
    shortcut = x

    x = Conv2D(filters, kernel_size=(3, 3), strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    if strides != 1:
        shortcut = Conv2D(filters, kernel_size=(1, 1), strides=strides, padding='same')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([x, shortcut])
    x = Activation('relu')(x)

    return x

# Define the ResNet model
def create_resnet_model(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    x = Conv2D(64, kernel_size=(7, 7), strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)

    x = residual_block(x, filters=64)
    x = residual_block(x, filters=64)

    x = residual_block(x, filters=128, strides=2)
    x = residual_block(x, filters=128)

    x = residual_block(x, filters=256, strides=2)
    x = residual_block(x, filters=256)

    x = residual_block(x, filters=512, strides=2)
    x = residual_block(x, filters=512)

    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=inputs, outputs=outputs)

    return model

# Instantiate the MirroredStrategy
strategy = tf.distribute.MirroredStrategy()

# Create the ResNet model and compile it within the strategy scope
with strategy.scope():
    input_shape = (224, 224, 3)
    num_classes = 10
    resnet_model = create_resnet_model(input_shape, num_classes)
    resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the ResNet model using the strategy
resnet_model.fit(train_dataset, epochs=10, validation_data=val_dataset)

In this example, we first define a residual block function, which is a building block for the ResNet architecture. We then create a ResNet model with multiple residual blocks, increasing its complexity compared to the previous example. The rest of the code remains unchanged, with the MirroredStrategy being instantiated and used to train the ResNet model on multiple GPUs.

MultiWorkerMirroredStrategy

MultiWorkerMirroredStrategy extends the capabilities of MirroredStrategy to support training across multiple workers, each with potentially multiple devices. This strategy is particularly useful when you need to scale your training process beyond a single machine.

In this example we will use the same complex ResNet model as before, but we will train it using MultiWorkerMirroredStrategy. This will allow us to distribute the learning process across multiple machines, each with multiple GPUs.

import os
import json
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Define the residual block and create_resnet_model functions as shown in the previous example

# Define the strategy and worker configurations
num_workers = 2
worker_ip_addresses = ['192.168.1.100', '192.168.1.101']
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': worker_ip_addresses
    },
    'task': {'type': 'worker', 'index': 0}
})

# Instantiate the MultiWorkerMirroredStrategy
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

# Create the ResNet model and compile it within the strategy scope
with strategy.scope():
    input_shape = (224, 224, 3)
    num_classes = 10
    resnet_model = create_resnet_model(input_shape, num_classes)
    resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the ResNet model using the strategy
resnet_model.fit(train_dataset, epochs=10, validation_data=val_dataset)

In this example, we use the same ResNet model architecture as in the previous MirroredStrategy example. The primary difference is that we now define the number of workers and their IP addresses, and set up the TF_CONFIG environment variable to configure the distributed training. We then instantiate the MultiWorkerMirroredStrategy and train the ResNet model on multiple machines with multiple GPUs.

CentralStorageStrategy

CentralStorageStrategy is another synchronous learning strategy provided by TensorFlow. Unlike MirroredStrategy and MultiWorkerMirroredStrategy, this strategy stores the model's variables in a centralized location (usually the CPU). The gradients are still computed independently on each device, but they are aggregated and applied to the centrally stored variables.

In this example, we will use the same complex ResNet model as before, but we will train it using the CentralStorageStrategy strategy. This strategy allows us to store the model variables in a centralized location (usually the CPU), but to compute gradients independently on each device.

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Define the residual block and create_resnet_model functions as shown in the previous examples

# Instantiate the CentralStorageStrategy
strategy = tf.distribute.experimental.CentralStorageStrategy()

# Create the ResNet model and compile it within the strategy scope
with strategy.scope():
    input_shape = (224, 224, 3)
    num_classes = 10
    resnet_model = create_resnet_model(input_shape, num_classes)
    resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the ResNet model using the strategy
resnet_model.fit(train_dataset, epochs=10, validation_data=val_dataset)

In this example, we use the same ResNet model architecture as in the previous MirroredStrategy and MultiWorkerMirroredStrategy examples. The main difference is that we instantiate the CentralStorageStrategy instead of the other strategies. The rest of the code remains unchanged, and we train the ResNet model using the CentralStorageStrategy. This strategy can be particularly useful when memory constraints on the devices are a concern, as it stores the model's variables in a centralized location.

Asynchronous learning strategies

Asynchronous learning strategies allow devices to update model parameters independently and without waiting for other devices to complete their computations. TensorFlow offers ParameterServerStrategy for implementing asynchronous learning with data and model parallelism.

ParameterServerStrategy

ParameterServerStrategy uses a set of parameter servers storing model variables and a set of workloads responsible for calculating gradients. The worker tasks asynchronously retrieve the latest model parameters from the parameter servers, compute gradients using their local data, and pass the gradients back to the parameter servers, which then update the model parameters.

In this example, we will use the same complex ResNet model as before, but train it using ParameterServerStrategy. This strategy allows us to implement asynchronous learning with data and model parallelism, using a set of parameter servers storing model variables and a set of worker tasks responsible for calculating gradients.

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Define the residual block and create_resnet_model functions as shown in the previous examples

# Define the strategy and cluster configurations
num_ps = 2
num_workers = 4
cluster_spec = tf.train.ClusterSpec({
    'ps': ['ps0.example.com:2222', 'ps1.example.com:2222'],
    'worker': ['worker0.example.com:2222', 'worker1.example.com:2222', 'worker2.example.com:2222', 'worker3.example.com:2222']
})
task_type = 'worker'  # or 'ps' for parameter servers
task_index = 0  # index of the current task (e.g., worker or parameter server)

# Instantiate the ParameterServerStrategy
strategy = tf.distribute.experimental.ParameterServerStrategy()

# Create the ResNet model and compile it within the strategy scope
with strategy.scope():
    input_shape = (224, 224, 3)
    num_classes = 10
    resnet_model = create_resnet_model(input_shape, num_classes)
    resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the ResNet model using the strategy
resnet_model.fit(train_dataset, epochs=10, validation_data=val_dataset)

In this example, we use the same ResNet model architecture as in the previous MirroredStrategy, MultiWorkerMirroredStrategy, and CentralStorageStrategy examples. The main differences are that we define the number of parameter servers and workers, as well as the cluster specification that includes their addresses. We also set the task type and index for the current task. After that, we instantiate the ParameterServerStrategy and train the ResNet model as we did with the other strategies. This strategy is especially effective when both data and model parallelism are required, as well as when there is a tolerance for higher communication overhead.

Choosing the Right Strategy

Selecting the most suitable distributed learning strategy in TensorFlow depends on various factors, including the scale of your deep learning tasks, the available hardware resources, and the communication overhead between devices or workers. Here are some guidelines to help you choose between synchronous and asynchronous strategies based on specific use cases:

  1.  If you have a single machine with multiple GPUs, consider using MirroredStrategy, as it allows you to achieve data parallelism with minimal communication overhead.

  1.  If you need to scale your training process across multiple machines, each with multiple devices, MultiWorkerMirroredStrategy can be an excellent choice.

  1. If memory constraints on the devices are a concern, CentralStorageStrategy might be a suitable option, as it stores the model's variables in a centralized location.

  1. For scenarios that require both data and model parallelism, as well as tolerance for higher communication overhead, ParameterServerStrategy can be an effective asynchronous learning solution

In this article, we delved into the world of distributed deep learning in TensorFlow, exploring various strategies for model and data parallelism. We examined synchronous learning strategies like MirroredStrategy, MultiWorkerMirroredStrategy, and CentralStorageStrategy, as well as asynchronous learning strategies like ParameterServerStrategy. By providing practical examples, we demonstrated how to implement these strategies in TensorFlow and discussed factors to consider when choosing the right strategy for your use case.

You now have a solid understanding of TensorFlow's distributed learning strategies and can confidently apply them to your projects. So, go ahead and explore the tf.distribute package, experiment with different strategies, and optimize your deep learning tasks.

Anton Emelianov

CTO (Chief Technology Officer)

Other articles

By continuing to use this website you agree to our Cookie Policy