Paper: https://arxiv.org/pdf/2201.10728v1.pdf
Github: https://github.com/niranjankrishna-acad/Training-Vision-Transformers-with-Only-2040-Images
Architecture
The architecture of the model is quite straightforward. It’s the losses that are more complicated. For the network, we use a vision transformer combined with a linear layer to predict the probabilities of each class. We’re going to be using the Vision Transformer library, vit-pytorch, for initializing the vision transformer.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
class VisionTransformer(nn.Module): def __init__(self, img_size, z_dim,num_classes): super(VisionTransformer, self).__init__() self.vit = ViT( image_size = img_size, patch_size = 32, num_classes = z_dim, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) self.linear = nn.Linear(in_features=z_dim, out_features=num_classes) def forward(self, x): x = self.vit(x) x = self.linear(x) x = nn.functional.softmax(x) return x |
Loss
There are two losses in the original paper.
- Instance Discrimination Loss
- Contrasive Learning Loss
Instance Discrimination Loss
The instance discrimination loss, \(L_{InsDis}\) is defined as follows
\[L_{InsDis} = – \sum^{N}_{i=1}\sum^{N}_{c=1}y_{c}^{i}\log{P_{c}^{i}}\]
where \(c\) sums over classes and \(i\) sums over instances passed to the network, which are batches
1 2 3 4 5 6 |
class InstanceDiscriminationLoss(nn.Module): def __init__(self): super(InstanceDiscriminationLoss, self).__init__() def forward(self, predictions): return –torch.sum(torch.log(predictions)) |
Contrastive Learning Loss
The constrasive learning loss \(L_{CN}\) is defined as follows
\[L_{CN} = -\sum^{N}_{i=1}z_{iA}^{T}z_{iB} + \sum^{N}_{i=1} \log{(e^{z_{iA}^{T}z_{iB}}} + \sum{e^{z_{iA}^{T}z_{i}^{-}}})\]
Here \(z_{iA}\) and \(z_{iB}\) are features extracted from augmented versions of the image \(x_{i}\). Let’s write an augmentation layer first
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
class RandomAugmentation(nn.Module): def __init__(self): super(RandomAugmentation, self).__init__() self.augment = transforms.Compose( [ transforms.RandomRotation(35), transforms.ColorJitter(), ] ) def forward(self, x): x_a = self.augment(x) x_b = self.augment(x) return x_a, x_b |
Next up is the loss class.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
class ContrastiveLearningLoss(nn.Module): def __init__(self): super(ContrastiveLearningLoss, self).__init__() def forward(self, z_a, z_b): n = len(z_a) alignment = 0 for i in range(n): alignment += –torch.sum(torch.dot(z_a[i].T, z_b[i])) uniformity_loss = 0 for i in range(n): negative_sum = 0 for j in range(n–1): if i == j: continue negative_sum += torch.sum(torch.dot(z_a[i].T, z_b[i])) uniformity = torch.exp(torch.dot(z_a[i].T, z_b[i])) uniformity_loss += negative_sum + uniformity return alignment + uniformity_loss |
Preprocessing
The preprocessing part is pretty straightforward. We define a custom dataset class for extracting the data into data loader. We also define all the 17 classes for the flower dataset.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
labels_2_idx = { “Buttercup”:0, “Colts Foot”:1, “Daffodil”:2, “Daisy”:3, “Dandelion”:4, “Fritilary”:5, “Iris”:6, “Pansy”:7, “Sunflower”:8, “Windflower”:9, “Snowdrop”:10, “Lily Valley”:11, “Bluebell”:12, “Crocus”:13, “Tiger lily”:14, “Tulip”:15, “Cowslip”:16 } idx_2_labels = {} for key, value in labels_2_idx.items(): idx_2_labels[value] = key class CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = open(annotations_file,“r”).read() self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = int(np.floor((idx + 1/17))) if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) print(image, label) return image, label dataset = CustomImageDataset( “data/files.txt”, “data” ) train_set, val_set = torch.utils.data.random_split(dataset, [1360 – 136, 136]) from torch.utils.data import DataLoader train_dataloader = DataLoader(train_set, batch_size=64, shuffle=True) val_dataloader = DataLoader(val_set, batch_size=64, shuffle=True) |
Training
The training part is pretty straightforward as well. You just have to iterate through the batches and add up the two losses.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
epochs = 12 visionTransformer = nn.DataParallel(VisionTransformer(224, 10, 17)) visionTransformer.cuda() randomAugmentation = nn.DataParallel(RandomAugmentation()) randomAugmentation .cuda() instanceDiscriminationLoss = nn.DataParallel(InstanceDiscriminationLoss()) instanceDiscriminationLoss.cuda() contrastiveLearningLoss = nn.DataParallel(ContrastiveLearningLoss()) contrastiveLearningLoss.cuda() optimizer = torch.optim.AdamW(visionTransformer.parameters(), lr=0.001) for epoch in range(epochs): aggregate_loss = 0 for step, batch in enumerate(train_dataloader): if step % 10 == 0 and step != 0: print(“Epoch:{} Step:{} Loss:{:.3f}”.format(epoch, step, aggregate_loss/step)) batch = batch[0].to(torch.float) predictions = visionTransformer(batch) x_a, x_b = randomAugmentation(batch) z_embeddings_a = nn.functional.normalize(visionTransformer.module.vit(x_a), dim=1) z_embeddings_b = nn.functional.normalize(visionTransformer.module.vit(x_b), dim=1) loss_1 = instanceDiscriminationLoss(predictions).mean() loss_2 = contrastiveLearningLoss(z_embeddings_a, z_embeddings_b).mean() total_loss = loss_1 + loss_2 optimizer.zero_grad() total_loss.backward() aggregate_loss += total_loss.item() optimizer.step() print(“Total Epoch {} Loss : {}”.format(epoch, aggregate_loss/len(train_dataloader))) |
Conclusion
Well, that’s it for now. We’ve implemented the paper “Training Vision Transformers with Only 2040 Images” in PyTorch and it seems to work. Thanks for taking the patience to read.