diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2022-10-10 15:40:25 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2022-10-10 16:44:12 +0200 |
commit | 20fc74f52742e5d512723d4f5fe314626e4a92f3 (patch) | |
tree | 70fa1fd99a1d4cf08e827f3f4030abbe30832840 /examples | |
parent | 2ede930eec0aceb292687351ed520784c060380c (diff) |
Improved py-machine-learning example.
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/README.md | 7 | ||||
-rw-r--r-- | examples/py-machine-learning/requirements.txt | 1 | ||||
-rwxr-xr-x | examples/py-machine-learning/sklearn-ml.py | 112 |
3 files changed, 97 insertions, 23 deletions
diff --git a/examples/README.md b/examples/README.md index 4a5b9b339..b378f26ae 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,6 +42,13 @@ Prints prettyfied information about flow events. Use sklearn together with CSVs created with **c-analysed** to train and predict DPI detections. +Try it with: `./examples/py-machine-learning/sklearn-ml.py --csv ./ndpi-analysed.csv --proto-class tls.youtube --proto-class tls.github --proto-class tls.spotify --proto-class tls.facebook --proto-class tls.instagram --proto-class tls.doh_dot --proto-class quic --proto-class icmp` + +This way you should get 9 different classification classes. +You may notice that some classes e.g. TLS protocol classifications may have a higher false-negative rate. + +Unfortunately, I can not provide any datasets due to some privacy concerns. + ## py-flow-dashboard A realtime web based graph using Plotly/Dash. diff --git a/examples/py-machine-learning/requirements.txt b/examples/py-machine-learning/requirements.txt index e29be016c..672559647 100644 --- a/examples/py-machine-learning/requirements.txt +++ b/examples/py-machine-learning/requirements.txt @@ -2,3 +2,4 @@ scikit-learn scipy matplotlib numpy +pandas diff --git a/examples/py-machine-learning/sklearn-ml.py b/examples/py-machine-learning/sklearn-ml.py index 86a3bae44..d5d08b281 100755 --- a/examples/py-machine-learning/sklearn-ml.py +++ b/examples/py-machine-learning/sklearn-ml.py @@ -1,11 +1,15 @@ #!/usr/bin/env python3 import csv +import matplotlib.pyplot import numpy import os +import pandas import sklearn import sklearn.ensemble +import sklearn.inspection import sys +import time sys.path.append(os.path.dirname(sys.argv[0]) + '/../../dependencies') sys.path.append(os.path.dirname(sys.argv[0]) + '/../share/nDPId') @@ -44,26 +48,63 @@ def getFeaturesFromArray(json, expected_len=0): return dirs def getRelevantFeaturesCSV(line): - return [ - getFeatures(line) + \ - getFeaturesFromArray(line['iat_data'], N_DIRS - 1) if ENABLE_FEATURE_IAT is True else [] + \ - getFeaturesFromArray(line['pktlen_data'], N_DIRS) if ENABLE_FEATURE_PKTLEN is True else [] + \ - getFeaturesFromArray(line['directions'], N_DIRS) if ENABLE_FEATURE_DIRS is True else [] + \ - getFeaturesFromArray(line['bins_c_to_s'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ - getFeaturesFromArray(line['bins_s_to_c'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ - [] - ] + ret = list() + ret.extend(getFeatures(line)); + if ENABLE_FEATURE_IAT is True: + ret.extend(getFeaturesFromArray(line['iat_data'], N_DIRS - 1)) + if ENABLE_FEATURE_PKTLEN is True: + ret.extend(getFeaturesFromArray(line['pktlen_data'], N_DIRS)) + if ENABLE_FEATURE_DIRS is True: + ret.extend(getFeaturesFromArray(line['directions'], N_DIRS)) + if ENABLE_FEATURE_BINS is True: + ret.extend(getFeaturesFromArray(line['bins_c_to_s'], N_BINS)) + ret.extend(getFeaturesFromArray(line['bins_s_to_c'], N_BINS)) + return [ret] def getRelevantFeaturesJSON(line): - return [ - getFeatures(line) + \ - getFeaturesFromArray(line['data_analysis']['iat']['data'], N_DIRS - 1) if ENABLE_FEATURE_IAT is True else [] + \ - getFeaturesFromArray(line['data_analysis']['pktlen']['data'], N_DIRS) if ENABLE_FEATURE_PKTLEN is True else [] + \ - getFeaturesFromArray(line['data_analysis']['directions'], N_DIRS) if ENABLE_FEATURE_DIRS is True else [] + \ - getFeaturesFromArray(line['data_analysis']['bins']['c_to_s'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ - getFeaturesFromArray(line['data_analysis']['bins']['s_to_c'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ - [] - ] + ret = list() + ret.extend(getFeatures(line)) + if ENABLE_FEATURE_IAT is True: + ret.extend(getFeaturesFromArray(line['data_analysis']['iat']['data'], N_DIRS - 1)) + if ENABLE_FEATURE_PKTLEN is True: + ret.extend(getFeaturesFromArray(line['data_analysis']['pktlen']['data'], N_DIRS)) + if ENABLE_FEATURE_DIRS is True: + ret.extend(getFeaturesFromArray(line['data_analysis']['directions'], N_DIRS)) + if ENABLE_FEATURE_BINS is True: + ret.extend(getFeaturesFromArray(line['data_analysis']['bins']['c_to_s'], N_BINS)) + ret.extend(getFeaturesFromArray(line['data_analysis']['bins']['s_to_c'], N_BINS) ) + return [ret] + +def getRelevantFeatureNames(): + names = list() + names.extend(['flow_src_packets_processed', 'flow_dst_packets_processed', + 'flow_src_tot_l4_payload_len', 'flow_dst_tot_l4_payload_len']) + if ENABLE_FEATURE_IAT is True: + for x in range(N_DIRS - 1): + names.append('iat_{}'.format(x)) + if ENABLE_FEATURE_PKTLEN is True: + for x in range(N_DIRS): + names.append('pktlen_{}'.format(x)) + if ENABLE_FEATURE_DIRS is True: + for x in range(N_DIRS): + names.append('dirs_{}'.format(x)) + if ENABLE_FEATURE_BINS is True: + for x in range(N_BINS): + names.append('bins_c_to_s_{}'.format(x)) + for x in range(N_BINS): + names.append('bins_s_to_c_{}'.format(x)) + return names + +def plotPermutatedImportance(model, X, y): + result = sklearn.inspection.permutation_importance(model, X, y, n_repeats=10, random_state=42, n_jobs=-1) + forest_importances = pandas.Series(result.importances_mean, index=getRelevantFeatureNames()) + + fig, ax = matplotlib.pyplot.subplots() + forest_importances.plot.bar(yerr=result.importances_std, ax=ax) + ax.set_title("Feature importances using permutation on full model") + ax.set_ylabel("Mean accuracy decrease") + fig.tight_layout() + matplotlib.pyplot.show() def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): if 'flow_event_name' not in json_dict: @@ -81,20 +122,34 @@ def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): model, = global_user_data try: - print('DPI Engine detected: "{}", Prediction: "{}"'.format( - json_dict['ndpi']['proto'], model.predict(getRelevantFeaturesJSON(json_dict)))) + X = getRelevantFeaturesJSON(json_dict) + y = model.predict(X) + s = model.score(X, y) + p = model.predict_log_proba(X) + print('DPI Engine detected: {:>24}, Prediction: {:>3}, Score: {}, Probabilities: {}'.format( + '"' + str(json_dict['ndpi']['proto']) + '"', '"' + str(y) + '"', s, p[0])) except Exception as err: print('Got exception `{}\'\nfor json: {}'.format(err, json_dict)) return True +def isProtoClass(proto_class, line): + s = line.lower() + + for x in range(len(proto_class)): + if s.startswith(proto_class[x].lower()) is True: + return x + 1 + + return 0 if __name__ == '__main__': argparser = nDPIsrvd.defaultArgumentParser() argparser.add_argument('--csv', action='store', required=True, help='Input CSV file generated with nDPIsrvd-analysed.') - argparser.add_argument('--proto-class', action='store', required=True, - help='nDPId protocol class of interest, used for training and prediction. Example: tls.youtube') + argparser.add_argument('--proto-class', action='append', required=True, + help='nDPId protocol class of interest used for training and prediction. Can be specified multiple times. Example: tls.youtube') + argparser.add_argument('--generate-feature-importance', action='store_true', + help='Generates the permutated feature importance with matplotlib.') argparser.add_argument('--enable-iat', action='store', default=True, help='Use packet (I)nter (A)rrival (T)ime for learning and prediction.') argparser.add_argument('--enable-pktlen', action='store', default=False, @@ -114,6 +169,9 @@ if __name__ == '__main__': sys.stderr.write('Recv buffer size: {}\n'.format(nDPIsrvd.NETWORK_BUFFER_MAX_SIZE)) sys.stderr.write('Connecting to {} ..\n'.format(address[0]+':'+str(address[1]) if type(address) is tuple else address)) + numpy.set_printoptions(formatter={'float_kind': "{:.1f}".format}, sign=' ') + numpy.seterr(divide = 'ignore') + sys.stderr.write('Learning via CSV..\n') with open(args.csv, newline='\n') as csvfile: reader = csv.DictReader(csvfile, delimiter=',', quotechar='"') @@ -128,13 +186,21 @@ if __name__ == '__main__': for line in reader: try: X += getRelevantFeaturesCSV(line) - y += [1 if line['proto'].lower().startswith(args.proto_class) is True else 0] + y += [isProtoClass(args.proto_class, line['proto'])] except RuntimeError as err: print('Error: `{}\'\non line: {}'.format(err, line)) model = sklearn.ensemble.RandomForestClassifier() model.fit(X, y) + if args.generate_feature_importance is True: + sys.stderr.write('Generating feature importance .. this may take some time') + plotPermutatedImportance(model, X, y) + + print('Map[*] -> [0]') + for x in range(len(args.proto_class)): + print('Map["{}"] -> [{}]'.format(args.proto_class[x], x + 1)) + sys.stderr.write('Predicting realtime traffic..\n') nsock = nDPIsrvdSocket() nsock.connect(address) |