diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2022-10-14 08:49:25 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2022-10-15 11:59:39 +0200 |
commit | 6292102f93086d2d61de874640f0b87c89c02b44 (patch) | |
tree | 76966b4d496f0b2db93603171c7632bacadb8fc2 | |
parent | 80f84488340d681211ecd34c60a974c7500f9ee5 (diff) |
py-machine-learning: load and save trained models
* added link to a pre-trained model
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
-rw-r--r-- | examples/README.md | 3 | ||||
-rw-r--r-- | examples/py-machine-learning/requirements.txt | 2 | ||||
-rwxr-xr-x | examples/py-machine-learning/sklearn-random-forest.py | 93 |
3 files changed, 63 insertions, 35 deletions
diff --git a/examples/README.md b/examples/README.md index 9a1d368e0..71b7b8204 100644 --- a/examples/README.md +++ b/examples/README.md @@ -43,6 +43,9 @@ This way you should get 9 different classification classes. You may notice that some classes e.g. TLS protocol classifications may have a higher false-negative rate. Unfortunately, I can not provide any datasets due to some privacy concerns. +But you can use a [pre-trained model](https://drive.google.com/file/d/1KEwbP-Gx7KJr54wNoa63I56VI4USCAPL/view?usp=sharing) with `--load-model` using python-joblib. +Please send me your CSV files to improve the model. I will treat those files confidential. +They'll only be used for the training process and purged afterwards. ## py-flow-dashboard diff --git a/examples/py-machine-learning/requirements.txt b/examples/py-machine-learning/requirements.txt index 672559647..33cfad38c 100644 --- a/examples/py-machine-learning/requirements.txt +++ b/examples/py-machine-learning/requirements.txt @@ -1,3 +1,5 @@ +joblib +tensorflow scikit-learn scipy matplotlib diff --git a/examples/py-machine-learning/sklearn-random-forest.py b/examples/py-machine-learning/sklearn-random-forest.py index 2c4a2251b..ada238e94 100755 --- a/examples/py-machine-learning/sklearn-random-forest.py +++ b/examples/py-machine-learning/sklearn-random-forest.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import csv +import joblib import matplotlib.pyplot import numpy import os @@ -177,7 +178,11 @@ def isProtoClass(proto_class, line): if __name__ == '__main__': argparser = nDPIsrvd.defaultArgumentParser() - argparser.add_argument('--csv', action='store', required=True, + argparser.add_argument('--load-model', action='store', + help='Load a pre-trained model file.') + argparser.add_argument('--save-model', action='store', + help='Save the trained model to a file.') + argparser.add_argument('--csv', action='store', help='Input CSV file generated with nDPIsrvd-analysed.') argparser.add_argument('--proto-class', action='append', required=True, help='nDPId protocol class of interest used for training and prediction. ' + @@ -211,6 +216,14 @@ if __name__ == '__main__': args = argparser.parse_args() address = nDPIsrvd.validateAddress(args) + if args.csv is None and args.load_model is None: + sys.stderr.write('{}: Either `--csv` or `--load-model` required!\n'.format(sys.argv[0])) + sys.exit(1) + + if args.csv is None and args.generate_feature_importance is True: + sys.stderr.write('{}: `--generate-feature-importance` requires `--csv`.\n'.format(sys.argv[0])) + sys.exit(1) + ENABLE_FEATURE_IAT = args.enable_iat ENABLE_FEATURE_PKTLEN = args.enable_pktlen ENABLE_FEATURE_DIRS = args.disable_dirs is False @@ -222,40 +235,50 @@ if __name__ == '__main__': for i in range(len(args.proto_class)): args.proto_class[i] = args.proto_class[i].lower() - sys.stderr.write('Learning via CSV..\n') - with open(args.csv, newline='\n') as csvfile: - reader = csv.DictReader(csvfile, delimiter=',', quotechar='"') - X = list() - y = list() - - for line in reader: - N_DIRS = len(getFeaturesFromArray(line['directions'])) - N_BINS = len(getFeaturesFromArray(line['bins_c_to_s'])) - break - - for line in reader: - try: - X += getRelevantFeaturesCSV(line) - y += [isProtoClass(args.proto_class, line['proto'])] - except RuntimeError as err: - print('Error: `{}\'\non line: {}'.format(err, line)) - - sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X))) - - model = sklearn.ensemble.RandomForestClassifier(bootstrap=False, - class_weight = args.sklearn_class_weight, - n_jobs = args.sklearn_jobs, - n_estimators = args.sklearn_estimators, - verbose = args.sklearn_verbosity, - min_samples_leaf = args.sklearn_min_samples_leaf, - max_features = args.sklearn_max_features - ) - sys.stderr.write('Training model..\n') - model.fit(X, y) - - if args.generate_feature_importance is True: - sys.stderr.write('Generating feature importance .. this may take some time') - plotPermutatedImportance(model, X, y) + if args.load_model is not None: + sys.stderr.write('Loading model from {}\n'.format(args.load_model)) + model = joblib.load(args.load_model) + + if args.csv is not None: + sys.stderr.write('Learning via CSV..\n') + with open(args.csv, newline='\n') as csvfile: + reader = csv.DictReader(csvfile, delimiter=',', quotechar='"') + X = list() + y = list() + + for line in reader: + N_DIRS = len(getFeaturesFromArray(line['directions'])) + N_BINS = len(getFeaturesFromArray(line['bins_c_to_s'])) + break + + for line in reader: + try: + X += getRelevantFeaturesCSV(line) + y += [isProtoClass(args.proto_class, line['proto'])] + except RuntimeError as err: + print('Error: `{}\'\non line: {}'.format(err, line)) + + sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X))) + + if args.load_model is None: + model = sklearn.ensemble.RandomForestClassifier(bootstrap=False, + class_weight = args.sklearn_class_weight, + n_jobs = args.sklearn_jobs, + n_estimators = args.sklearn_estimators, + verbose = args.sklearn_verbosity, + min_samples_leaf = args.sklearn_min_samples_leaf, + max_features = args.sklearn_max_features + ) + sys.stderr.write('Training model..\n') + model.fit(X, y) + + if args.generate_feature_importance is True: + sys.stderr.write('Generating feature importance .. this may take some time\n') + plotPermutatedImportance(model, X, y) + + if args.save_model is not None: + sys.stderr.write('Saving model to {}\n'.format(args.save_model)) + joblib.dump(model, args.save_model) print('Map[*] -> [0]') for x in range(len(args.proto_class)): |