summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorToni Uhlig <matzeton@googlemail.com>2022-10-14 08:49:25 +0200
committerToni Uhlig <matzeton@googlemail.com>2022-10-15 11:59:39 +0200
commit6292102f93086d2d61de874640f0b87c89c02b44 (patch)
tree76966b4d496f0b2db93603171c7632bacadb8fc2 /examples
parent80f84488340d681211ecd34c60a974c7500f9ee5 (diff)
py-machine-learning: load and save trained models
* added link to a pre-trained model Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/README.md3
-rw-r--r--examples/py-machine-learning/requirements.txt2
-rwxr-xr-xexamples/py-machine-learning/sklearn-random-forest.py93
3 files changed, 63 insertions, 35 deletions
diff --git a/examples/README.md b/examples/README.md
index 9a1d368e0..71b7b8204 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -43,6 +43,9 @@ This way you should get 9 different classification classes.
You may notice that some classes e.g. TLS protocol classifications may have a higher false-negative rate.
Unfortunately, I can not provide any datasets due to some privacy concerns.
+But you can use a [pre-trained model](https://drive.google.com/file/d/1KEwbP-Gx7KJr54wNoa63I56VI4USCAPL/view?usp=sharing) with `--load-model` using python-joblib.
+Please send me your CSV files to improve the model. I will treat those files confidential.
+They'll only be used for the training process and purged afterwards.
## py-flow-dashboard
diff --git a/examples/py-machine-learning/requirements.txt b/examples/py-machine-learning/requirements.txt
index 672559647..33cfad38c 100644
--- a/examples/py-machine-learning/requirements.txt
+++ b/examples/py-machine-learning/requirements.txt
@@ -1,3 +1,5 @@
+joblib
+tensorflow
scikit-learn
scipy
matplotlib
diff --git a/examples/py-machine-learning/sklearn-random-forest.py b/examples/py-machine-learning/sklearn-random-forest.py
index 2c4a2251b..ada238e94 100755
--- a/examples/py-machine-learning/sklearn-random-forest.py
+++ b/examples/py-machine-learning/sklearn-random-forest.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import csv
+import joblib
import matplotlib.pyplot
import numpy
import os
@@ -177,7 +178,11 @@ def isProtoClass(proto_class, line):
if __name__ == '__main__':
argparser = nDPIsrvd.defaultArgumentParser()
- argparser.add_argument('--csv', action='store', required=True,
+ argparser.add_argument('--load-model', action='store',
+ help='Load a pre-trained model file.')
+ argparser.add_argument('--save-model', action='store',
+ help='Save the trained model to a file.')
+ argparser.add_argument('--csv', action='store',
help='Input CSV file generated with nDPIsrvd-analysed.')
argparser.add_argument('--proto-class', action='append', required=True,
help='nDPId protocol class of interest used for training and prediction. ' +
@@ -211,6 +216,14 @@ if __name__ == '__main__':
args = argparser.parse_args()
address = nDPIsrvd.validateAddress(args)
+ if args.csv is None and args.load_model is None:
+ sys.stderr.write('{}: Either `--csv` or `--load-model` required!\n'.format(sys.argv[0]))
+ sys.exit(1)
+
+ if args.csv is None and args.generate_feature_importance is True:
+ sys.stderr.write('{}: `--generate-feature-importance` requires `--csv`.\n'.format(sys.argv[0]))
+ sys.exit(1)
+
ENABLE_FEATURE_IAT = args.enable_iat
ENABLE_FEATURE_PKTLEN = args.enable_pktlen
ENABLE_FEATURE_DIRS = args.disable_dirs is False
@@ -222,40 +235,50 @@ if __name__ == '__main__':
for i in range(len(args.proto_class)):
args.proto_class[i] = args.proto_class[i].lower()
- sys.stderr.write('Learning via CSV..\n')
- with open(args.csv, newline='\n') as csvfile:
- reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
- X = list()
- y = list()
-
- for line in reader:
- N_DIRS = len(getFeaturesFromArray(line['directions']))
- N_BINS = len(getFeaturesFromArray(line['bins_c_to_s']))
- break
-
- for line in reader:
- try:
- X += getRelevantFeaturesCSV(line)
- y += [isProtoClass(args.proto_class, line['proto'])]
- except RuntimeError as err:
- print('Error: `{}\'\non line: {}'.format(err, line))
-
- sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X)))
-
- model = sklearn.ensemble.RandomForestClassifier(bootstrap=False,
- class_weight = args.sklearn_class_weight,
- n_jobs = args.sklearn_jobs,
- n_estimators = args.sklearn_estimators,
- verbose = args.sklearn_verbosity,
- min_samples_leaf = args.sklearn_min_samples_leaf,
- max_features = args.sklearn_max_features
- )
- sys.stderr.write('Training model..\n')
- model.fit(X, y)
-
- if args.generate_feature_importance is True:
- sys.stderr.write('Generating feature importance .. this may take some time')
- plotPermutatedImportance(model, X, y)
+ if args.load_model is not None:
+ sys.stderr.write('Loading model from {}\n'.format(args.load_model))
+ model = joblib.load(args.load_model)
+
+ if args.csv is not None:
+ sys.stderr.write('Learning via CSV..\n')
+ with open(args.csv, newline='\n') as csvfile:
+ reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
+ X = list()
+ y = list()
+
+ for line in reader:
+ N_DIRS = len(getFeaturesFromArray(line['directions']))
+ N_BINS = len(getFeaturesFromArray(line['bins_c_to_s']))
+ break
+
+ for line in reader:
+ try:
+ X += getRelevantFeaturesCSV(line)
+ y += [isProtoClass(args.proto_class, line['proto'])]
+ except RuntimeError as err:
+ print('Error: `{}\'\non line: {}'.format(err, line))
+
+ sys.stderr.write('CSV data set contains {} entries.\n'.format(len(X)))
+
+ if args.load_model is None:
+ model = sklearn.ensemble.RandomForestClassifier(bootstrap=False,
+ class_weight = args.sklearn_class_weight,
+ n_jobs = args.sklearn_jobs,
+ n_estimators = args.sklearn_estimators,
+ verbose = args.sklearn_verbosity,
+ min_samples_leaf = args.sklearn_min_samples_leaf,
+ max_features = args.sklearn_max_features
+ )
+ sys.stderr.write('Training model..\n')
+ model.fit(X, y)
+
+ if args.generate_feature_importance is True:
+ sys.stderr.write('Generating feature importance .. this may take some time\n')
+ plotPermutatedImportance(model, X, y)
+
+ if args.save_model is not None:
+ sys.stderr.write('Saving model to {}\n'.format(args.save_model))
+ joblib.dump(model, args.save_model)
print('Map[*] -> [0]')
for x in range(len(args.proto_class)):