Tf Image Classifier
Xception - FineTuning
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.metrics import (
classification_report,
confusion_matrix,
ConfusionMatrixDisplay)
import seaborn as sns
import tensorflow as tf
from tensorflow.io import TFRecordWriter
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import (
Callback,
CSVLogger,
EarlyStopping,
LearningRateScheduler,
ModelCheckpoint
)
from tensorflow.keras.layers import (
Layer,
GlobalAveragePooling2D,
Conv2D,
MaxPool2D,
Dense,
Flatten,
InputLayer,
BatchNormalization,
Input,
Dropout,
RandomFlip,
RandomRotation,
RandomContrast,
RandomBrightness,
Resizing,
Rescaling
)
from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy, SparseCategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy, TopKCategoricalAccuracy, SparseCategoricalAccuracy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import L2, L1
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.train import Feature, Features, Example, BytesList, Int64List
BATCH = 32
SIZE = 224
SEED = 42
EPOCHS = 20
LR = 0.001
FILTERS = 6
KERNEL = 3
STRIDES = 1
REGRATE = 0.0
POOL = 2
DORATE = 0.05
LABELS = ['Gladiolus', 'Adenium', 'Alpinia_Purpurata', 'Alstroemeria', 'Amaryllis', 'Anthurium_Andraeanum', 'Antirrhinum', 'Aquilegia', 'Billbergia_Pyramidalis', 'Cattleya', 'Cirsium', 'Coccinia_Grandis', 'Crocus', 'Cyclamen', 'Dahlia', 'Datura_Metel', 'Dianthus_Barbatus', 'Digitalis', 'Echinacea_Purpurea', 'Echinops_Bannaticus', 'Fritillaria_Meleagris', 'Gaura', 'Gazania', 'Gerbera', 'Guzmania', 'Helianthus_Annuus', 'Iris_Pseudacorus', 'Leucanthemum', 'Malvaceae', 'Narcissus_Pseudonarcissus', 'Nerine', 'Nymphaea_Tetragona', 'Paphiopedilum', 'Passiflora', 'Pelargonium', 'Petunia', 'Platycodon_Grandiflorus', 'Plumeria', 'Poinsettia', 'Primula', 'Protea_Cynaroides', 'Rose', 'Rudbeckia', 'Strelitzia_Reginae', 'Tropaeolum_Majus', 'Tussilago', 'Viola', 'Zantedeschia_Aethiopica']
NLABELS = len(LABELS)
DENSE1 = 1024
DENSE2 = 128
Dataset
train_directory = '../dataset/Flower_Dataset/split/train'
test_directory = '../dataset/Flower_Dataset/split/val'
train_dataset = image_dataset_from_directory(
train_directory,
labels='inferred',
label_mode='categorical',
class_names=LABELS,
color_mode='rgb',
batch_size=BATCH,
image_size=(SIZE, SIZE),
shuffle=True,
seed=SEED,
interpolation='bilinear',
follow_links=False,
crop_to_aspect_ratio=False
)
# Found 9206 files belonging to 48 classes.
test_dataset = image_dataset_from_directory(
test_directory,
labels='inferred',
label_mode='categorical',
class_names=LABELS,
color_mode='rgb',
batch_size=BATCH,
image_size=(SIZE, SIZE),
shuffle=True,
seed=SEED
)
# Found 3090 files belonging to 48 classes.
data_augmentation = Sequential([
# Resizing(150, 150),
RandomRotation(factor=0.25),
RandomFlip(mode='horizontal'),
RandomContrast(factor=0.1),
RandomBrightness(0.1)
],
name="img_augmentation",
)
training_dataset = (
train_dataset
.map(lambda image, label: (data_augmentation(image), label))
.prefetch(tf.data.AUTOTUNE)
)
testing_dataset = (
test_dataset.prefetch(
tf.data.AUTOTUNE
)
)
Building the Xception TF Model
# transfer learning
backbone = tf.keras.applications.Xception(
input_shape=(SIZE, SIZE, 3),
include_top=False,
weights="imagenet"
)
backbone.trainable = False
input = Input(shape=(SIZE,SIZE,3))
# random image augmentation
data_aug = data_augmentation(input)
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = Rescaling(scale=1 / 127.5, offset=-1)
data_aug_scaled = scale_layer(data_aug)
x = backbone(data_aug_scaled, training=False)
x = GlobalAveragePooling2D()(x)
# Regularize with dropout
x = Dropout(0.2)(x)
x = Dense(DENSE1, activation='relu')(x)
x = BatchNormalization()(x)
x = Dense(DENSE2, activation='relu')(x)
output = Dense(NLABELS, activation='softmax')(x)
xception_model = Model(input, output)
xception_model.summary()
checkpoint_callback = ModelCheckpoint(
'../best_weights',
monitor='val_accuracy',
mode='max',
verbose=1,
save_best_only=True
)
early_stopping_callback = EarlyStopping(
monitor='val_accuracy',
patience=10,
restore_best_weights=True
)
loss_function = CategoricalCrossentropy()
metrics = [CategoricalAccuracy(name='accuracy')]
xception_model.compile(
optimizer = Adam(learning_rate=LR),
loss = loss_function,
metrics = metrics
)
Model Training
xception_history = xception_model.fit(
training_dataset,
validation_data = testing_dataset,
epochs = EPOCHS,
verbose = 1,
# callbacks=[checkpoint_callback, early_stopping_callback]
)
# loss: 0.5019
# accuracy: 0.8426
# val_loss: 0.6363
# val_accuracy: 0.8188
Model Evaluation
xception_model.evaluate(testing_dataset)
# loss: 0.6363 - accuracy: 0.8188
plt.plot(xception_history.history['loss'])
plt.plot(xception_history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss', 'val_loss'])
plt.savefig('assets/Xception_01.webp', bbox_inches='tight')
plt.plot(xception_history.history['accuracy'])
plt.plot(xception_history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train_accuracy', 'val_accuracy'])
plt.savefig('assets/Xception_02.webp', bbox_inches='tight')
test_image_bgr = cv.imread('../dataset/snapshots/Viola_Tricolor.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)
probs = xception_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
print(label, str(probs[0]))
plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')
plt.savefig('assets/Xception_Prediction_01.webp', bbox_inches='tight')
test_image_bgr = cv.imread('../dataset/snapshots/Strelitzia.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)
probs = xception_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
print(label, str(probs[0]))
plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')
plt.savefig('assets/Xception_Prediction_02.webp', bbox_inches='tight')
test_image_bgr = cv.imread('../dataset/snapshots/Water_Lilly.jpg')
test_image_resized = cv.resize(test_image_bgr, (SIZE, SIZE))
test_image_rgb = cv.cvtColor(test_image_resized, cv.COLOR_BGR2RGB)
img = tf.constant(test_image_rgb, dtype=tf.float32)
img = tf.expand_dims(img, axis=0)
probs = xception_model(img).numpy()
label = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
print(label, str(probs[0]))
plt.imshow(test_image_rgb)
plt.title(label)
plt.axis('off')
plt.savefig('assets/Xception_Prediction_03.webp', bbox_inches='tight')
plt.figure(figsize=(16,16))
for images, labels in testing_dataset.take(1):
for i in range(16):
ax = plt.subplot(4,4,i+1)
true = "True: " + LABELS[tf.argmax(labels[i], axis=0).numpy()]
pred = "Predicted: " + LABELS[
tf.argmax(xception_model(tf.expand_dims(images[i], axis=0)).numpy(), axis=1).numpy()[0]
]
plt.title(
true + "\n" + pred
)
plt.imshow(images[i]/255.)
plt.axis('off')
plt.savefig('assets/Xception_FT_03.webp', bbox_inches='tight')
y_pred = []
y_test = []
for img, label in testing_dataset:
y_pred.append(xception_model(img))
y_test.append(label.numpy())
conf_mtx = ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix(
np.argmax(y_test[:-1], axis=-1).flatten(),
np.argmax(y_pred[:-1], axis=-1).flatten()
),
display_labels=LABELS
)
fig, ax = plt.subplots(figsize=(16,12))
conf_mtx.plot(ax=ax, cmap='plasma', include_values=True, xticks_rotation='vertical',)
plt.savefig('assets/Xception_FT_04.webp', bbox_inches='tight')
Saving the Model
tf.keras.saving.save_model(
xception_model, '../saved_model/xception_model', overwrite=True, save_format='tf'
)
# restore the model
restored_model2 = tf.keras.saving.load_model('../saved_model/xception_model')
# Check its architecture
restored_model2.summary()
restored_model2.evaluate(testing_dataset)
# loss: 0.6363 - accuracy: 0.8188