Skip to main content

Angkor Wat, Cambodia

Tf Image Classifier

DeiT Vision Transformer (Transfer-Learning)

Training data-efficient image transformers & distillation through attention

Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Herve J'egou

"Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. These highperforming vision transformers are pre-trained with hundreds of millions of images using a large infrastructure, thereby limiting their adoption. In this work, we produce competitive convolution-free transformers by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks."

!pip install transformers
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay)
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import (
Conv2D,
BatchNormalization,
LayerNormalization,
Dense,
Input,
Embedding,
MultiHeadAttention,
Layer,
Add,
Resizing,
Rescaling,
Permute,
Flatten,
RandomFlip,
RandomRotation,
RandomContrast,
RandomBrightness
)
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy, TopKCategoricalAccuracy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import image_dataset_from_directory

from transformers import DeiTConfig, TFDeiTModel
SEED = 42
LABELS = ['Gladiolus', 'Adenium', 'Alpinia_Purpurata', 'Alstroemeria', 'Amaryllis', 'Anthurium_Andraeanum', 'Antirrhinum', 'Aquilegia', 'Billbergia_Pyramidalis', 'Cattleya', 'Cirsium', 'Coccinia_Grandis', 'Crocus', 'Cyclamen', 'Dahlia', 'Datura_Metel', 'Dianthus_Barbatus', 'Digitalis', 'Echinacea_Purpurea', 'Echinops_Bannaticus', 'Fritillaria_Meleagris', 'Gaura', 'Gazania', 'Gerbera', 'Guzmania', 'Helianthus_Annuus', 'Iris_Pseudacorus', 'Leucanthemum', 'Malvaceae', 'Narcissus_Pseudonarcissus', 'Nerine', 'Nymphaea_Tetragona', 'Paphiopedilum', 'Passiflora', 'Pelargonium', 'Petunia', 'Platycodon_Grandiflorus', 'Plumeria', 'Poinsettia', 'Primula', 'Protea_Cynaroides', 'Rose', 'Rudbeckia', 'Strelitzia_Reginae', 'Tropaeolum_Majus', 'Tussilago', 'Viola', 'Zantedeschia_Aethiopica']
NLABELS = len(LABELS)
BATCH_SIZE = 32
SIZE = 224
EPOCHS = 40
LR = 5e-6 # default 0.001
HIDDEN_SIZE = 768 # default 768
NHEADS = 8 # default 12
NLAYERS = 4 # default 12

Dataset

train_directory = '../dataset/Flower_Dataset/split/train'
test_directory = '../dataset/Flower_Dataset/split/val'
train_dataset = image_dataset_from_directory(
train_directory,
labels='inferred',
label_mode='categorical',
class_names=LABELS,
color_mode='rgb',
batch_size=BATCH_SIZE,
image_size=(SIZE, SIZE),
shuffle=True,
seed=SEED,
interpolation='bilinear',
follow_links=False,
crop_to_aspect_ratio=False
)

# Found 9206 files belonging to 48 classes.
test_dataset = image_dataset_from_directory(
test_directory,
labels='inferred',
label_mode='categorical',
class_names=LABELS,
color_mode='rgb',
batch_size=BATCH_SIZE,
image_size=(SIZE, SIZE),
shuffle=True,
seed=SEED
)

# Found 3090 files belonging to 48 classes.
data_augmentation = Sequential([
RandomRotation(factor=0.25),
RandomFlip(mode='horizontal'),
RandomContrast(factor=0.1),
RandomBrightness(0.1)
],
name="img_augmentation",
)
resize_rescale_reshape = Sequential([
Resizing(SIZE, SIZE),
Rescaling(1./255),
# transformer expects image shape (3,224,224)
Permute((3,1,2))
])
training_dataset = (
train_dataset
.map(lambda image, label: (data_augmentation(image), label))
.prefetch(tf.data.AUTOTUNE)
)


testing_dataset = (
test_dataset.prefetch(
tf.data.AUTOTUNE
)
)

DeiT Model

# Initializing a ViT vit-base-patch16-224 style configuration
configuration = DeiTConfig(
image_size=SIZE,
hidden_size=HIDDEN_SIZE,
num_attention_heads=NHEADS,
num_hidden_layers=NLAYERS
)

# Initializing a model with random weights from the vit-base-patch16-224 style configuration
# base_model = TFViTModel(configuration)

# use pretrained weights for the model instead
base_model = TFDeiTModel.from_pretrained("facebook/deit-base-distilled-patch16-224", config=configuration)

# Accessing the model configuration
configuration = base_model.config
configuration
input = Input(shape=(SIZE,SIZE,3))
# random image augmentation
data_aug = data_augmentation(input)
x = resize_rescale_reshape(data_aug)
x = base_model.deit(x)[0][:,0,:]
output = Dense(NLABELS, activation='softmax')(x)

deit_model = Model(inputs=input, outputs=output)
deit_model.summary()
# testing the pretrained model
test_image = cv.imread('../dataset/snapshots/Viola_Tricolor.jpg')
test_image = cv.resize(test_image, (SIZE, SIZE))
deit_model(tf.expand_dims(test_image, axis = 0))
# numpy= array([[1.0963462e-01, 4.4628163e-03, 2.7227099e-03, 3.9012067e-02,
# 1.2207581e-02, 3.4460202e-02, 2.3577355e-03, 3.5261197e-03,
# 1.7803181e-02, 1.0567555e-02, 1.5943516e-02, 4.0797489e-03,
# 7.1987398e-03, 9.5541059e-04, 4.2675242e-02, 1.5655500e-04,
# 1.1215543e-02, 1.4889235e-02, 1.8372904e-01, 7.0088580e-03,
# 3.1637046e-03, 1.4315472e-03, 8.3367303e-03, 1.5427665e-03,
# 1.9941023e-02, 9.9778855e-03, 5.6907861e-03, 1.7462631e-03,
# 3.6991950e-02, 1.3322993e-02, 5.4029688e-02, 4.0368687e-02,
# 6.1121010e-03, 7.9112053e-03, 7.2245464e-02, 8.8621033e-03,
# 2.1858371e-03, 3.0036021e-02, 2.7811823e-02, 7.0134280e-03,
# 6.1850133e-03, 1.8044524e-02, 2.3036957e-02, 1.6069075e-02,
# 2.3161862e-02, 2.9986592e-03, 1.0242336e-02, 1.6933089e-02]],
# dtype=float32)

Model Training

loss_function = CategoricalCrossentropy()
metrics = [CategoricalAccuracy(name='accuracy')]
deit_model.compile(
optimizer = Adam(learning_rate = LR),
loss = loss_function,
metrics = metrics
)
deit_history = deit_model.fit(
training_dataset,
validation_data = testing_dataset,
epochs = EPOCHS,
verbose = 1
)

# loss: 0.3592
# accuracy: 0.8945
# val_loss: 0.7199
# val_accuracy: 0.7900

Model Evaluation

deit_model.evaluate(testing_dataset)
# loss: 0.7199 - accuracy: 0.7900
plt.plot(deit_history.history['loss'])
plt.plot(deit_history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'val_loss'])

plt.savefig('assets/DeiT_01.webp', bbox_inches='tight')

Building a Tensorflow VIT

plt.plot(deit_history.history['accuracy'])
plt.plot(deit_history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train_accuracy', 'val_accuracy'])

plt.savefig('assets/DeiT_02.webp', bbox_inches='tight')

Building a Tensorflow VIT

test_image_bgr = cv.imread('../dataset/snapshots/Viola_Tricolor.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)

probs = deit_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]

print(label, str(probs[0]))

plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')

plt.savefig('assets/DeiT_Prediction_01.webp', bbox_inches='tight')

Building a Tensorflow VIT

test_image_bgr = cv.imread('../dataset/snapshots/Strelitzia.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)

probs = deit_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]

print(label, str(probs[0]))

plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')

plt.savefig('assets/DeiT_Prediction_02.webp', bbox_inches='tight')

Building a Tensorflow VIT

test_image_bgr = cv.imread('../dataset/snapshots/Water_Lilly.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)

probs = deit_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]

print(label, str(probs[0]))

plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')

plt.savefig('assets/DeiT_Prediction_03.webp', bbox_inches='tight')

Building a Tensorflow VIT

plt.figure(figsize=(16,16))

for images, labels in testing_dataset.take(1):
for i in range(16):
ax = plt.subplot(4,4,i+1)
true = "True: " + LABELS[tf.argmax(labels[i], axis=0).numpy()]
pred = "Predicted: " + LABELS[
tf.argmax(deit_model(tf.expand_dims(images[i], axis=0)).numpy(), axis=1).numpy()[0]
]
plt.title(
true + "\n" + pred
)
plt.imshow(images[i]/255.)
plt.axis('off')

plt.savefig('assets/DeiT_03.webp', bbox_inches='tight')

Building a Tensorflow VIT

y_pred = []
y_test = []

for img, label in testing_dataset:
y_pred.append(deit_model(img))
y_test.append(label.numpy())
conf_mtx = ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix(
np.argmax(y_test[:-1], axis=-1).flatten(),
np.argmax(y_pred[:-1], axis=-1).flatten()
),
display_labels=LABELS
)

fig, ax = plt.subplots(figsize=(16,12))
conf_mtx.plot(ax=ax, cmap='plasma', include_values=True, xticks_rotation='vertical')

plt.savefig('assets/DeiT_04.webp', bbox_inches='tight')

Building a Tensorflow VIT

Saving the Model

tf.keras.saving.save_model(
deit_model, '../saved_model/deit_model', overwrite=True, save_format='tf'
)
# restore the model
restored_model = tf.keras.saving.load_model('../saved_model/deit_model')
# Check its architecture
restored_model.summary()
restored_model.evaluate(testing_dataset)
# loss: 0.5184 - accuracy: 0.7840 - topk_accuracy: 0.9394