diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2022-09-22 19:07:08 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2022-09-22 19:07:08 +0200 |
commit | 9a28475bba88b711b7075b58473b7e5b5df1f393 (patch) | |
tree | 73cdf56320f14b5fe0fbfb2e930cf7ea025f9117 /examples/py-machine-learning/sklearn-ml.py | |
parent | 28971cd7647a79253000fb33e52b5d2129e5ba62 (diff) |
Improved flown analyse event:
* store packet directions
* merged direction based IATs
* merged direction based PKTLENs
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'examples/py-machine-learning/sklearn-ml.py')
-rwxr-xr-x | examples/py-machine-learning/sklearn-ml.py | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/examples/py-machine-learning/sklearn-ml.py b/examples/py-machine-learning/sklearn-ml.py new file mode 100755 index 000000000..301f4e907 --- /dev/null +++ b/examples/py-machine-learning/sklearn-ml.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +# pip3 install -U scikit-learn scipy matplotlib + +import os +import sklearn +import sklearn.ensemble +import sys + +sys.path.append(os.path.dirname(sys.argv[0]) + '/../../dependencies') +sys.path.append(os.path.dirname(sys.argv[0]) + '/../share/nDPId') +sys.path.append(os.path.dirname(sys.argv[0])) +sys.path.append(sys.base_prefix + '/share/nDPId') +import nDPIsrvd +from nDPIsrvd import nDPIsrvdSocket, TermColor + +class RFC(sklearn.ensemble.RandomForestClassifier): + def __init__(self, max_samples): + self.max_samples = max_samples + self.samples_x = [] + self.samples_y = [] + super().__init__(verbose=1, n_estimators=1000, max_samples=max_samples) + + def addSample(self, x, y): + self.samples_x += x + self.samples_y += y + + def fit(self): + if len(self.samples_x) != self.max_samples or \ + len(self.samples_y) != self.max_samples: + return False + + super().fit(self.samples_x, self.samples_y) + self.samples_x = [] + self.samples_y = [] + return True + +def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data): + if 'flow_event_name' not in json_dict: + return True + if json_dict['flow_event_name'] != 'analyse': + return True + + if 'ndpi' not in json_dict: + return True + if 'proto' not in json_dict['ndpi']: + return True + + #print(json_dict) + + features = [[]] + features[0] += json_dict['data_analysis']['bins']['c_to_s'] + features[0] += json_dict['data_analysis']['bins']['s_to_c'] + #print(features) + + out = '' + rfc = global_user_data + try: + out += '[Predict: {}]'.format(rfc.predict(features)[0]) + except sklearn.exceptions.NotFittedError: + pass + + # TLS.DoH_DoT + if json_dict['ndpi']['proto'].startswith('TLS.') is not True and \ + json_dict['ndpi']['proto'] != 'TLS': + rfc.addSample(features, [0]) + else: + rfc.addSample(features, [1]) + + if rfc.fit() is True: + out += '*** FIT *** ' + out += '[{}]'.format(json_dict['ndpi']['proto']) + print(out) + + return True + +if __name__ == '__main__': + argparser = nDPIsrvd.defaultArgumentParser() + args = argparser.parse_args() + address = nDPIsrvd.validateAddress(args) + + sys.stderr.write('Recv buffer size: {}\n'.format(nDPIsrvd.NETWORK_BUFFER_MAX_SIZE)) + sys.stderr.write('Connecting to {} ..\n'.format(address[0]+':'+str(address[1]) if type(address) is tuple else address)) + + rfc = RFC(10) + + nsock = nDPIsrvdSocket() + nsock.connect(address) + nsock.loop(onJsonLineRecvd, None, rfc) |