Image labeling vs classification models
Comparing the loss functions of label and classification models
- Image Classification vs Image Labeling
- CrossEntropyLoss vs BCEWithLogitsLoss
- Exploring CrossEntropyLoss
- Exploring BCEWithLogitsLoss
- Discussion
The 'hello world' example for introducing deep learning based computer vision often involves classifying images as 🐶 or 🐱. An alternative approach to classifying images is to instead apply labels. This is usually introduced in the context of multi-label classification i.e. where an image can have more than one label. In this blog post I discuss some of the differences between these two approaches, specifically the difference in loss functions, and how these two approaches might work better depending on the application. The post starts with a conceptual overview of the differences between these two approaches, before showing the different loss functions and then moving to a practical example of training these two different types of model.
In a classification model, an input can have only one label. This could be one of a few or one of a hundred, regardless of the number of potential classes, it is assumed that the input only belongs to one of these. With a model that applies labels this is not true an input can have one, multiple or no labels.
Sorting through family photos
We can use an analogy to illustrate the difference between these two approaches. Let's say you were sorting through some old family photographs. You might "classify" the photos into one (and only one) of two photo albums, depending on whether they are black-and-white or colour. This would be comparable to using a classification model since each photo will go into exactly one of these two albums - a photo cannot be both simultaneously colour and black-and-white, and it cannot be neither colour nor black-and-white.
You may at the same time also want to make it easier to find photos of particular people in your family. You could do this by assigning labels to each photo, indicating or "tagging" the family members who appear in the photo. In this case, a photo may have one label (a photo of your sister), more than one label (a photo of your sister and aunt), or it may have no labels (a photograph of a landscape taken on a holiday). This would be analogous to a multi-label classification model.
The choice between using a model which performs classification or a model which assigns labels should be considered in relation to the role your model has. It is also useful to look a little bit more closely as how these different types of models work under the hood.
CrossEntropyLoss vs BCEWithLogitsLoss
When we create a model which does classifications or applies labels, the distinction, if using the same data is that they use different loss functions.
A classification model will use a variant of Cross Entropy Loss whilst the label model will use a BCE with Logits Loss. We'll see how this is inferred by fastai below but fore now take my word for it...
Let's take a look at a snippet of the Pytorch docs for each of these loss functions
CrossEntropyLoss
This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class. It is useful when training a classification problem with C classes. If provided, the optional argument weight should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. Read more
BCEWithLogitsLoss
This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability. Read more
Let's see what these do to some activations. First we'll import required packages
import torch.nn as nn
import numpy as np
import torch
We can create some fake activations. To start we'll just consider one output with three classes. We'll start with one to keep things simple for now.
one_act = torch.randn((1, 3)) * 1
one_act
We can think of these activations as probabilities for one of three classes. Let's see what these sum to.
one_act.sum()
We can see that these activations don't sum to 1. If we want our image input to belong to only one class, then the labels are not mutually exclusive of each other i.e. if one label probability is higher, another needs to be lower i.e. the probabilities need to add up to 1. Going back to the Pytorch explanation of CrossEntropyLoss
we see that one component is nn.LogSoftmax()
. What is particularly relevant here is that 'softmax' part. Let's see what this does to our activation
softmax_acts = torch.softmax(one_act, dim=1)
softmax_acts
You can probably already see how this has changed the nature of these activations. Let's call sum on these outputs again.
softmax_acts.sum()
We now have a sum of 1! We can now treat this as the probability of an input image belonging to a particular class. We could then call argmax to find out which class the model is most confident about and use that as our prediction.
softmax_acts.argmax(dim=1)
One of the potential issues that was mentioned about using a classification model was that it doesn't account for ambiguities in the labels very well.
What is softmax doing?
Digging into what softmax
does in a little bit more detail will show what is going on here.
First lets see what softmax actually does, I'll skip the LaTeX formula from Wikepedia because it makes is look much scarier than the Python code example:
a = [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]
np.exp(a) / np.sum(np.exp(a))
This is much easier for me to parse compared to the Greek. Let's look at the different parts. Working with one set of activations again:
one_act
Starting from np.exp(a)
we can do this in Pytorch like:
one_act.exp()
We can convert the rest of the numpy code as follows
one_act.exp().sum(dim=1)
Putting it all together we get
(one_act.exp() /one_act.exp().sum(dim=1)).sum(dim=1)
This seems to work as expected, i.e. we get the probabilities to sum to 1. To make it clearer what's going on though, it's useful to look a little more closely at the difference using exp
makes. Let's import the standard python version of exp
and check the docs.
from math import exp
doc(exp)
What difference does using the exponent make? We'll use a simple array of values to keep things simple
x = np.array([1,2,4,1])
x
Now if we want these to be converted to probabilities for different classes we need them to sum to 1. We could just do this by dividing each element by the sum.
x/x.sum()
We can confirm this add to 1
(x/x.sum()).sum()
Now this seems to work to get us probabilities for each class. Let's compare doing the same thing but using exp
to create exponents of the inputs
np.exp(x)/np.sum(np.exp(x))
Again we get an array of probabilities, let's confirm they add to one.
one_act.exp()/one_act.exp().sum(dim=1)
np.exp(x)/ np.sum(np.exp(x)), (x/ x.sum())
Other than the difference in decimals, you will probably notice that when we use exponent, some labels for a class have been pushed much higher. Index 2
is 0.80
when we use exp
and only 0.5
when we don't use the exponent. This is an important difference here. By using the magic properties of $e$ we 'push' one probability to be higher than the others.
This property is useful when we have a clear distinction between classes. If we were predicting handwritten digits there (should) only be one correct answer. In this case having one class prediction being pushed much higher would be a good thing.
If however, we have labels which are more ambiguous, this would be less of a desirable property. Even if we try and capture ambiguity by using the raw probabilities of the labels, rather than taking the argmax
value, the numerical properties of the softmax function mean that it likely that one label value will be pushed higher than the others.
We'll look at a practical example later on to illustrate this. Let's now quickly compare our other loss function
one_act
As a reminder sigmoid function can be plotted as
You'll probably be familiar with sigmoid as one of the potential activations functions you can use in the a neural network. The property we care about is that it squishes inputs into a value between 0 and 1. Let's do this for our activations
torch.sigmoid(one_act)
We can see that all our values have been pushed between 0 and 1. However, we can also see they don't sum to 1.
torch.sigmoid(one_act).sum()
What we have here is a probability for each label which is independent of the probability of the other labels. The sigmoid function makes sure the activations for each label becomes a probability but it doesn't make sure that all of the labels probabilities sum to 1. Looking at a practical example using fastai might illustrate this difference.
We'll work with some images taken from 19th Century books, the specific images in this case don't matter to do much
We'll import fastai and then put images from two folders 'building' and 'coat' into a Pandas DataFrame.
from fastai.vision.all import *
files = get_image_files('data/cv_workshop_exercise_data/', folders=['building', 'coat'])
df = pd.DataFrame(files.items, columns=['fname'])
df['class_label'] = df['fname'].apply(lambda x: x.parts[2])
df['class_label'].value_counts()
We can see we have two possible classes building
and coat
. First we'll load these into fastai as a classification model.
dls_classification = ImageDataLoaders.from_df(df,fn_col='fname',valid_pct=0.4, label_col='class_label', item_tfms=Resize(128, ResizeMethod.Squish), bs=8,num_workers=0)
dls_classification.show_batch()
You'll see that building refers to a building, whilst a coat refers to a coat of arms. Let's now load this data into fastai
learn = cnn_learner(dls_classification, resnet18, metrics=[accuracy, F1Score()])
Often if we pass fastai a dataloader it will be able to infer the correct loss function based on this data. we can access this using the loss_func
attribute.
learn.loss_func
As promised this is a variant on the CrossEntropyLoss we saw earlier. Let's now fit it for a bit.
learn.fit(5)
Now we have a model, we'll grab the predictions
acts, _ = learn.get_preds()
acts
These are the predictions for each class, let's confirm these all sum to 1.
acts.sum(dim=1)
If we look at the max for each probability we'll see they tend to be high.
acts.max(dim=1)[0]
Looking at the mean, max and min:
acts.max(dim=1)[0].mean(), acts.max(dim=1)[0].max(), acts.max(dim=1)[0].min(),
This is desirable if the input we are trying to label does neatly fit the categories but if we are trying to label something which is more ambiguous then this might be less useful. A particular case where this certainty might not be so helpful is when your model may possibly face out of domain images, i.e. see things it hasn't seen before and for which none of the classes it is trying to predict should apply. Let's load a new dataset of images of people.
people = get_image_files('data/cv_workshop_exercise_data/', folders='people')
people
PILImage.create(people[5])