Validation, regularization, callbacks

目录

<!DOCTYPE html>

Coding Tutorial
In [45]:
import tensorflow as tf
print(tf.__version__)
2.0.0

Validation, regularisation and callbacks


Validation sets

Load the data

In [46]:
# Load the diabetes dataset

from sklearn.datasets import load_diabetes

diabetes_dataset = load_diabetes()
print(diabetes_dataset['DESCR'])
.. _diabetes_dataset:

Diabetes dataset
----------------

Ten baseline variables, age, sex, body mass index, average blood
pressure, and six blood serum measurements were obtained for each of n =
442 diabetes patients, as well as the response of interest, a
quantitative measure of disease progression one year after baseline.

**Data Set Characteristics:**

  :Number of Instances: 442

  :Number of Attributes: First 10 columns are numeric predictive values

  :Target: Column 11 is a quantitative measure of disease progression one year after baseline

  :Attribute Information:
      - Age
      - Sex
      - Body mass index
      - Average blood pressure
      - S1
      - S2
      - S3
      - S4
      - S5
      - S6

Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times `n_samples` (i.e. the sum of squares of each column totals 1).

Source URL:
http://www4.stat.ncsu.edu/~boos/var.select/diabetes.html

For more information see:
Bradley Efron, Trevor Hastie, Iain Johnstone and Robert Tibshirani (2004) "Least Angle Regression," Annals of Statistics (with discussion), 407-499.
(http://web.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf)
In [47]:
# Save the input and target variables
#print(diabetes_dataset.keys())

data = diabetes_dataset["data"]
target = diabetes_dataset["target"]
In [48]:
# Normalise the target data (this will make clearer training curves)

targets = (target - target.mean(axis=0)) / target.std()
targets
Out[48]:
array([-1.47194752e-02, -1.00165882e+00, -1.44579915e-01,  6.99512942e-01,
       -2.22496178e-01, -7.15965848e-01, -1.83538046e-01, -1.15749134e+00,
       -5.47147277e-01,  2.05006151e+00, -6.64021672e-01, -1.07957508e+00,
        3.48889755e-01,  4.26806019e-01, -4.43258925e-01,  2.45001404e-01,
        1.80071184e-01, -1.05621783e-01, -7.15965848e-01,  2.06043272e-01,
       -1.09256112e+00, -1.33929596e+00, -1.09256112e+00,  1.20596866e+00,
        4.13819975e-01,  6.47568766e-01, -1.96524090e-01, -8.71798376e-01,
       -2.74440354e-01,  1.69943833e+00, -3.00412442e-01, -1.20943552e+00,
        2.45262887e+00, -8.45826288e-01, -1.13151925e+00, -6.51035629e-01,
        1.46568953e+00,  1.60853602e+00,  1.29687096e+00, -8.06868156e-01,
       -6.77007716e-01, -1.26137969e+00, -1.18346343e+00, -7.80896068e-01,
        1.38777327e+00, -1.28735178e+00,  4.91736239e-01, -1.31593871e-01,
       -1.00165882e+00, -1.31593871e-01,  3.72247006e-02,  9.46247777e-01,
       -1.20943552e+00, -6.25063541e-01,  3.87847887e-01, -3.13398486e-01,
       -1.30033783e+00, -1.49512849e+00,  2.32015360e-01,  2.32015360e-01,
       -1.18346343e+00, -1.05621783e-01, -1.30033783e+00, -3.13398486e-01,
       -1.05360299e+00,  1.41113052e-01, -2.77055191e-02, -7.15965848e-01,
        1.02154920e-01,  3.35903711e-01, -1.35228200e+00,  1.53061975e+00,
        6.47568766e-01, -5.34161233e-01, -8.71798376e-01, -1.43019827e+00,
        2.32015360e-01,  6.21596678e-01,  1.29687096e+00, -5.08189145e-01,
       -1.18607827e-01, -1.31332387e+00, -1.30033783e+00,  7.51457118e-01,
       -1.13151925e+00, -1.44579915e-01, -1.26137969e+00, -2.35482222e-01,
       -1.43019827e+00, -5.34161233e-01, -7.02979804e-01,  1.54099096e-01,
       -1.35228200e+00, -7.28951892e-01, -8.06868156e-01,  1.28127008e-01,
       -2.77055191e-02,  1.64749415e+00, -7.80896068e-01, -8.97770464e-01,
       -3.13398486e-01, -6.51035629e-01,  1.94617316e+00,  5.95624590e-01,
       -7.41937936e-01, -1.28735178e+00, -2.35482222e-01, -1.05621783e-01,
        1.03715008e+00, -9.23742551e-01, -6.25063541e-01, -1.20943552e+00,
        1.21895470e+00,  1.88124294e+00,  1.37478723e+00,  9.98191953e-01,
        1.59554997e+00,  1.67346624e+00,  3.48889755e-01,  6.21596678e-01,
        6.21596678e-01,  2.70973492e-01,  3.61875799e-01, -8.84784420e-01,
       -4.04300794e-01,  1.15140964e-01, -6.89993760e-01, -5.60133321e-01,
       -4.82217057e-01,  1.50464767e+00,  1.58256393e+00,  7.61828325e-02,
       -5.86105409e-01, -8.97770464e-01, -6.38049585e-01,  1.55659184e+00,
       -8.71798376e-01,  1.66048019e+00,  2.38769865e+00,  1.67346624e+00,
       -4.43258925e-01,  2.14096382e+00,  1.07610822e+00, -1.19644947e+00,
        2.83959536e-01,  1.38777327e+00,  3.35903711e-01, -3.13398486e-01,
       -7.28951892e-01, -3.39370574e-01,  1.76436855e+00, -8.32840244e-01,
        1.81631272e+00, -1.05360299e+00,  5.82638546e-01,  4.39792063e-01,
       -1.65096101e+00, -8.84784420e-01, -7.28951892e-01,  5.56666458e-01,
       -1.28735178e+00,  8.42359425e-01,  2.57987448e-01, -2.74440354e-01,
        8.03401293e-01, -1.20943552e+00, -1.06658903e+00,  8.81317557e-01,
        1.50464767e+00, -1.73343121e-03, -1.36526805e+00, -1.01464486e+00,
        1.85527085e+00, -6.64021672e-01, -1.47194752e-02, -3.26384530e-01,
        1.10208030e+00,  9.46247777e-01, -9.23742551e-01, -1.47194752e-02,
       -5.86105409e-01, -1.14450530e+00, -1.83538046e-01,  4.26806019e-01,
        1.46568953e+00, -6.64021672e-01, -1.96524090e-01, -1.18607827e-01,
       -1.44579915e-01, -9.49714639e-01,  1.81631272e+00,  3.35903711e-01,
       -7.93882112e-01, -4.69231013e-01, -8.58812332e-01, -3.91314750e-01,
       -1.04061695e+00, -3.00412442e-01, -1.31593871e-01, -8.06868156e-01,
        7.61828325e-02, -1.46915640e+00,  5.69652502e-01,  9.07289645e-01,
        1.62152206e+00, -6.89993760e-01,  5.69652502e-01,  6.47568766e-01,
        3.72247006e-02, -9.75686727e-01,  5.04722283e-01, -1.06658903e+00,
       -1.02763090e+00, -1.33929596e+00, -1.13151925e+00,  1.43971745e+00,
        1.24492679e+00,  1.86825690e+00,  8.03401293e-01,  4.26806019e-01,
       -9.62700683e-01, -7.67910024e-01,  1.29687096e+00, -2.77055191e-02,
       -9.75686727e-01,  7.25485030e-01, -9.75686727e-01, -5.73119365e-01,
        1.02154920e-01, -1.28735178e+00,  8.81317557e-01,  2.42386567e-02,
        1.38777327e+00, -8.06868156e-01,  1.21895470e+00, -3.65342662e-01,
       -1.10554717e+00, -1.04061695e+00,  1.36180118e+00,  1.42673140e+00,
        1.59554997e+00,  3.22917667e-01, -1.05360299e+00, -1.36526805e+00,
        4.52778107e-01, -3.52356618e-01, -9.62700683e-01, -1.31332387e+00,
        1.37478723e+00,  8.16387337e-01,  1.95915920e+00,  1.17999657e+00,
       -7.93882112e-01, -2.77055191e-02,  2.05006151e+00,  1.12526127e-02,
        2.51755909e+00, -1.15749134e+00, -8.19854200e-01, -1.32630991e+00,
       -1.46915640e+00, -6.38049585e-01,  2.02408942e+00, -4.69231013e-01,
       -9.26357388e-02, -1.01464486e+00, -1.39124013e+00, -4.82217057e-01,
        1.45270349e+00, -8.45826288e-01,  6.47568766e-01, -3.26384530e-01,
        3.87847887e-01,  1.15402448e+00, -1.11853321e+00, -7.54923980e-01,
        1.69943833e+00, -1.14450530e+00, -6.51035629e-01,  6.21596678e-01,
        1.46568953e+00, -7.54923980e-01,  1.01117800e+00,  3.74861843e-01,
        5.02107446e-02,  1.05013613e+00, -1.19644947e+00,  8.68331513e-01,
       -9.36728595e-01, -1.09256112e+00,  2.33575448e+00,  1.24492679e+00,
       -8.84784420e-01,  6.21596678e-01, -1.26137969e+00, -8.71798376e-01,
       -8.19854200e-01, -1.57304475e+00, -3.00412442e-01, -8.97770464e-01,
        1.59554997e+00, -1.13151925e+00,  5.95624590e-01,  1.08909426e+00,
        1.30985701e+00, -3.65342662e-01, -1.40422618e+00,  2.57987448e-01,
       -4.95203101e-01, -1.31593871e-01, -5.60133321e-01,  3.61875799e-01,
       -1.05621783e-01,  1.41113052e-01, -6.66636509e-02, -7.15965848e-01,
        8.81317557e-01,  4.91736239e-01, -5.60133321e-01,  5.04722283e-01,
       -3.91314750e-01,  1.01117800e+00,  1.16701052e+00,  1.24492679e+00,
        1.25791283e+00,  5.17708327e-01, -2.74440354e-01,  1.10208030e+00,
       -9.62700683e-01, -2.22496178e-01,  1.19298261e+00,  6.08610634e-01,
        1.53061975e+00,  1.54099096e-01, -1.04061695e+00, -7.28951892e-01,
        1.99811734e+00, -7.93882112e-01,  8.03401293e-01, -7.41937936e-01,
        8.29373381e-01,  1.43971745e+00,  3.35903711e-01, -5.08189145e-01,
        6.21596678e-01, -1.70552003e-01, -1.70552003e-01, -8.32840244e-01,
       -5.36776070e-02, -8.32840244e-01,  1.17999657e+00, -1.05360299e+00,
       -9.75686727e-01, -5.60133321e-01,  1.55659184e+00, -1.19644947e+00,
       -1.27436574e+00,  8.94303601e-01, -8.06868156e-01,  2.06304756e+00,
        1.67346624e+00,  3.87847887e-01,  2.19290800e+00, -1.22242156e+00,
        1.42673140e+00,  6.99512942e-01,  1.05013613e+00,  1.16701052e+00,
       -3.78328706e-01,  1.93057228e-01, -1.15749134e+00,  5.82638546e-01,
       -1.05360299e+00,  2.06043272e-01, -1.57565959e-01,  8.42359425e-01,
       -4.04300794e-01,  1.07610822e+00,  1.20596866e+00, -1.45617035e+00,
       -1.30033783e+00, -6.25063541e-01, -2.61454310e-01, -8.32840244e-01,
       -1.07957508e+00,  8.68331513e-01, -1.04061695e+00,  6.34582722e-01,
       -5.47147277e-01, -1.31332387e+00,  1.62152206e+00, -1.15749134e+00,
       -4.43258925e-01, -1.07957508e+00,  1.56957789e+00,  1.37478723e+00,
       -1.41721222e+00,  5.95624590e-01,  1.16701052e+00,  1.03715008e+00,
        2.96945580e-01, -7.67910024e-01,  2.06043272e-01,  1.59554997e+00,
        1.82929877e+00,  1.67346624e+00, -1.04061695e+00, -1.57565959e-01,
        4.78750195e-01,  3.74861843e-01,  7.38471074e-01, -2.09510134e-01,
        1.41374536e+00, -5.08189145e-01, -2.74440354e-01,  2.83959536e-01,
        1.36180118e+00, -1.26137969e+00, -8.84784420e-01, -1.43019827e+00,
       -7.96496949e-02,  7.77429206e-01,  1.05013613e+00, -7.93882112e-01,
       -5.34161233e-01, -1.73343121e-03, -4.17286837e-01, -1.10554717e+00,
        2.05006151e+00, -7.54923980e-01,  4.00833931e-01, -1.11853321e+00,
        2.70973492e-01, -1.04061695e+00, -1.33929596e+00, -1.14450530e+00,
       -1.35228200e+00,  3.35903711e-01, -6.25063541e-01, -2.61454310e-01,
        8.81317557e-01, -1.23540761e+00])
In [49]:
# Split the data into train and test sets
from sklearn.model_selection import train_test_split

train_data, test_data, train_targets, test_targets = train_test_split(data, targets, test_size=0.1)

Train a feedforward neural network model

In [50]:
print(train_data.shape)
print(test_data.shape)
print(train_targets.shape)
print(test_targets.shape)
(397, 10)
(45, 10)
(397,)
(45,)
In [51]:
# Build the model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

def get_model():
    model = Sequential([
        Dense(128, activation="relu", input_shape=(train_data.shape[1],)),
        Dense(128, activation="relu"),
        Dense(128, activation="relu"),
        Dense(128, activation="relu"),
        Dense(128, activation="relu"),
        Dense(128, activation="relu"),
        Dense(1)
    ])
    return model

model = get_model()
In [52]:
# Print the model summary
model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_19 (Dense)             (None, 128)               1408      
_________________________________________________________________
dense_20 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_21 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_22 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_23 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_24 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_25 (Dense)             (None, 1)                 129       
=================================================================
Total params: 84,097
Trainable params: 84,097
Non-trainable params: 0
_________________________________________________________________
In [53]:
# Compile the model

model.compile(optimizer="adam", loss="mse", metrics=["mae"])
In [54]:
# Train the model, with some of the data reserved for validation
history = model.fit(train_data, train_targets, epochs=100, validation_split=0.15, batch_size=64, verbose=False)
In [55]:
# Evaluate the model on the test set

model.evaluate(test_data, test_targets,verbose=2)
45/1 - 0s - loss: 0.9737 - mae: 0.7681
Out[55]:
[0.9119720167583889, 0.7680769]

Plot the learning curves

In [56]:
import matplotlib.pyplot as plt
%matplotlib inline
In [57]:
# Plot the training and validation loss

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.show()

Model regularisation

Adding regularisation with weight decay and dropout

In [58]:
from tensorflow.keras.layers import Dropout
from tensorflow.keras import regularizers
In [74]:
def get_regularised_model(wd, rate):
    model = Sequential([
        Dense(128, kernel_regularizer=tf.keras.regularizers.l2(wd), activation="relu", input_shape=(train_data.shape[1],)),
        Dropout(rate),
        Dense(128, kernel_regularizer=tf.keras.regularizers.l2(wd), activation="relu"),
        Dropout(rate),
        Dense(128, kernel_regularizer=tf.keras.regularizers.l2(wd), activation="relu"),
        Dropout(rate),
        Dense(128, kernel_regularizer=tf.keras.regularizers.l2(wd), activation="relu"),
        Dropout(rate),
        Dense(128, kernel_regularizer=tf.keras.regularizers.l2(wd), activation="relu"),
        Dropout(rate),
        Dense(128, kernel_regularizer=tf.keras.regularizers.l2(wd), activation="relu"),
        Dropout(rate),
        Dense(1)
    ])
    return model
In [75]:
# Re-build the model with weight decay and dropout layers

model = get_regularised_model(1e-5,0.3)
In [76]:
# Compile the model
model.compile(optmizer="adam", loss="mse", metrics=["mae"])
In [77]:
# Train the model, with some of the data reserved for validation

history = model.fit(train_data, train_targets, epochs=100, validation_split=0.15, batch_size=64, verbose=False)
In [78]:
# Evaluate the model on the test set

model.evaluate(test_data, test_targets, verbose=2)
45/1 - 0s - loss: 0.6757 - mae: 0.6096
Out[78]:
[0.6208774553404914, 0.6095614]

Plot the learning curves

In [79]:
# Plot the training and validation loss

import matplotlib.pyplot as plt

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.show()

Introduction to callbacks

Example training callback

In [92]:
# Write a custom callback
from tensorflow.keras.callbacks import Callback

class TrainingCallback(Callback):
    def on_train_begin(self, logs=None):
        print("Starting training....")
        
    def on_epoch_begin(self, epoch, logs=None):
        print(f"Starting epoch {epoch}")
    
    def on_train_batch_begin(self, batch,logs=None):
        print(f"Training: Starting batch {batch}")
        
    def on_train_batch_end(self, batch, logs=None):
        print(f"Training: Finished batch {batch}")
        
    def on_epoch_end(self, epoch, logs=None):
        print(f"Finished epoch {epoch}")
        
    def on_train_end(self, logs=None):
        print("Finished trainging!")
In [93]:
# Re-build the model
model = get_regularised_model(1e-5, 0.3)
In [94]:
# Compile the model

model.compile(optimizer="adam", loss="mse")

Train the model with the callback

In [98]:
# Train the model, with some of the data reserved for validation

model.fit(train_data, train_targets, epochs=3, batch_size=128, verbose=False, callbacks=[TrainingCallback()])
Starting training....
Starting epoch 0
Training: Starting batch 0
Training: Finished batch 0
Training: Starting batch 1
Training: Finished batch 1
Training: Starting batch 2
Training: Finished batch 2
Training: Starting batch 3
Training: Finished batch 3
Finished epoch 0
Starting epoch 1
Training: Starting batch 0
Training: Finished batch 0
Training: Starting batch 1
Training: Finished batch 1
Training: Starting batch 2
Training: Finished batch 2
Training: Starting batch 3
Training: Finished batch 3
Finished epoch 1
Starting epoch 2
Training: Starting batch 0
Training: Finished batch 0
Training: Starting batch 1
Training: Finished batch 1
Training: Starting batch 2
Training: Finished batch 2
Training: Starting batch 3
Training: Finished batch 3
Finished epoch 2
Finished trainging!
Out[98]:
<tensorflow.python.keras.callbacks.History at 0x7fd0ec0b0d68>
In [105]:
# Evaluate the model
class TestingCallback(Callback):
    def on_test_begin(self, logs=None):
        print("Starting testing....")
        

    
    def on_test_batch_begin(self, batch,logs=None):
        print(f"Testing: Starting batch {batch}")
        
    def on_test_batch_end(self, batch, logs=None):
        print(f"Testing: Finished batch {batch}")
        

        
    def on_test_end(self, logs=None):
        print("Finished testing!")

model.evaluate(test_data, test_targets, verbose=False, callbacks=[TestingCallback()])
Starting testing....
Testing: Starting batch 0
Testing: Finished batch 0
Testing: Starting batch 1
Testing: Finished batch 1
Finished testing!
Out[105]:
0.5224372175004747
In [106]:
# Make predictions with the model
class PredictionCallback(Callback):
    def on_predict_begin(self, logs=None):
        print("Starting Prediction....")
    
    def on_predict_batch_begin(self, batch,logs=None):
        print(f"Prediction: Starting batch {batch}")
        
    def on_predict_batch_end(self, batch, logs=None):
        print(f"Prediction: Finished batch {batch}")
        
       
    def on_predict_end(self, logs=None):
        print("Finished Prediction!")

model.predict(test_data, verbose=False, callbacks=[PredictionCallback()])
Starting Prediction....
Prediction: Starting batch 0
Prediction: Finished batch 0
Prediction: Starting batch 1
Prediction: Finished batch 1
Finished Prediction!
Out[106]:
array([[-0.70453364],
       [-0.66237   ],
       [-0.46678713],
       [-0.00217101],
       [ 0.81033176],
       [ 0.09889701],
       [ 0.9591879 ],
       [-0.41632694],
       [-0.6790302 ],
       [-0.44720244],
       [ 0.32216445],
       [ 1.0040126 ],
       [ 0.05457034],
       [ 0.77189565],
       [-0.62541866],
       [-0.6614995 ],
       [-0.4295209 ],
       [-0.5600192 ],
       [-0.32411316],
       [ 1.4033808 ],
       [-0.4459252 ],
       [-0.5172542 ],
       [-0.67107993],
       [ 0.6741319 ],
       [-0.42273706],
       [-0.49320617],
       [ 0.70178944],
       [ 1.4118713 ],
       [ 0.99328846],
       [-0.27199495],
       [-0.08221228],
       [-0.55037   ],
       [ 0.88309705],
       [-0.06700692],
       [ 0.8299056 ],
       [ 0.598157  ],
       [ 0.19678356],
       [-0.03356161],
       [-0.64796686],
       [-0.6102506 ],
       [-0.7447535 ],
       [ 0.73226035],
       [ 0.2279578 ],
       [-0.5152449 ],
       [-0.47076586]], dtype=float32)

Early stopping / patience

Re-train the models with early stopping

In [129]:
# Re-train the unregularised model

unregularised_model = get_model()
unregularised_model.compile(optimizer="adam", loss="mse")
unreg_history = unregularised_model.fit(train_data, train_targets, epochs=100,
                                       validation_split=0.15, batch_size=64,
                                        callbacks=[tf.keras.callbacks.EarlyStopping(monitor="val_loss",min_delta=0.01,patience=10,mode="min")])
Train on 337 samples, validate on 60 samples
Epoch 1/100
337/337 [==============================] - 1s 3ms/sample - loss: 0.9819 - val_loss: 0.9647
Epoch 2/100
337/337 [==============================] - 0s 491us/sample - loss: 0.8789 - val_loss: 0.7893
Epoch 3/100
337/337 [==============================] - 0s 573us/sample - loss: 0.6763 - val_loss: 0.5589
Epoch 4/100
337/337 [==============================] - 0s 317us/sample - loss: 0.5579 - val_loss: 0.5193
Epoch 5/100
337/337 [==============================] - 0s 579us/sample - loss: 0.5062 - val_loss: 0.4961
Epoch 6/100
337/337 [==============================] - 0s 313us/sample - loss: 0.4938 - val_loss: 0.5053
Epoch 7/100
337/337 [==============================] - 0s 324us/sample - loss: 0.4892 - val_loss: 0.5020
Epoch 8/100
337/337 [==============================] - 0s 574us/sample - loss: 0.4663 - val_loss: 0.4756
Epoch 9/100
337/337 [==============================] - 0s 317us/sample - loss: 0.4434 - val_loss: 0.4807
Epoch 10/100
337/337 [==============================] - 0s 579us/sample - loss: 0.4329 - val_loss: 0.4794
Epoch 11/100
337/337 [==============================] - 0s 314us/sample - loss: 0.4264 - val_loss: 0.4951
Epoch 12/100
337/337 [==============================] - 0s 571us/sample - loss: 0.4266 - val_loss: 0.4820
Epoch 13/100
337/337 [==============================] - 0s 313us/sample - loss: 0.4314 - val_loss: 0.4869
Epoch 14/100
337/337 [==============================] - 0s 574us/sample - loss: 0.4106 - val_loss: 0.4776
Epoch 15/100
337/337 [==============================] - 0s 318us/sample - loss: 0.4009 - val_loss: 0.4756
Epoch 16/100
337/337 [==============================] - 0s 324us/sample - loss: 0.3895 - val_loss: 0.4944
Epoch 17/100
337/337 [==============================] - 0s 573us/sample - loss: 0.3853 - val_loss: 0.5071
Epoch 18/100
337/337 [==============================] - 0s 319us/sample - loss: 0.3838 - val_loss: 0.5035
In [130]:
# Evaluate the model on the test set
unregularised_model.evaluate(test_data, test_targets, verbose=2)
45/1 - 0s - loss: 0.6237
Out[130]:
0.5756108389960395
In [133]:
# Re-train the regularised model

regularised_model = get_regularised_model(1e-8, 0.2)
regularised_model.compile(optimizer="adam", loss="mse")
reg_history = regularised_model.fit(train_data, train_targets, epochs=100, 
                                   validation_split=0.15, batch_size=64,
                                    callbacks=[tf.keras.callbacks.EarlyStopping(monitor="val_loss",min_delta=0.01,patience=10,mode="min")])
Train on 337 samples, validate on 60 samples
Epoch 1/100
337/337 [==============================] - 2s 7ms/sample - loss: 0.9924 - val_loss: 1.0138
Epoch 2/100
337/337 [==============================] - 0s 339us/sample - loss: 0.9582 - val_loss: 0.9155
Epoch 3/100
337/337 [==============================] - 0s 581us/sample - loss: 0.8535 - val_loss: 0.6993
Epoch 4/100
337/337 [==============================] - 0s 585us/sample - loss: 0.6555 - val_loss: 0.5455
Epoch 5/100
337/337 [==============================] - 0s 583us/sample - loss: 0.5879 - val_loss: 0.5361
Epoch 6/100
337/337 [==============================] - 0s 320us/sample - loss: 0.5272 - val_loss: 0.5162
Epoch 7/100
337/337 [==============================] - 0s 582us/sample - loss: 0.5791 - val_loss: 0.5240
Epoch 8/100
337/337 [==============================] - 0s 584us/sample - loss: 0.5386 - val_loss: 0.5033
Epoch 9/100
337/337 [==============================] - 0s 576us/sample - loss: 0.5371 - val_loss: 0.5020
Epoch 10/100
337/337 [==============================] - 0s 326us/sample - loss: 0.5240 - val_loss: 0.4894
Epoch 11/100
337/337 [==============================] - 0s 589us/sample - loss: 0.5010 - val_loss: 0.5069
Epoch 12/100
337/337 [==============================] - 0s 584us/sample - loss: 0.4600 - val_loss: 0.4947
Epoch 13/100
337/337 [==============================] - 0s 582us/sample - loss: 0.4833 - val_loss: 0.4918
Epoch 14/100
337/337 [==============================] - 0s 577us/sample - loss: 0.4717 - val_loss: 0.4914
Epoch 15/100
337/337 [==============================] - 0s 333us/sample - loss: 0.4818 - val_loss: 0.4865
Epoch 16/100
337/337 [==============================] - 0s 576us/sample - loss: 0.4840 - val_loss: 0.4876
Epoch 17/100
337/337 [==============================] - 0s 587us/sample - loss: 0.4632 - val_loss: 0.4934
Epoch 18/100
337/337 [==============================] - 0s 575us/sample - loss: 0.4826 - val_loss: 0.4994
Epoch 19/100
337/337 [==============================] - 0s 332us/sample - loss: 0.4336 - val_loss: 0.5023
Epoch 20/100
337/337 [==============================] - 0s 577us/sample - loss: 0.4438 - val_loss: 0.5098
In [134]:
# Evaluate the model on the test set

regularised_model.evaluate(test_data, test_targets, verbose=2)
45/1 - 0s - loss: 0.6192
Out[134]:
0.549280321598053

Plot the learning curves

In [135]:
# Plot the training and validation loss

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(12, 5))

fig.add_subplot(121)

plt.plot(unreg_history.history['loss'])
plt.plot(unreg_history.history['val_loss'])
plt.title('Unregularised model: loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')

fig.add_subplot(122)

plt.plot(reg_history.history['loss'])
plt.plot(reg_history.history['val_loss'])
plt.title('Regularised model: loss vs. epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')

plt.show()
In [ ]: