Tensorflow Keras implementation of Learning to tokenize in Vision Transformers
Full credits to Sayak Paul and Aritra Roy Gosthipaty for this work.
Intended uses & limitations
Vision Transformers (Dosovitskiy et al.) and many other Transformer-based architectures (Liu et al., Yuan et al., etc.) have shown strong results in image recognition. The following provides a brief overview of the components involved in the Vision Transformer architecture for image classification:
- Extract small patches from input images.
- Linearly project those patches.
- Add positional embeddings to these linear projections.
- Run these projections through a series of Transformer (Vaswani et al.) blocks.
- Finally, take the representation from the final Transformer block and add a classification head. If we take 224x224 images and extract 16x16 patches, we get a total of 196 patches (also called tokens) for each image. The number of patches increases as we increase the resolution, leading to higher memory footprint. Could we use a reduced number of patches without having to compromise performance? Ryoo et al. investigate this question in TokenLearner: Adaptive Space-Time Tokenization for Videos. They introduce a novel module called TokenLearner that can help reduce the number of patches used by a Vision Transformer (ViT) in an adaptive manner. With TokenLearner incorporated in the standard ViT architecture, they are able to reduce the amount of compute (measured in FLOPS) used by the model. In this example, we implement the TokenLearner module and demonstrate its performance with a mini ViT and the CIFAR-10 dataset. We make use of the following references:
- Official TokenLearner code
- Image Classification with ViTs on keras.io
- TokenLearner slides from NeurIPS 2021
Training and evaluation data
More information needed
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
name | learning_rate | decay | beta_1 | beta_2 | epsilon | amsgrad | weight_decay | exclude_from_weight_decay | training_precision |
---|---|---|---|---|---|---|---|---|---|
AdamW | 0.0010000000474974513 | 0.0 | 0.8999999761581421 | 0.9990000128746033 | 1e-07 | False | 9.999999747378752e-05 | None | float32 |