diff options
-rwxr-xr-x | examples/py-machine-learning/keras-autoencoder.py | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/examples/py-machine-learning/keras-autoencoder.py b/examples/py-machine-learning/keras-autoencoder.py index fc3115012..a99cc1b2d 100755 --- a/examples/py-machine-learning/keras-autoencoder.py +++ b/examples/py-machine-learning/keras-autoencoder.py @@ -29,12 +29,12 @@ import nDPIsrvd from nDPIsrvd import nDPIsrvdSocket, TermColor INPUT_SIZE = nDPIsrvd.nDPId_PACKETS_PLEN_MAX -LATENT_SIZE = 16 -TRAINING_SIZE = 8192 -EPOCH_COUNT = 50 -BATCH_SIZE = 512 -LEARNING_RATE = 0.0000001 -ES_PATIENCE = 10 +LATENT_SIZE = 8 +TRAINING_SIZE = 512 +EPOCH_COUNT = 3 +BATCH_SIZE = 16 +LEARNING_RATE = 0.000001 +ES_PATIENCE = 3 PLOT = False PLOT_HISTORY = 100 TENSORBOARD = False @@ -164,8 +164,12 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_ sys.stderr.flush() encoder, _, autoencoder = get_autoencoder() autoencoder.summary() - tensorboard = TensorBoard(log_dir=TB_LOGPATH, histogram_freq=1) + additional_callbacks = [] + if TENSORBOARD is True: + tensorboard = TensorBoard(log_dir=TB_LOGPATH, histogram_freq=1) + additional_callbacks += [tensorboard] early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=ES_PATIENCE, restore_best_weights=True, start_from_epoch=0, verbose=0, mode='auto') + additional_callbacks += [early_stopping] shared_training_event.clear() try: @@ -188,7 +192,7 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_ tmp, tmp, epochs=EPOCH_COUNT, batch_size=BATCH_SIZE, validation_split=0.2, shuffle=True, - callbacks=[tensorboard, early_stopping] + callbacks=[additional_callbacks] ) reconstructed_data = autoencoder.predict(tmp) mse = np.mean(np.square(tmp - reconstructed_data)) @@ -295,15 +299,15 @@ if __name__ == '__main__': help='Load a pre-trained model file.') argparser.add_argument('--save-model', action='store', help='Save the trained model to a file.') - argparser.add_argument('--training-size', action='store', default=TRAINING_SIZE, + argparser.add_argument('--training-size', action='store', type=int, default=TRAINING_SIZE, help='Set the amount of captured packets required to start the training phase.') - argparser.add_argument('--batch-size', action='store', default=BATCH_SIZE, + argparser.add_argument('--batch-size', action='store', type=int, default=BATCH_SIZE, help='Set the batch size used for the training phase.') - argparser.add_argument('--learning-rate', action='store', default=LEARNING_RATE, + argparser.add_argument('--learning-rate', action='store', type=float, default=LEARNING_RATE, help='Set the (initial) learning rate for the optimizer.') argparser.add_argument('--plot', action='store_true', default=PLOT, help='Show some model metrics using pyplot.') - argparser.add_argument('--plot-history', action='store', default=PLOT_HISTORY, + argparser.add_argument('--plot-history', action='store', type=int, default=PLOT_HISTORY, help='Set the history size of Line plots. Requires --plot') argparser.add_argument('--tensorboard', action='store_true', default=TENSORBOARD, help='Enable TensorBoard compatible logging callback.') @@ -313,7 +317,7 @@ if __name__ == '__main__': help='Use SGD optimizer instead of Adam.') argparser.add_argument('--use-kldiv', action='store_true', default=VAE_USE_KLDIV, help='Use Kullback-Leibler loss function instead of Mean-Squared-Error.') - argparser.add_argument('--patience', action='store', default=ES_PATIENCE, + argparser.add_argument('--patience', action='store', type=int, default=ES_PATIENCE, help='Epoch value for EarlyStopping. This value forces VAE fitting to if no improvment achieved.') args = argparser.parse_args() address = nDPIsrvd.validateAddress(args) |