Breast Histopathology Image Segmentation Part 2
- Part 1: Data Inspection and Pre-processing
- Part 2: Weights, Data Augmentations and Generators
- Part 3: Model creation based on a pre-trained and a custom model
- Part 4: Train our model to fit the dataset
- Part 5: Evaluate the performance of your trained model
- Part 6: Running Predictions
Based on Breast Histopathology Images by Paul Mooney.
Invasive Ductal Carcinoma (IDC) is the most common subtype of all breast cancers. To assign an aggressiveness grade to a whole mount sample, pathologists typically focus on the regions which contain the IDC. As a result, one of the common pre-processing steps for automatic aggressiveness grading is to delineate the exact regions of IDC inside of a whole mount slide.
Can recurring breast cancer be spotted with AI tech? - BBC News
- Citation: Deep learning for digital pathology image analysis: A comprehensive tutorial with selected use cases
- Dataset: 198,738 IDC(negative) image patches; 78,786 IDC(positive) image patches
Skewed Datasets
Before we can start the training we have to provide a few helper functions. These are identical for the custom model and the pre-trained ResNet50:
./train_CustomModel_32_conv_20k.py ./train_ResNet50_32_20k.py
Adding Weights to balance Data Classes
Get number of image files from a path:
# Method to get the number of files given a path
def retrieveNumberOfFiles(path):
list1 = []
for file_name in glob.iglob(path+'/**/*.png', recursive=True):
list1.append(file_name)
return len(list1)
# Defining the paths to the training, validation, and testing directories
trainPath = config.TRAIN_PATH
valPath = config.VAL_PATH
testPath = config.TEST_PATH
# Checking for the total number of images
totalTrain = retrieveNumberOfFiles(config.TRAIN_PATH)
totalVal = retrieveNumberOfFiles(config.VAL_PATH)
totalTest = retrieveNumberOfFiles(config.TEST_PATH)
Get list of image files from a path:
# Defining a method to get the list of files given a path
def getAllFiles(path):
list1 = []
for file_name in glob.iglob(path+'/**/*.png', recursive=True):
list1.append(file_name)
return list1
# Retrieving all files from train directory
allTrainFiles = getAllFiles(config.TRAIN_PATH)
Get number of benign and malignant images and create a weight for both classes. This helps to prevent overfitting the model to a dominant class (as we have seen in part 1 - the number of benign cases is far greater then the number of malignant):
# Calculating the total number of training images against each class and then store the class weights in a dictionary
trainLabels = [int(p.split(os.path.sep)[-2]) for p in allTrainFiles]
trainLabels = to_categorical(trainLabels)
classSumTotals = trainLabels.sum(axis=0)
classWeight = dict()
# Looping over all classes and calculate the class weights
for i in range(0, len(classSumTotals)):
classWeight[i] = classSumTotals.max() / classSumTotals[i]
Plotting Training Progress
Method to plot accuracy and loss of the training to visualize how the training is progressing:
# Defining a method to plot training and validation accuracy and loss
# H - model fit
# N - number of training epochs
# plotPath - where to store the output file
def training_plot(H, N, plotPath):
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["accuracy"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(plotPath)
Data Augmentation
To improve the performance of an ML model the training data source we are working with should be big and diverse. Augmentation is used to increase the amount of data by adding slightly modified copies of existing data to the train dataset.
In case of images this can be achieved by padding, random rotating, re-scaling. vertical and horizontal flipping, translating, cropping, zooming, darkening/brightening, adding noise or colour modifications.
./train_CustomModel_32_conv_20k.py ./train_ResNet50_32_20k.py
Training augmentation:
# Initialize the training data augmentation object
## preprocess_input will scale input pixels between -1 and 1
## rotation_range is a value in degrees (0-180), a range within which to randomly rotate pictures
## zoom_range is for randomly zooming inside pictures
## width_shift and height_shift are ranges (as a fraction of total width or height) within which to randomly translate pictures vertically or horizontally
## shear_range is for randomly applying shearing transformations
## horizontal_flip and vertical_flip is for randomly flipping half of the images horizontally and vertically resp
## fill_mode is the strategy used for filling in newly created pixels, which can appear after a rotation or a width/height shift
trainAug = ImageDataGenerator(
rescale=1 / 255.0,
rotation_range=30,
zoom_range=0.15,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.15,
horizontal_flip=True,
vertical_flip=True,
fill_mode="nearest")
But only use re-scaling for the validation augmentation:
# Initialize the validation data augmentation object
valAug = ImageDataGenerator(rescale=1/255.0)
Data Generators
Use Data Generators like Keras ImageDataGenerator
to limit the amount of Memory your dataset is occupying. data generators allow you to augment your data in real-time while your model is still training. This limits the amount of data that needs to be loaded into GPU memory - as a large part of it is still being generated in parallel by your CPU.
trainAug = ImageDataGenerator(
rescale=1 / 255.0,
rotation_range=30,
zoom_range=0.15,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.15,
horizontal_flip=True,
vertical_flip=True,
fill_mode="nearest")
# Initialize the training generator
trainGen = trainAug.flow_from_directory(
trainPath,
class_mode="categorical",
target_size=(48, 48),
color_mode="rgb",
shuffle=True,
batch_size=config.BATCH_SIZE)
# Initialize the validation data augmentation object
valAug = ImageDataGenerator(rescale=1/255.0)
# Initialize the validation generator
valGen = valAug.flow_from_directory(
valPath,
class_mode="categorical",
target_size=(48, 48),
color_mode="rgb",
shuffle=False,
batch_size=config.BATCH_SIZE)
# Initialize the testing generator
testGen = valAug.flow_from_directory(
testPath,
class_mode="categorical",
target_size=(48, 48),
color_mode="rgb",
shuffle=False,
batch_size=config.BATCH_SIZE)