Multi-task Deep Learning

Introduction

Most machine learning projects only require models to perform a single specific task. For example, suppose we were to classify emails into spam and not spam. We would develop a classifier such that for a given email it would return the appropriate labels. But what if we needed to not only determine whether the email was spam, but also parse specific information from the email like any specific people or places mentioned? We could use two separate models: one for spam classification and another for extracting people and places from the email. But this may dramatically increase inference time especially if both models are quite large which is usually the case for deep learning. Modern solutions to natural language processing problems use large pre-trained transformers with smaller heads for specific tasks such as text classification. The vast majority of the weights for these models are found in these backbone transformers. So if we were to use these models in the original problem, we would need the same input to pass through tens of millions of the same weights. Naturally, we may instead design a model architecture that is capable of performing both tasks simultaneously.

In this post, we will develop a neural network capable of performing two tasks on the same input simultaneously. In this process we'll note what considerations are necessary given the increased complexity of the task. We will then develop another model this time using information from one head to inform the other. We'll conclude with some final thoughts on this exciting sub-field of machine learning and the value it provides to practical problems.

One Model, Many Tasks

Let's now design a neural network to address the problem outlined at the start of this post: given an email, can we 1) determine if it is spam or not spam and 2) extract the names of any people or places found in that email. In terms of machine learning tasks, we can frame these two concerns as text classification and named entity recognition by way of token classification respectively. We'll assume in this example that these tasks are independent from one another and so we do not need to worry about feeding information from one head to the next.

We'll start by looking at what two separate models may look like.

 
 

In PyTorch these models may look like this:

# Separate models for email analysis
from torch import nn
from transformers import BertModel


class SpamClassifier(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SpamClassifier, self).__init__()
        self.backbone = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, 2)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x


class EmailNERModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(EmailNERModel, self).__init__()
        self.backbone = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

Note the parameter count for these models. More than 99% of the parameters in each model are in the backbone model. Since both models accept the same input - email text in this case - and architecturally they already have many of the same parameters, then this is a prime situation to employ multi-task learning.

Reduced latency, simplified management, and potentially improved accuracy are the three most significant practical benefits of multi-tasked learning. The reduction in latency comes directly from the deduplication in parameters: the fewer parameters a model has, the less computation required to produce inference. By reducing the two models into one we also can make training and deployment much easier since we have one less model to worry about. Another potential benefit too is an increase in accuracy. Since we are providing one model more data on our problem by way of additional labels, its possible to find better overall weights for the model.

For most fine-tuning problems, the simplest implementation of a multi-task learning model is to has a single backbone with multiple heads where each head corresponds to a specific task like in the diagram below:

 
 

The corresponding PyTorch would look like this:

# Mutli-task model
class MultiTaskModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MultiTaskModel, self).__init__()
        self.backbone = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Sequential(
            nn.Linear(input_size, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, 2)
        )
        self.ner = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.backbone(x)
        x1 = self.classifier(x)
        x2 = self.ner(x)
        return x1, x2

While this model has more parameters than one of the individual models above, it has significantly less than both of the individual models combined. Below is a table comparing the parameter count and memory count of the two configurations:

 

Taming the Hydra

Multi-task learning does not come without its challenges. Certain decisions must be made with care especially in the training of these models in order to reap their full benefit. The overall efficacy and impact of the model is a direct result of seemingly small decisions made during training.

The two most important decisions when developing a multi-task model are 1) how many weights to unfreeze during training and 2) how to properly formulate the loss function. In fact how you decide to handle the first problem directly influences the second. Most multi-task architectures like the ones we're considering really only present two choices on which weights to freeze during training. The first option is to freeze the backbone and to only update the weights in each head during training. This is a great option when the pretrained backbone itself is already suitable for your task and can just function as a feature extractor for the individual heads. Obviously this greatly reduces the number of parameters to train as well and so will be relatively quick to train. If you keep the backbone frozen, then the individual task heads can be trained independent of each other. The training loop will essentially be one of training two smaller models whose input is embedded by the same pre-trained network. No real considerations of one network affecting the other are present in this formulation. The obvious downside, however, is that if the pre-trained backbone is not sufficiently related to the domain you are fine-tuning for, then you are limiting the model's ability to appropriately fit to the data by freezing the vast majority of the weights. Below is a brief sketch of the training loop for this first configuration:

for inputs, labels in dataset:
	embedded_input = backbone(inputs)
	# Task 1
	y_hat_task_1 = head1(embedded_input)
	task_1_loss = cross_entropy(y_hat_task_1, labels)
	head1.backprop(task_1_loss)
	# Task 2
	y_hat_task_2 = head2(embedded_input)
	task_2_loss = cross_entropy(y_hat_task_1, labels)
	head1.backprop(task_1_loss)

The other option is to unfreeze all of the weights of the network so that all of the parameters are updated during training. This allows the model to better fit to the data, but presents a new problem for training: how to balance the losses. When all of the weights are unfrozen, the two different heads of the model are no longer independent since the loss of one may change the backbone weights which influences in the input of the other head. There are a couple of ways to address this. The most naive solution would be to just treat the losses of the two heads equally and back propagate the weights on their sum. Let's look at this specifically in the context of our email problem. The basic task of both the spam classifier and named entity recognition heads is classification, so we can use the cross entropy loss for both.

This naive approach may work if the following are true:

  1. The loss functions are on the same scale. We don't want one loss to naturally dominant the other since the model would be encouraged to largely ignore the other head. In this case, we are using cross entropy for both heads which may seem to suggest that one loss won't dominate the other but recall that cross entropy loss has a lower bound of zero and upper bound of positive infinity. The upper bound of the loss function is a function of the total number of labels.

  2. The two heads are of equal importance. This is usually the case unless other complexities should be considered. For example, if the output of one head should inform another, then it's more important that that first head is as accurate as possible so to not propagate errors.

A more sophisticated approach to mitigate the issues outlined above would be to introduce some hyper-parameter to balance the two losses against each other. For two heads like in the email problem, we only need to introduce a single hyper-parameter to balance the losses:

Introducing a new hyper-parameter is never ideal since it will require careful tuning, but if you understand the bounds of your loss functions then it is typically pretty clear what bounds to set on your hyper-parameter search. The lambda parameter is playing two important roles in the above loss function: it is scaling the second loss in relation to the first and so can represent the importance of one over the other.

The choice of which formulation to use is dependent upon the nature of the data largely. In general, if there is a good backbone model that embeds the inputs sufficiently then keeping the backbone frozen and updating the task heads individually will by far be the easiest option. But if the backbone also needs to be tuned to the data to achieve the desired results, then introducing the lambda hyper-parameter to the combined loss function will prove easier than trying to determine a static scaling factor.

Wrapping Up

We started with two distinct models with millions of duplicated parameters and finish with a single model capable of multiple tasks that also offers better performance with only a small trade off in terms of training difficulty. Multi-task learning may not be a standard tool in the data scientist's toolbox but I hope to have shown how it deserves to be so. Machine learning products are proliferating and are often times asked to extrapolate many data points from the same input information, and multi-task models offer a solution.

Previous
Previous

“Achieving AI” Isn’t A Milestone