diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2023-08-23 22:56:59 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2023-08-23 22:56:59 +0200 |
commit | 5234f4621b5a7c5764a6f53921f9af0ba9f4c762 (patch) | |
tree | 836c0c089df69f022175e2a60defc1166d47c76c | |
parent | 86ac09a8db9d6749adf6e29adc010d6eebc1d88c (diff) |
keras-autoencoder.py: TensorBoard, SGD optimizer, KLDivergence loss function, EarlyStopping
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
-rwxr-xr-x | examples/py-machine-learning/keras-autoencoder.py | 96 |
1 files changed, 64 insertions, 32 deletions
diff --git a/examples/py-machine-learning/keras-autoencoder.py b/examples/py-machine-learning/keras-autoencoder.py index 4f9307a6d..6e80b38a5 100755 --- a/examples/py-machine-learning/keras-autoencoder.py +++ b/examples/py-machine-learning/keras-autoencoder.py @@ -4,7 +4,7 @@ import base64 import binascii import datetime as dt import math -import matplotlib.animation as animation +import matplotlib.animation as ani import matplotlib.pyplot as plt import multiprocessing as mp import numpy as np @@ -17,7 +17,9 @@ from tensorflow.keras import models, layers, preprocessing from tensorflow.keras.layers import Embedding, Masking, Input, Dense from tensorflow.keras.models import Model from tensorflow.keras.utils import plot_model -from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import MeanSquaredError, KLDivergence +from tensorflow.keras.optimizers import Adam, SGD +from tensorflow.keras.callbacks import TensorBoard, EarlyStopping sys.path.append(os.path.dirname(sys.argv[0]) + '/../../dependencies') sys.path.append(os.path.dirname(sys.argv[0]) + '/../share/nDPId') @@ -28,11 +30,17 @@ from nDPIsrvd import nDPIsrvdSocket, TermColor INPUT_SIZE = nDPIsrvd.nDPId_PACKETS_PLEN_MAX LATENT_SIZE = 16 -TRAINING_SIZE = 1024 +TRAINING_SIZE = 8192 EPOCH_COUNT = 50 -BATCH_SIZE = 256 +BATCH_SIZE = 512 LEARNING_RATE = 0.0000001 +ES_PATIENCE = 10 +PLOT = False PLOT_HISTORY = 100 +TENSORBOARD = False +TB_LOGPATH = 'logs/' + dt.datetime.now().strftime("%Y%m%d-%H%M%S") +VAE_USE_KLDIV = False +VAE_USE_SGD = False def generate_autoencoder(): # TODO: The current model does handle *each* packet separatly. @@ -62,11 +70,13 @@ def generate_autoencoder(): encoder = Model(input_e, latent, name='encoder') decoder = Model(input_l, output_i, name='decoder') - return Adam(learning_rate=LEARNING_RATE), Model(input_e, decoder(encoder(input_e)), name='VAE') + return KLDivergence() if VAE_USE_KLDIV else MeanSquaredError(), \ + SGD() if VAE_USE_SGD else Adam(learning_rate=LEARNING_RATE), \ + Model(input_e, decoder(encoder(input_e)), name='VAE') def compile_autoencoder(): - optimizer, autoencoder = generate_autoencoder() - autoencoder.compile(loss='mean_squared_error', optimizer=optimizer, metrics=[]) + loss, optimizer, autoencoder = generate_autoencoder() + autoencoder.compile(loss=loss, optimizer=optimizer, metrics=[]) return autoencoder def get_autoencoder(load_from_file=None): @@ -87,7 +97,7 @@ def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): json_dict['packet_event_name'] != 'packet-flow': return True - shutdown_event, training_event, padded_pkts = global_user_data + shutdown_event, training_event, padded_pkts, print_dots = global_user_data if shutdown_event.is_set(): return False @@ -120,8 +130,11 @@ def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): #print(list(buf[0])) if not training_event.is_set(): - sys.stdout.write('.') + sys.stdout.write('.' * print_dots) sys.stdout.flush() + print_dots = 1 + else: + print_dots += 1 return True @@ -130,8 +143,8 @@ def nDPIsrvd_worker(address, shared_shutdown_event, shared_training_event, share try: nsock.connect(address) - padded_pkts = list() - nsock.loop(onJsonLineRecvd, None, (shared_shutdown_event, shared_training_event, shared_packet_list)) + print_dots = 1 + nsock.loop(onJsonLineRecvd, None, (shared_shutdown_event, shared_training_event, shared_packet_list, print_dots)) except nDPIsrvd.SocketConnectionBroken as err: sys.stderr.write('\nnDPIsrvd-Worker Socket Error: {}\n'.format(err)) except KeyboardInterrupt: @@ -151,6 +164,8 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_ sys.stderr.flush() encoder, decoder, autoencoder = get_autoencoder() autoencoder.summary() + tensorboard = TensorBoard(log_dir=TB_LOGPATH, histogram_freq=1) + early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=ES_PATIENCE, restore_best_weights=True, start_from_epoch=0, verbose=0, mode='auto') shared_training_event.clear() try: @@ -172,14 +187,15 @@ def keras_worker(load_model, save_model, shared_shutdown_event, shared_training_ history = autoencoder.fit( tmp, tmp, epochs=EPOCH_COUNT, batch_size=BATCH_SIZE, validation_split=0.2, - shuffle=True + shuffle=True, + callbacks=[tensorboard, early_stopping] ) reconstructed_data = autoencoder.predict(tmp) mse = np.mean(np.square(tmp - reconstructed_data)) reconstruction_accuracy = (1.0 / mse) encoded_data = encoder.predict(tmp) latent_activations = encoder.predict(tmp) - shared_plot_queue.put((reconstruction_accuracy, history.history['loss'], encoded_data[:, 0], encoded_data[:, 1], latent_activations)) + shared_plot_queue.put((reconstruction_accuracy, history.history['val_loss'], encoded_data[:, 0], encoded_data[:, 1], latent_activations)) packets.clear() shared_training_event.clear() except KeyboardInterrupt: @@ -234,7 +250,7 @@ def plot_animate(i, shared_plot_queue, ax, xs, ys): ax1.set_xlabel('Epoch Count') ax1.set_ylabel('Accuracy') ax2.set_xlabel('Epoch Count') - ax2.set_ylabel('Loss') + ax2.set_ylabel('Validation Loss') ax3.set_title('Latent Space') ax4.set_title('Latent Space Heatmap') ax4.set_xlabel('Latent Dimensions') @@ -247,7 +263,7 @@ def plot_worker(shared_shutdown_event, shared_plot_queue): ax1.set_xlabel('Epoch Count') ax1.set_ylabel('Accuracy') ax2.set_xlabel('Epoch Count') - ax2.set_ylabel('Loss') + ax2.set_ylabel('Validation Loss') ax3.set_title('Latent Space') ax4.set_title('Latent Space Heatmap') ax4.set_xlabel('Latent Dimensions') @@ -258,8 +274,9 @@ def plot_worker(shared_shutdown_event, shared_plot_queue): ys3 = [] ys4 = [] x = 0 - ani = animation.FuncAnimation(fig, plot_animate, fargs=(shared_plot_queue, (ax1, ax2, ax3, ax4), xs, (ys1, ys2, ys3, ys4)), interval=1000, cache_frame_data=False) + a = ani.FuncAnimation(fig, plot_animate, fargs=(shared_plot_queue, (ax1, ax2, ax3, ax4), xs, (ys1, ys2, ys3, ys4)), interval=1000, cache_frame_data=False) plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05) + plt.margins(x=0, y=0) plt.show() except Exception as err: sys.stderr.write('\nPlot-Worker Exception: {}\n'.format(str(err))) @@ -279,29 +296,43 @@ if __name__ == '__main__': help='Load a pre-trained model file.') argparser.add_argument('--save-model', action='store', help='Save the trained model to a file.') - argparser.add_argument('--training-size', action='store', type=int, + argparser.add_argument('--training-size', action='store', default=TRAINING_SIZE, help='Set the amount of captured packets required to start the training phase.') - argparser.add_argument('--batch-size', action='store', type=int, + argparser.add_argument('--batch-size', action='store', default=BATCH_SIZE, help='Set the batch size used for the training phase.') - argparser.add_argument('--learning-rate', action='store', type=float, - help='Set the (initial!) learning rate for the Adam optimizer.') - argparser.add_argument('--plot', action='store_true', default=False, + argparser.add_argument('--learning-rate', action='store', default=LEARNING_RATE, + help='Set the (initial) learning rate for the optimizer.') + argparser.add_argument('--plot', action='store_true', default=PLOT, help='Show some model metrics using pyplot.') - argparser.add_argument('--plot-history', action='store', type=int, + argparser.add_argument('--plot-history', action='store', default=PLOT_HISTORY, help='Set the history size of Line plots. Requires --plot') + argparser.add_argument('--tensorboard', action='store_true', default=TENSORBOARD, + help='Enable TensorBoard compatible logging callback.') + argparser.add_argument('--tensorboard-logpath', action='store', default=TB_LOGPATH, + help='TensorBoard logging path.') + argparser.add_argument('--use-sgd', action='store_true', default=VAE_USE_SGD, + help='Use SGD optimizer instead of Adam.') + argparser.add_argument('--use-kldiv', action='store_true', default=VAE_USE_KLDIV, + help='Use Kullback-Leibler loss function instead of Mean-Squared-Error.') + argparser.add_argument('--patience', action='store', default=ES_PATIENCE, + help='Epoch value for EarlyStopping. This value forces VAE fitting to if no improvment achieved.') args = argparser.parse_args() address = nDPIsrvd.validateAddress(args) - LEARNING_RATE = args.learning_rate if args.learning_rate is not None else LEARNING_RATE - TRAINING_SIZE = args.training_size if args.training_size is not None else TRAINING_SIZE - BATCH_SIZE = args.batch_size if args.batch_size is not None else BATCH_SIZE - if args.plot is False and args.plot_history is not None: - raise RuntimeError('--plot-history requires --plot') - PLOT_HISTORY = args.plot_history if args.plot_history is not None else PLOT_HISTORY + LEARNING_RATE = args.learning_rate + TRAINING_SIZE = args.training_size + BATCH_SIZE = args.batch_size + PLOT = args.plot + PLOT_HISTORY = args.plot_history + TENSORBOARD = args.tensorboard + TB_LOGPATH = args.tensorboard_logpath if args.tensorboard_logpath is not None else TB_LOGPATH + VAE_USE_SGD = args.use_sgd + VAE_USE_KLDIV = args.use_kldiv + ES_PATIENCE = args.patience sys.stderr.write('Recv buffer size: {}\n'.format(nDPIsrvd.NETWORK_BUFFER_MAX_SIZE)) sys.stderr.write('Connecting to {} ..\n'.format(address[0]+':'+str(address[1]) if type(address) is tuple else address)) - sys.stderr.write('PLOT={}, PLOT_HISTORY={}, LEARNING_RATE={}, TRAINING_SIZE={}, BATCH_SIZE={}\n\n'.format(args.plot, PLOT_HISTORY, LEARNING_RATE, TRAINING_SIZE, BATCH_SIZE)) + sys.stderr.write('PLOT={}, PLOT_HISTORY={}, LEARNING_RATE={}, TRAINING_SIZE={}, BATCH_SIZE={}\n\n'.format(PLOT, PLOT_HISTORY, LEARNING_RATE, TRAINING_SIZE, BATCH_SIZE)) mgr = mp.Manager() @@ -332,7 +363,7 @@ if __name__ == '__main__': )) keras_job.start() - if args.plot is True: + if PLOT is True: plot_job = mp.Process(target=plot_worker, args=(shared_shutdown_event, shared_plot_queue)) plot_job.start() @@ -341,9 +372,10 @@ if __name__ == '__main__': except KeyboardInterrupt: print('\nShutting down worker processess..') - if args.plot is True: + if PLOT is True: plot_job.terminate() plot_job.join() nDPIsrvd_job.terminate() nDPIsrvd_job.join() - keras_job.join() + keras_job.join(timeout=3) + keras_job.terminate() |