Image Classification with Pre-trained Computer Vision Models in PyTorch

Blog Post

Computer Vision | Pytorch


How good are state of the art Computer Vision models at identifying my cat Ada in a picture? This is, more or less, the question that prompted the exploration presented here. In this article, we will look at a practical example of using pretrained computer vision models available in PyTorch for inference, to do image classification. In particular, we will use the ResNet model to classify images containing a number of images of pets. Note that we will not be performing any type of training (such as finetuning the pretrained models on our training task), but rather output one of the classes available from the model itself. We will go through each step of the process, from image pre-processing, retrieving predictions from the pretrained model, and analyzing the output. Of course, you can use the same methodology and predict on your own images. The images we want to classify are:

...
...
...
...
...
...


Alright, let's get started and see how we can use ResNet to classify these lovely pets, and also, examine how good the predictions are!

Image Classification

But first, let's briefly cover the task of Image Classification. Image Classification answers the following question:
Given a image -> what does that image represent?
How a model answers this question is dependent on what type and how much data the model has be trained on. There are a number of datasets that are used as benchmarks for Image Classification, such as MNIST (for digit recognition), CIFAR-10/100, or ImageNet. For more information about datasets, benchmarks and state of the art models for Image Classification, I recommend checking out the Papers With Code website.

ResNet, the model we will be using in this article is trained on the ImageNetdataset, which contains 1000 classes, among which a number of cat and dog breeds. Metrics of more models on this dataset can be found here. Note that while there are models performing higher than ResNet on the ImageNet dataset, the performance of the ResNet architecture is high enough that it should still give good results. The Pytorch implementation of ResNet with 152 layers (which we will be using) has a top-1% error of 21.69 and top-5% error of 5.94.
Okay, so let's get started with using the ResNet model for inference using the Pytorch implementation!

Imports

First, we need to import a number of packages. We need the torch and torchvision packages to get the pretrained models, numpy to transform the images, and matplotlib to display the images.

  %matplotlib inline 
  import torch
  import torchvision
  import torchvision.transforms as transforms
  import matplotlib.pyplot as plt
  import numpy as np

Initialize pretrained model

Now, we initialize the pretrained computer vision model we will be using: ResNet. Using Pytorch, we can do so in only one line of code, as shown below. ResNet is a state of the art deep learning model trained on ImageNet. For more details about the model itself, you can visit its webpage or read the corresponding paper.

Things to note here:
  1. While we are using ResNet, Pytorch offers a comprehensive list of other computer vision models you can load and experiment with. The entire list of models can be found here, under 'Vision'. The method in which you can load them is the same. Some examples you can experiment with include VGGNet , Inception or AlexNet
  2. ResNet has 5 different versions, which contain 18, 34, 50, 101, 152 layers. Note that here, we load the biggest one, which contains 152 layers
  3. We need to make sure to set the model.eval() so that the model is in evaluation, and not training, mode

  model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True); 
  model.eval()

Upload a photo

Great, now that we have initialized the model, we can load a photo. We will be using the PIL library to do this, and numpy to load the pixels.

  from PIL import Image
  import numpy as np
  # creating a image object
  im = Image.open("/content/ada.jpeg")
  pixels = np.asarray(im)
To display the picture, we can simply run:

  display(im)
...
This is the un-processed version of the image. We can look at the initial dimension of the image.

  pixels.shape
Output: (4032, 3024, 3)

Transform Image

Note that, according to the PyTorch page for ResNet, the model expects images in the format (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. Consequently, we need to transform the image to follow this format and also load the data into a tensor format. We follow the methodology suggested by PyTorch:

  def resnet_transform_images(mu, std):
      return transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=mu, std=std),
  ])

  def imshow(images, mu, std_var):
    npimg = torchvision.utils.make_grid(images).numpy()
    npimg_transp = np.transpose(npimg, (1, 2, 0)) 
    npimg_transp = npimg_transp * std_var + mu     # unnormalize
    plt.imshow((npimg_transp * 255).astype(np.uint8))

  mu = [0.485, 0.456, 0.406]
  std = [0.229, 0.224, 0.225]
  img = resnet_transform_images(mu, std)(im)
  img_tensor = img.unsqueeze(0)

  img_tensor.shape
Output: torch.Size([1, 3, 224, 224])

We can now see that the dimension of the processed image is (1, 3, 224, 224), which will fit into the ResNet model. That also means that the image has been cropped, and, as we can see, re-centered:

  imshow(img_tensor, mu, std)

Get Predictions

Now that we have an image that can fit into the ResNet model, let's use the model to do inference and get the predicted classes.
First, let's download the classes used by ImageNet. We will need these in order to know how to map from the classes the model outputs for an image (as numbers), to the descriptive labels they correspond to. Note that this command will download a file imagenet_classes.txt in the current folder.
 !wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
Now let's read this file to obtain the classes ResNet was trained to predict. Note that these are the classes coming from ImageNet.

  with open("imagenet_classes.txt", "r") as f:
    categories = np.array([s.strip() for s in f.readlines()])
 len(categories)
Output: 1000
We can see that we have a total of 1000 classes. Now let's print some sample classes - here are 10 samples:
 np.random.choice(categories, 10, replace = False)
Output: array(['Boston bull', 'Cardigan', 'Labrador retriever', 'velvet', 'leopard', 'gar', 'hand-held computer', 'running shoe', 'apron', 'red-backed sandpiper'])

Now let's see what our model predicts on our initial image!

  def predict(images, model):
    with torch.no_grad():
      outputs = model(images)
      probs =  torch.nn.functional.softmax(outputs, dim=1)
      probs_labels, pred_labels = torch.topk(probs, 5)

      probs_labels = probs_labels.squeeze().numpy()
      pred_labels = pred_labels.squeeze().numpy()
    return pred_labels, probs_labels

  print('Predictions')
  print('=' * 30)
  preds, probs = predict(img_tensor, model)
  for pred, prob in zip(preds, probs):
    print('Category: %s, Probability: %2f' % (categories[pred], prob))
Predictions
==============================
Category: tabby, Probability: 0.470609
Category: Egyptian cat, Probability: 0.406148
Category: tiger cat, Probability: 0.081435
Category: photocopier, Probability: 0.006049
Category: printer, Probability: 0.004371

It looks like for this image, the top choices are different cat breeds, so the model is working pretty well! While my cat Ada is not a tabby or an Egyptian cat, and only metaphorically a tiger cat, the fact that the top 3 predicted classes are cat-related is encouraging. Let's try predicting a few more pictures!
...
Predictions
=================
Category: laptop,
Probability: 0.504977
Category: notebook,
Probability: 0.246912
Category: computer keyboard,
Probability: 0.129945
Category: space bar,
Probability: 0.049168
Category: mouse,
Probability: 0.020669
...
Predictions
=================
Category: Egyptian cat,
Probability: 0.220818
Category: tabby,
Probability: 0.093764
Category: tiger cat,
Probability: 0.043862
Category: purse,
Probability: 0.037481
Category: sleeping bag,
Probability: 0.034918
...
Predictions
=================
Category: Egyptian cat,
Probability: 0.935708
Category: tiger cat,
Probability: 0.031417
Category: tabby,
Probability: 0.030745
Category: lynx,
Probability: 0.001550
Category: remote control,
Probability: 0.000174
...
Predictions
=================
Category: tabby,
Probability: 0.467902
Category: Egyptian cat,
Probability: 0.423090
Category: tiger cat,
Probability: 0.049405
Category: quilt,
Probability: 0.017489
Category: remote control,
Probability: 0.006777
...
Predictions
=================
Category: Egyptian cat,
Probability: 0.930395
Category: Siamese cat,
Probability: 0.026431
Category: tabby,
Probability: 0.007882
Category: quilt,
Probability: 0.007549
Category: tiger cat,
Probability: 0.004132
...
Predictions
=================
Category: meerkat,
Probability: 0.547480
Category: mongoose,
Probability: 0.138101
Category: baboon,
Probability: 0.104021
Category: fox squirrel,
Probability: 0.103073
Category: macaque,
Probability: 0.022226

Analysis

Correct predictions: Okay, so we can see that the predictions are largely in the correct ballpark. For the first four pictures, the model outputs different types of cat breeds, with 'Egyptian cat' and 'Tabby' as the most commonly predicted ones. While cat connoisseurs might disagree that these are indeed 'Egyptian' or 'Tabby' cats, it is possible that these are the closest classes available from the model. If we wanted to understand what are all the cat breeds that the model could have outputted, we'd have to search through the categories for cat breeds, to see what the model might have missed.

Incorrect predictions: Meanwhile, there are other observations we can make. For instance, for the first picture in the sequence of six above, we can see that the model did not recognize at all that there is a cat in the picture! The model picked up on the laptop in the picture as the top prediction, but completely missed the cat - even in its subsequent class predictions! It is possible that this happened because during the pre-processing of the picture, the cat's ears got cut out from the image. One hypothesis could be that this was a key signal that the model was using to identify a cat present in the image. However, to be more certain in our hypothesis, we would have to make further analyses.

Class probabilities Another interesting thing is looking at the probabilities for the top predicted class. In some cases, such as the third and fifth picture, the probabilities for 'Egyptian cat' are quite high, above 0.9, showing that the model has a high confidence in its predictions. In other instances, such as the second and the forth picture, and even the picture we first predicted on, we see that the probabilities are quite low, smaller than 0.5, so the model is not very certain in its predictions. It is possible that lighting, position of the cat and its features in the picture, or picture cropping also affect these probabilities.

Image pre-processing Last but not least, the last picture of a squirrel actually got a top prediction of a 'meerkat'! This makes sense when we see that during picture pre-processing, the head of the squirrel got cropped out, which might have been a key feature for identification.

Summary

To summarize, in this article, we looked at how to use pretrained Computer Vision models in inference mode, to classify images. We covered:
  1. Brief overview of the Image Classification task
  2. How to initialize a pretrained model in Pytorch
  3. How to pre-process an image for prediction
  4. How to get a prediction using pretrained models
  5. Analysis of predictions

Takeaways

There are also a number of takeaways we can take:
  1. We can load pretrained models from Pytorch and use them for inference with only a few lines of code
  2. We don't have to train models from scratch if the task a model was trained on matches our own task
  3. Computer Vision models take the output in a particular format, and we need to preprocess our images to match that
  4. Even for high-performing models, it's not guaranteed that predictions in practice will be always correct. Image cropping, lighting, occlusions, as well as other features in the picture might affect quality of predictions.

Next Steps

There are a number of next steps you can take, such as:
  1. Use the methodology in this tutorial to get predictions on your own pictures and desired classes!
  2. Experiment with other computer vision models. The entire list of models available from Pytorch can be found here. The method in which you can load them is the same. Some examples you can experiment with include VGGNet , Inception or AlexNet. You can also check out the Papers With Code website for the top performing models on the Image Classification task.
  3. Experiment with other Computer Vision tasks. For a list, check out the Papers With Code website.

References