Build a Fine-Tuned Neural Network with TensorFlow's Keras API
text
Build a fine-tuned neural network with TensorFlow's Keras API
In this episode, we'll demonstrate how to fine-tune a pre-trained model to classify images as cats and dogs.
VGG16 and ImageNet
The pre-trained model we'll be working with to classify images of cats and dogs is called VGG16, which is the model that won the 2014 ImageNet competition.
In the ImageNet competition, multiple teams compete to build a model that best classifies images from the ImageNet library. The ImageNet library houses thousands of images belonging to
1000
different categories.
We'll import this VGG16 model and then fine-tune it using Keras. The fine-tuned model will not classify images as one of the 1000
categories for which it was trained on, but instead
it will only work to classify images as either cats or dogs.
Note that dogs and cats were included in the ImageNet library from which VGG16 was originally trained. Therefore, the model has already learned the features of cats and dogs. Given this, the fine-tuning we'll do on this model will be very minimal. In later episodes, we'll do more involved fine-tuning and utilize transfer learning to classify completely new data than what was included in the training set.
To understand fine-tuning and transfer learning on a fundamental level, check out the corresponding episode in the Deep Learning Fundamentals course.
VGG16 Preprocessing
Let's first check out a batch of training data using the plotting function we brought in previously.
imgs, labels = next(train_batches)
plotImages(imgs)
print(labels)
[[1. 0.]
[0. 1.]
[1. 0.]
[0. 1.]
[0. 1.]
[1. 0.]
[0. 1.]
[0. 1.]
[0. 1.]
[1. 0.]]
When we
previously inspected these images, we briefly discussed that the color data was skewed as a result of preprocessing the images using the tf.keras.applications.vgg16.preprocess_input
function.
To understand what preprocessing is needed for images that will be passed to a VGG16 model, we can look at the VGG16 paper.
Under the 2.1 Architecture section, we can see that the authors stated that, "The only preprocessing we do is subtracting the mean RGB value, computed on the training set, from each pixel."
This is the preprocessing that was used on the original training data, and therefore, this is the way we need to process images before passing them to VGG16 or a fine-tuned VGG16 model.
This processing is what is causing the underlying color data to look distorted.
Building a fine-tuned model
Now, let's begin building our model. First, be sure that you still have all the imports that we brought in a couple episodes back when we began our work on CNNs.
Next, we'll import the VGG16 model from Keras. Note, an internet connection is needed to download this model.
vgg16_model = tf.keras.applications.vgg16.VGG16()
The original trained VGG16 model, along with its saved weights and other parameters, is now downloaded onto our machine.
We can check out a summary of the model just to see what the architecture looks like.
vgg16_model.summary()
Model: "vgg16"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 224, 224, 3)] 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
_________________________________________________________________
flatten (Flatten) (None, 25088) 0
_________________________________________________________________
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
predictions (Dense) (None, 1000) 4097000
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________
In contrast, recall how much simpler the CNN was that we worked with in the last episode. VGG16 is much more complex and sophisticated and has many more layers than our previous model.
Notice that the last Dense
layer of VGG16 has 1000
outputs. These outputs correspond to the 1000
categories in the ImageNet library.
Since we're only going to be classifying two categories, cats and dogs, we need to modify this model in order for it to do what we want it to do, which is to only classify cats and dogs.
Before we do that, note that the type of Keras models we've been working with so far in this series have been of type Sequential
.
If we check out the type of model vgg16_model
is, we see that it is of type Model
, which is from the Keras' Functional
API.
type(vgg16_model)
tensorflow.python.keras.engine.training.Model
We've not yet worked with the more sophisticated Functional
API, although we will work with it in
later episodes using the MobileNet model.
For now, we're going to go through a process to convert the Functional
model to a Sequential
model, so that it will be easier for us to work with given our current knowledge.
We first create a new model of type Sequential
. We then iterate over each of the layers in vgg16_model
, except for the last layer, and add each layer to the new
Sequential
model.
model = Sequential()
for layer in vgg16_model.layers[:-1]:
model.add(layer)
Now, we have replicated the entire vgg16_model
(excluding the output layer) to a new Sequential
model, which we've just given the name model
.
Next, we'll iterate over each of the layers in our new Sequential
model and set them to be non-trainable. This freezes the weights and other trainable parameters in each layer so that
they will not be trained or updated when we later pass in our images of cats and dogs.
for layer in model.layers:
layer.trainable = False
The reason we don't want to retrain these layers is because, as mentioned earlier, cats and dogs were already included in the original ImageNet library. So, VGG16 already does a nice job at classifying these categories. We only want to modify the model such that the output layer understands only how to classify cats and dogs and nothing else. Therefore, we don't want any re-training to occur on the earlier layers.
Next, we add our new output layer, consisting of only 2
nodes that correspond to cat
and dog
. This output layer will be the only trainable layer in the model.
model.add(Dense(units=2, activation='softmax'))
We can now check out a summary
of our model and see that everything is exactly the same as the original vgg16_model
, except for now, the output layer has only 2
nodes,
rather than 1000
, and the number of
trainable parameters has drastically decreased since we froze all the parameters in the earlier layers.
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
_________________________________________________________________
flatten (Flatten) (None, 25088) 0
_________________________________________________________________
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
dense_1 (Dense) (None, 2) 8194
=================================================================
Total params: 134,268,738
Trainable params: 8,194
Non-trainable params: 134,260,544
In the next episode, we'll see how we can train this modified model on our images of cats and dogs.
quiz
resources
updates
Committed by on