diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2022-10-07 17:55:17 +0200 |
---|---|---|
committer | lns <matzeton@googlemail.com> | 2022-10-09 18:31:45 +0200 |
commit | 4654faf38128f4e793d654c78eee3c5b8d226bbf (patch) | |
tree | dac93c001fbf5dafa9b28908d35ec93f9567c8af /examples/py-machine-learning/sklearn-ml.py | |
parent | b7a17d62c73a0be53ee3ce2940e623ebe4a1252c (diff) |
Improved py-machine-learning example.
* c-analysed: fixed quoting bug
* nDPId: fixed invalid iat storing/serialisation
* nDPId: free data analysis after event was sent
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Signed-off-by: lns <matzeton@googlemail.com>
Diffstat (limited to 'examples/py-machine-learning/sklearn-ml.py')
-rwxr-xr-x | examples/py-machine-learning/sklearn-ml.py | 142 |
1 files changed, 97 insertions, 45 deletions
diff --git a/examples/py-machine-learning/sklearn-ml.py b/examples/py-machine-learning/sklearn-ml.py index 301f4e907..86a3bae44 100755 --- a/examples/py-machine-learning/sklearn-ml.py +++ b/examples/py-machine-learning/sklearn-ml.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -# pip3 install -U scikit-learn scipy matplotlib - +import csv +import numpy import os import sklearn import sklearn.ensemble @@ -14,26 +14,56 @@ sys.path.append(sys.base_prefix + '/share/nDPId') import nDPIsrvd from nDPIsrvd import nDPIsrvdSocket, TermColor -class RFC(sklearn.ensemble.RandomForestClassifier): - def __init__(self, max_samples): - self.max_samples = max_samples - self.samples_x = [] - self.samples_y = [] - super().__init__(verbose=1, n_estimators=1000, max_samples=max_samples) - - def addSample(self, x, y): - self.samples_x += x - self.samples_y += y - - def fit(self): - if len(self.samples_x) != self.max_samples or \ - len(self.samples_y) != self.max_samples: - return False - - super().fit(self.samples_x, self.samples_y) - self.samples_x = [] - self.samples_y = [] - return True + +N_DIRS = 0 +N_BINS = 0 + +ENABLE_FEATURE_IAT = True +ENABLE_FEATURE_PKTLEN = True +ENABLE_FEATURE_DIRS = True +ENABLE_FEATURE_BINS = True + +def getFeatures(json): + return [json['flow_src_packets_processed'], + json['flow_dst_packets_processed'], + json['flow_src_tot_l4_payload_len'], + json['flow_dst_tot_l4_payload_len']] + +def getFeaturesFromArray(json, expected_len=0): + if type(json) is str: + dirs = numpy.fromstring(json, sep=',', dtype=int) + dirs = numpy.asarray(dirs, dtype=int).tolist() + elif type(json) is list: + dirs = json + else: + raise TypeError('Invalid type: {}.'.format(type(json))) + + if expected_len > 0 and len(dirs) != expected_len: + raise RuntimeError('Invalid array length; Expected {}, Got {}.'.format(expected_len, len(dirs))) + + return dirs + +def getRelevantFeaturesCSV(line): + return [ + getFeatures(line) + \ + getFeaturesFromArray(line['iat_data'], N_DIRS - 1) if ENABLE_FEATURE_IAT is True else [] + \ + getFeaturesFromArray(line['pktlen_data'], N_DIRS) if ENABLE_FEATURE_PKTLEN is True else [] + \ + getFeaturesFromArray(line['directions'], N_DIRS) if ENABLE_FEATURE_DIRS is True else [] + \ + getFeaturesFromArray(line['bins_c_to_s'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ + getFeaturesFromArray(line['bins_s_to_c'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ + [] + ] + +def getRelevantFeaturesJSON(line): + return [ + getFeatures(line) + \ + getFeaturesFromArray(line['data_analysis']['iat']['data'], N_DIRS - 1) if ENABLE_FEATURE_IAT is True else [] + \ + getFeaturesFromArray(line['data_analysis']['pktlen']['data'], N_DIRS) if ENABLE_FEATURE_PKTLEN is True else [] + \ + getFeaturesFromArray(line['data_analysis']['directions'], N_DIRS) if ENABLE_FEATURE_DIRS is True else [] + \ + getFeaturesFromArray(line['data_analysis']['bins']['c_to_s'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ + getFeaturesFromArray(line['data_analysis']['bins']['s_to_c'], N_BINS) if ENABLE_FEATURE_BINS is True else [] + \ + [] + ] def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): if 'flow_event_name' not in json_dict: @@ -48,42 +78,64 @@ def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): #print(json_dict) - features = [[]] - features[0] += json_dict['data_analysis']['bins']['c_to_s'] - features[0] += json_dict['data_analysis']['bins']['s_to_c'] - #print(features) + model, = global_user_data - out = '' - rfc = global_user_data try: - out += '[Predict: {}]'.format(rfc.predict(features)[0]) - except sklearn.exceptions.NotFittedError: - pass - - # TLS.DoH_DoT - if json_dict['ndpi']['proto'].startswith('TLS.') is not True and \ - json_dict['ndpi']['proto'] != 'TLS': - rfc.addSample(features, [0]) - else: - rfc.addSample(features, [1]) - - if rfc.fit() is True: - out += '*** FIT *** ' - out += '[{}]'.format(json_dict['ndpi']['proto']) - print(out) + print('DPI Engine detected: "{}", Prediction: "{}"'.format( + json_dict['ndpi']['proto'], model.predict(getRelevantFeaturesJSON(json_dict)))) + except Exception as err: + print('Got exception `{}\'\nfor json: {}'.format(err, json_dict)) return True + if __name__ == '__main__': argparser = nDPIsrvd.defaultArgumentParser() + argparser.add_argument('--csv', action='store', required=True, + help='Input CSV file generated with nDPIsrvd-analysed.') + argparser.add_argument('--proto-class', action='store', required=True, + help='nDPId protocol class of interest, used for training and prediction. Example: tls.youtube') + argparser.add_argument('--enable-iat', action='store', default=True, + help='Use packet (I)nter (A)rrival (T)ime for learning and prediction.') + argparser.add_argument('--enable-pktlen', action='store', default=False, + help='Use layer 4 packet lengths for learning and prediction.') + argparser.add_argument('--enable-dirs', action='store', default=True, + help='Use packet directions for learning and prediction.') + argparser.add_argument('--enable-bins', action='store', default=True, + help='Use packet length distribution for learning and prediction.') args = argparser.parse_args() address = nDPIsrvd.validateAddress(args) + ENABLE_FEATURE_IAT = args.enable_iat + ENABLE_FEATURE_PKTLEN = args.enable_pktlen + ENABLE_FEATURE_DIRS = args.enable_dirs + ENABLE_FEATURE_BINS = args.enable_bins + sys.stderr.write('Recv buffer size: {}\n'.format(nDPIsrvd.NETWORK_BUFFER_MAX_SIZE)) sys.stderr.write('Connecting to {} ..\n'.format(address[0]+':'+str(address[1]) if type(address) is tuple else address)) - rfc = RFC(10) + sys.stderr.write('Learning via CSV..\n') + with open(args.csv, newline='\n') as csvfile: + reader = csv.DictReader(csvfile, delimiter=',', quotechar='"') + X = list() + y = list() + + for line in reader: + N_DIRS = len(getFeaturesFromArray(line['directions'])) + N_BINS = len(getFeaturesFromArray(line['bins_c_to_s'])) + break + + for line in reader: + try: + X += getRelevantFeaturesCSV(line) + y += [1 if line['proto'].lower().startswith(args.proto_class) is True else 0] + except RuntimeError as err: + print('Error: `{}\'\non line: {}'.format(err, line)) + + model = sklearn.ensemble.RandomForestClassifier() + model.fit(X, y) + sys.stderr.write('Predicting realtime traffic..\n') nsock = nDPIsrvdSocket() nsock.connect(address) - nsock.loop(onJsonLineRecvd, None, rfc) + nsock.loop(onJsonLineRecvd, None, (model,)) |