aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xexamples/py-machine-learning/keras-autoencoder.py30
1 files changed, 17 insertions, 13 deletions
diff --git a/examples/py-machine-learning/keras-autoencoder.py b/examples/py-machine-learning/keras-autoencoder.py
index fc3115012..a99cc1b2d 100755
--- a/examples/py-machine-learning/keras-autoencoder.py
+++ b/examples/py-machine-learning/keras-autoencoder.py
@@ -29,12 +29,12 @@ import nDPIsrvd
from nDPIsrvd import nDPIsrvdSocket, TermColor
INPUT_SIZE = nDPIsrvd.nDPId_PACKETS_PLEN_MAX
-LATENT_SIZE = 16
-TRAINING_SIZE = 8192
-EPOCH_COUNT = 50
-BATCH_SIZE = 512
-LEARNING_RATE = 0.0000001
-ES_PATIENCE = 10
+LATENT_SIZE = 8
+TRAINING_SIZE = 512
+EPOCH_COUNT = 3
+BATCH_SIZE = 16
+LEARNING_RATE = 0.000001
+ES_PATIENCE = 3
PLOT = False
PLOT_HISTORY = 100
TENSORBOARD = False
@@ -164,8 +164,12 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_
sys.stderr.flush()
encoder, _, autoencoder = get_autoencoder()
autoencoder.summary()
- tensorboard = TensorBoard(log_dir=TB_LOGPATH, histogram_freq=1)
+ additional_callbacks = []
+ if TENSORBOARD is True:
+ tensorboard = TensorBoard(log_dir=TB_LOGPATH, histogram_freq=1)
+ additional_callbacks += [tensorboard]
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=ES_PATIENCE, restore_best_weights=True, start_from_epoch=0, verbose=0, mode='auto')
+ additional_callbacks += [early_stopping]
shared_training_event.clear()
try:
@@ -188,7 +192,7 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_
tmp, tmp, epochs=EPOCH_COUNT, batch_size=BATCH_SIZE,
validation_split=0.2,
shuffle=True,
- callbacks=[tensorboard, early_stopping]
+ callbacks=[additional_callbacks]
)
reconstructed_data = autoencoder.predict(tmp)
mse = np.mean(np.square(tmp - reconstructed_data))
@@ -295,15 +299,15 @@ if __name__ == '__main__':
help='Load a pre-trained model file.')
argparser.add_argument('--save-model', action='store',
help='Save the trained model to a file.')
- argparser.add_argument('--training-size', action='store', default=TRAINING_SIZE,
+ argparser.add_argument('--training-size', action='store', type=int, default=TRAINING_SIZE,
help='Set the amount of captured packets required to start the training phase.')
- argparser.add_argument('--batch-size', action='store', default=BATCH_SIZE,
+ argparser.add_argument('--batch-size', action='store', type=int, default=BATCH_SIZE,
help='Set the batch size used for the training phase.')
- argparser.add_argument('--learning-rate', action='store', default=LEARNING_RATE,
+ argparser.add_argument('--learning-rate', action='store', type=float, default=LEARNING_RATE,
help='Set the (initial) learning rate for the optimizer.')
argparser.add_argument('--plot', action='store_true', default=PLOT,
help='Show some model metrics using pyplot.')
- argparser.add_argument('--plot-history', action='store', default=PLOT_HISTORY,
+ argparser.add_argument('--plot-history', action='store', type=int, default=PLOT_HISTORY,
help='Set the history size of Line plots. Requires --plot')
argparser.add_argument('--tensorboard', action='store_true', default=TENSORBOARD,
help='Enable TensorBoard compatible logging callback.')
@@ -313,7 +317,7 @@ if __name__ == '__main__':
help='Use SGD optimizer instead of Adam.')
argparser.add_argument('--use-kldiv', action='store_true', default=VAE_USE_KLDIV,
help='Use Kullback-Leibler loss function instead of Mean-Squared-Error.')
- argparser.add_argument('--patience', action='store', default=ES_PATIENCE,
+ argparser.add_argument('--patience', action='store', type=int, default=ES_PATIENCE,
help='Epoch value for EarlyStopping. This value forces VAE fitting to if no improvment achieved.')
args = argparser.parse_args()
address = nDPIsrvd.validateAddress(args)