Tf Image Classifier
DeiT Vision Transformer (Transfer-Learning)
Training data-efficient image transformers & distillation through attention
Hugo Touvron
,Matthieu Cord
,Matthijs Douze
,Francisco Massa
,Alexandre Sablayrolles
,Herve J'egou
"Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. These highperforming vision transformers are pre-trained with hundreds of millions of images using a large infrastructure, thereby limiting their adoption. In this work, we produce competitive convolution-free transformers by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks."
!pip install transformers
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay)
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import (
Conv2D,
BatchNormalization,
LayerNormalization,
Dense,
Input,
Embedding,
MultiHeadAttention,
Layer,
Add,
Resizing,
Rescaling,
Permute,
Flatten,
RandomFlip,
RandomRotation,
RandomContrast,
RandomBrightness
)
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy, TopKCategoricalAccuracy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import image_dataset_from_directory
from transformers import DeiTConfig, TFDeiTModel
SEED = 42
LABELS = ['Gladiolus', 'Adenium', 'Alpinia_Purpurata', 'Alstroemeria', 'Amaryllis', 'Anthurium_Andraeanum', 'Antirrhinum', 'Aquilegia', 'Billbergia_Pyramidalis', 'Cattleya', 'Cirsium', 'Coccinia_Grandis', 'Crocus', 'Cyclamen', 'Dahlia', 'Datura_Metel', 'Dianthus_Barbatus', 'Digitalis', 'Echinacea_Purpurea', 'Echinops_Bannaticus', 'Fritillaria_Meleagris', 'Gaura', 'Gazania', 'Gerbera', 'Guzmania', 'Helianthus_Annuus', 'Iris_Pseudacorus', 'Leucanthemum', 'Malvaceae', 'Narcissus_Pseudonarcissus', 'Nerine', 'Nymphaea_Tetragona', 'Paphiopedilum', 'Passiflora', 'Pelargonium', 'Petunia', 'Platycodon_Grandiflorus', 'Plumeria', 'Poinsettia', 'Primula', 'Protea_Cynaroides', 'Rose', 'Rudbeckia', 'Strelitzia_Reginae', 'Tropaeolum_Majus', 'Tussilago', 'Viola', 'Zantedeschia_Aethiopica']
NLABELS = len(LABELS)
BATCH_SIZE = 32
SIZE = 224
EPOCHS = 40
LR = 5e-6 # default 0.001
HIDDEN_SIZE = 768 # default 768
NHEADS = 8 # default 12
NLAYERS = 4 # default 12
Dataset
train_directory = '../dataset/Flower_Dataset/split/train'
test_directory = '../dataset/Flower_Dataset/split/val'
train_dataset = image_dataset_from_directory(
train_directory,
labels='inferred',
label_mode='categorical',
class_names=LABELS,
color_mode='rgb',
batch_size=BATCH_SIZE,
image_size=(SIZE, SIZE),
shuffle=True,
seed=SEED,
interpolation='bilinear',
follow_links=False,
crop_to_aspect_ratio=False
)
# Found 9206 files belonging to 48 classes.
test_dataset = image_dataset_from_directory(
test_directory,
labels='inferred',
label_mode='categorical',
class_names=LABELS,
color_mode='rgb',
batch_size=BATCH_SIZE,
image_size=(SIZE, SIZE),
shuffle=True,
seed=SEED
)
# Found 3090 files belonging to 48 classes.
data_augmentation = Sequential([
RandomRotation(factor=0.25),
RandomFlip(mode='horizontal'),
RandomContrast(factor=0.1),
RandomBrightness(0.1)
],
name="img_augmentation",
)
resize_rescale_reshape = Sequential([
Resizing(SIZE, SIZE),
Rescaling(1./255),
# transformer expects image shape (3,224,224)
Permute((3,1,2))
])
training_dataset = (
train_dataset
.map(lambda image, label: (data_augmentation(image), label))
.prefetch(tf.data.AUTOTUNE)
)
testing_dataset = (
test_dataset.prefetch(
tf.data.AUTOTUNE
)
)
DeiT Model
# Initializing a ViT vit-base-patch16-224 style configuration
configuration = DeiTConfig(
image_size=SIZE,
hidden_size=HIDDEN_SIZE,
num_attention_heads=NHEADS,
num_hidden_layers=NLAYERS
)
# Initializing a model with random weights from the vit-base-patch16-224 style configuration
# base_model = TFViTModel(configuration)
# use pretrained weights for the model instead
base_model = TFDeiTModel.from_pretrained("facebook/deit-base-distilled-patch16-224", config=configuration)
# Accessing the model configuration
configuration = base_model.config
configuration
input = Input(shape=(SIZE,SIZE,3))
# random image augmentation
data_aug = data_augmentation(input)
x = resize_rescale_reshape(data_aug)
x = base_model.deit(x)[0][:,0,:]
output = Dense(NLABELS, activation='softmax')(x)
deit_model = Model(inputs=input, outputs=output)
deit_model.summary()
# testing the pretrained model
test_image = cv.imread('../dataset/snapshots/Viola_Tricolor.jpg')
test_image = cv.resize(test_image, (SIZE, SIZE))
deit_model(tf.expand_dims(test_image, axis = 0))
# numpy= array([[1.0963462e-01, 4.4628163e-03, 2.7227099e-03, 3.9012067e-02,
# 1.2207581e-02, 3.4460202e-02, 2.3577355e-03, 3.5261197e-03,
# 1.7803181e-02, 1.0567555e-02, 1.5943516e-02, 4.0797489e-03,
# 7.1987398e-03, 9.5541059e-04, 4.2675242e-02, 1.5655500e-04,
# 1.1215543e-02, 1.4889235e-02, 1.8372904e-01, 7.0088580e-03,
# 3.1637046e-03, 1.4315472e-03, 8.3367303e-03, 1.5427665e-03,
# 1.9941023e-02, 9.9778855e-03, 5.6907861e-03, 1.7462631e-03,
# 3.6991950e-02, 1.3322993e-02, 5.4029688e-02, 4.0368687e-02,
# 6.1121010e-03, 7.9112053e-03, 7.2245464e-02, 8.8621033e-03,
# 2.1858371e-03, 3.0036021e-02, 2.7811823e-02, 7.0134280e-03,
# 6.1850133e-03, 1.8044524e-02, 2.3036957e-02, 1.6069075e-02,
# 2.3161862e-02, 2.9986592e-03, 1.0242336e-02, 1.6933089e-02]],
# dtype=float32)
Model Training
loss_function = CategoricalCrossentropy()
metrics = [CategoricalAccuracy(name='accuracy')]
deit_model.compile(
optimizer = Adam(learning_rate = LR),
loss = loss_function,
metrics = metrics
)
deit_history = deit_model.fit(
training_dataset,
validation_data = testing_dataset,
epochs = EPOCHS,
verbose = 1
)
# loss: 0.3592
# accuracy: 0.8945
# val_loss: 0.7199
# val_accuracy: 0.7900
Model Evaluation
deit_model.evaluate(testing_dataset)
# loss: 0.7199 - accuracy: 0.7900
plt.plot(deit_history.history['loss'])
plt.plot(deit_history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'val_loss'])
plt.savefig('assets/DeiT_01.webp', bbox_inches='tight')
plt.plot(deit_history.history['accuracy'])
plt.plot(deit_history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train_accuracy', 'val_accuracy'])
plt.savefig('assets/DeiT_02.webp', bbox_inches='tight')
test_image_bgr = cv.imread('../dataset/snapshots/Viola_Tricolor.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)
probs = deit_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
print(label, str(probs[0]))
plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')
plt.savefig('assets/DeiT_Prediction_01.webp', bbox_inches='tight')
test_image_bgr = cv.imread('../dataset/snapshots/Strelitzia.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)
probs = deit_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
print(label, str(probs[0]))
plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')
plt.savefig('assets/DeiT_Prediction_02.webp', bbox_inches='tight')
test_image_bgr = cv.imread('../dataset/snapshots/Water_Lilly.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)
probs = deit_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
print(label, str(probs[0]))
plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')
plt.savefig('assets/DeiT_Prediction_03.webp', bbox_inches='tight')
plt.figure(figsize=(16,16))
for images, labels in testing_dataset.take(1):
for i in range(16):
ax = plt.subplot(4,4,i+1)
true = "True: " + LABELS[tf.argmax(labels[i], axis=0).numpy()]
pred = "Predicted: " + LABELS[
tf.argmax(deit_model(tf.expand_dims(images[i], axis=0)).numpy(), axis=1).numpy()[0]
]
plt.title(
true + "\n" + pred
)
plt.imshow(images[i]/255.)
plt.axis('off')
plt.savefig('assets/DeiT_03.webp', bbox_inches='tight')
y_pred = []
y_test = []
for img, label in testing_dataset:
y_pred.append(deit_model(img))
y_test.append(label.numpy())
conf_mtx = ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix(
np.argmax(y_test[:-1], axis=-1).flatten(),
np.argmax(y_pred[:-1], axis=-1).flatten()
),
display_labels=LABELS
)
fig, ax = plt.subplots(figsize=(16,12))
conf_mtx.plot(ax=ax, cmap='plasma', include_values=True, xticks_rotation='vertical')
plt.savefig('assets/DeiT_04.webp', bbox_inches='tight')
Saving the Model
tf.keras.saving.save_model(
deit_model, '../saved_model/deit_model', overwrite=True, save_format='tf'
)
# restore the model
restored_model = tf.keras.saving.load_model('../saved_model/deit_model')
# Check its architecture
restored_model.summary()
restored_model.evaluate(testing_dataset)
# loss: 0.5184 - accuracy: 0.7840 - topk_accuracy: 0.9394