- Import all necessary modules:
import os
import numpy as np
import tarfile
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import get_file
- Define an alias to the
tf.data.expertimental.AUTOTUNE
option, which we'll use later:AUTOTUNE = tf.data.experimental.AUTOTUNE
- Define a function to create a residual module in the ResNet architecture. Let's start by specifying the function signature and implementing the first block:
def residual_module(data,
filters,
stride,
reduce=False,
reg=0.0001,
bn_eps=2e-5,
bn_momentum=0.9):
bn_1 = BatchNormalization(axis=-1,
epsilon=bn_eps,
momentum=bn_momentum)(data)
act_1 = ReLU()(bn_1)
conv_1 = Conv2D(filters=int(filters / 4.),
kernel_size=(1, 1),
use_bias=False,
kernel_regularizer=l2(reg))(act_1)
Let's now implement the second and third blocks:
bn_2 = BatchNormalization(axis=-1,
epsilon=bn_eps,
momentum=bn_momentum)(conv_1)
act_2 = ReLU()(bn_2)
conv_2 = Conv2D(filters=int(filters / 4.),
kernel_size=(3, 3),
strides=stride,
padding='same',
use_bias=False,
kernel_regularizer=l2(reg))(act_2)
bn_3 = BatchNormalization(axis=-1,
epsilon=bn_eps,
momentum=bn_momentum)(conv_2)
act_3 = ReLU()(bn_3)
conv_3 = Conv2D(filters=filters,
kernel_size=(1, 1),
use_bias=False,
kernel_regularizer=l2(reg))(act_3)
If reduce=True
, we apply a 1x1 convolution:
if reduce:
shortcut = Conv2D(filters=filters,
kernel_size=(1, 1),
strides=stride,
use_bias=False,
kernel_regularizer=l2(reg))(act_1)
Finally, we combine the shortcut and the third block into a single layer and return that as our output:
x = Add()([conv_3, shortcut])
return x
- Define a function to build a custom ResNet network:
def build_resnet(input_shape,
classes,
stages,
filters,
reg=1e-3,
bn_eps=2e-5,
bn_momentum=0.9):
inputs = Input(shape=input_shape)
x = BatchNormalization(axis=-1,
epsilon=bn_eps,
momentum=bn_momentum)(inputs)
x = Conv2D(filters[0], (3, 3),
use_bias=False,
padding='same',
kernel_regularizer=l2(reg))(x)
for i in range(len(stages)):
stride = (1, 1) if i == 0 else (2, 2)
x = residual_module(data=x,
filters=filters[i + 1],
stride=stride,
reduce=True,
bn_eps=bn_eps,
bn_momentum=bn_momentum)
for j in range(stages[i] - 1):
x = residual_module(data=x,
filters=filters[i +
1],
stride=(1, 1),
bn_eps=bn_eps,
bn_momentum=bn_momentum)
x = BatchNormalization(axis=-1,
epsilon=bn_eps,
momentum=bn_momentum)(x)
x = ReLU()(x)
x = AveragePooling2D((8, 8))(x)
x = Flatten()(x)
x = Dense(classes, kernel_regularizer=l2(reg))(x)
x = Softmax()(x)
return Model(inputs, x, name='resnet')
- Define a function to load an image and its one-hot encoded labels, based on its file path:
def load_image_and_label(image_path, target_size=(32, 32)):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, channels=3)
image = tf.image.convert_image_dtype(image,
np.float32)
image -= CINIC_MEAN_RGB # Mean normalize
image = tf.image.resize(image, target_size)
label = tf.strings.split(image_path, os.path.sep)[-2]
label = (label == CINIC_10_CLASSES) # One-hot encode.
label = tf.dtypes.cast(label, tf.float32)
return image, label
- Define a function to create a
tf.data.Dataset
instance of images and labels from a glob-like pattern that refers to the folder where the images are:def prepare_dataset(data_pattern, shuffle=False):
dataset = (tf.data.Dataset
.list_files(data_pattern)
.map(load_image_and_label,
num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE))
if shuffle:
dataset = dataset.shuffle(BUFFER_SIZE)
return dataset.prefetch(BATCH_SIZE)
- Define the mean RGB values of the
CINIC-10
dataset, which is used in the load_image_and_label()
function to mean normalize the images (this information is available on the official CINIC-10
site):CINIC_MEAN_RGB = np.array([0.47889522, 0.47227842, 0.43047404])
- Define the classes of the
CINIC-10
dataset:CINIC_10_CLASSES = ['airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship',
'truck']
- Download and extract the
CINIC-10
dataset to the ~/.keras/datasets
directory:DATASET_URL = ('https://datashare.is.ed.ac.uk/bitstream/handle/'
'10283/3192/CINIC-10.tar.gz?'
'sequence=4&isAllowed=y')
DATA_NAME = 'cinic10'
FILE_EXTENSION = 'tar.gz'
FILE_NAME = '.'.join([DATA_NAME, FILE_EXTENSION])
downloaded_file_location = get_file(origin=DATASET_URL,
fname=FILE_NAME,
extract=False)
data_directory, _ = (downloaded_file_location
.rsplit(os.path.sep, maxsplit=1))
data_directory = os.path.sep.join([data_directory,
DATA_NAME])
tar = tarfile.open(downloaded_file_location)
if not os.path.exists(data_directory):
tar.extractall(data_directory)
- Define the glob-like patterns to the train, test, and validation subsets:
train_pattern = os.path.sep.join(
[data_directory, 'train/*/*.png'])
test_pattern = os.path.sep.join(
[data_directory, 'test/*/*.png'])
valid_pattern = os.path.sep.join(
[data_directory, 'valid/*/*.png'])
- Prepare the datasets:
BATCH_SIZE = 128
BUFFER_SIZE = 1024
train_dataset = prepare_dataset(train_pattern,
shuffle=True)
test_dataset = prepare_dataset(test_pattern)
valid_dataset = prepare_dataset(valid_pattern)
- Build, compile, and train a ResNet model. Because this is a time-consuming process, we'll save a version of the model after each epoch, using the
ModelCheckpoint()
callback:model = build_resnet(input_shape=(32, 32, 3),
classes=10,
stages=(9, 9, 9),
filters=(64, 64, 128, 256),
reg=5e-3)
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model_checkpoint_callback = ModelCheckpoint(
filepath='./model.{epoch:02d}-{val_accuracy:.2f}.hdf5',
save_weights_only=False,
monitor='val_accuracy')
EPOCHS = 100
model.fit(train_dataset,
validation_data=valid_dataset,
epochs=EPOCHS,
callbacks=[model_checkpoint_callback])
- Load the best model (in this case,
model.38-0.72.hdf5
) and evaluate it on the test set:model = load_model('model.38-0.72.hdf5')
result = model.evaluate(test_dataset)
print(f'Test accuracy: {result[1]}')
This prints the following:
Test accuracy: 0.71956664
Let's learn how it all works in the next section.
A residual module comprises two branches: the first one is the skip connection, also known as the shortcut branch, which is basically the same as the input. The second or main branch is composed of three convolution blocks: a 1x1 with a quarter of the filters, a 3x3 one, also with a quarter of the filters, and finally another 1x1, which uses all the filters. The shortcut and main branches are concatenated in the end using the Add()
layer.