Training Vision Transformers with Only 2040 Images: Implementation

Category:

Tags:

Paper: https://arxiv.org/pdf/2201.10728v1.pdf

Github: https://github.com/niranjankrishna-acad/Training-Vision-Transformers-with-Only-2040-Images

Architecture

The architecture of the model is quite straightforward. It’s the losses that are more complicated. For the network, we use a vision transformer combined with a linear layer to predict the probabilities of each class. We’re going to be using the Vision Transformer library, vit-pytorch, for initializing the vision transformer.

Loss

There are two losses in the original paper.

  1. Instance Discrimination Loss
  2. Contrasive Learning Loss

Instance Discrimination Loss

The instance discrimination loss, \(L_{InsDis}\) is defined as follows

\[L_{InsDis} = – \sum^{N}_{i=1}\sum^{N}_{c=1}y_{c}^{i}\log{P_{c}^{i}}\]

where \(c\) sums over classes and \(i\) sums over instances passed to the network, which are batches

Contrastive Learning Loss

The constrasive learning loss \(L_{CN}\) is defined as follows

\[L_{CN} = -\sum^{N}_{i=1}z_{iA}^{T}z_{iB} + \sum^{N}_{i=1} \log{(e^{z_{iA}^{T}z_{iB}}} + \sum{e^{z_{iA}^{T}z_{i}^{-}}})\]

Here \(z_{iA}\) and \(z_{iB}\) are features extracted from augmented versions of the image \(x_{i}\). Let’s write an augmentation layer first

Next up is the loss class.

Preprocessing

The preprocessing part is pretty straightforward. We define a custom dataset class for extracting the data into data loader. We also define all the 17 classes for the flower dataset.

Training

The training part is pretty straightforward as well. You just have to iterate through the batches and add up the two losses.

Conclusion

Well, that’s it for now. We’ve implemented the paper “Training Vision Transformers with Only 2040 Images” in PyTorch and it seems to work. Thanks for taking the patience to read.


Leave a Reply

%d