aboutsummaryrefslogtreecommitdiff
path: root/examples/py-machine-learning/sklearn-ml.py
diff options
context:
space:
mode:
authorToni Uhlig <matzeton@googlemail.com>2022-09-22 19:07:08 +0200
committerToni Uhlig <matzeton@googlemail.com>2022-09-22 19:07:08 +0200
commit9a28475bba88b711b7075b58473b7e5b5df1f393 (patch)
tree73cdf56320f14b5fe0fbfb2e930cf7ea025f9117 /examples/py-machine-learning/sklearn-ml.py
parent28971cd7647a79253000fb33e52b5d2129e5ba62 (diff)
Improved flown analyse event:
* store packet directions * merged direction based IATs * merged direction based PKTLENs Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'examples/py-machine-learning/sklearn-ml.py')
-rwxr-xr-xexamples/py-machine-learning/sklearn-ml.py89
1 files changed, 89 insertions, 0 deletions
diff --git a/examples/py-machine-learning/sklearn-ml.py b/examples/py-machine-learning/sklearn-ml.py
new file mode 100755
index 000000000..301f4e907
--- /dev/null
+++ b/examples/py-machine-learning/sklearn-ml.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+
+# pip3 install -U scikit-learn scipy matplotlib
+
+import os
+import sklearn
+import sklearn.ensemble
+import sys
+
+sys.path.append(os.path.dirname(sys.argv[0]) + '/../../dependencies')
+sys.path.append(os.path.dirname(sys.argv[0]) + '/../share/nDPId')
+sys.path.append(os.path.dirname(sys.argv[0]))
+sys.path.append(sys.base_prefix + '/share/nDPId')
+import nDPIsrvd
+from nDPIsrvd import nDPIsrvdSocket, TermColor
+
+class RFC(sklearn.ensemble.RandomForestClassifier):
+ def __init__(self, max_samples):
+ self.max_samples = max_samples
+ self.samples_x = []
+ self.samples_y = []
+ super().__init__(verbose=1, n_estimators=1000, max_samples=max_samples)
+
+ def addSample(self, x, y):
+ self.samples_x += x
+ self.samples_y += y
+
+ def fit(self):
+ if len(self.samples_x) != self.max_samples or \
+ len(self.samples_y) != self.max_samples:
+ return False
+
+ super().fit(self.samples_x, self.samples_y)
+ self.samples_x = []
+ self.samples_y = []
+ return True
+
+def onJsonLineRecvd(json_dict, instance, current_flow, global_user_data):
+ if 'flow_event_name' not in json_dict:
+ return True
+ if json_dict['flow_event_name'] != 'analyse':
+ return True
+
+ if 'ndpi' not in json_dict:
+ return True
+ if 'proto' not in json_dict['ndpi']:
+ return True
+
+ #print(json_dict)
+
+ features = [[]]
+ features[0] += json_dict['data_analysis']['bins']['c_to_s']
+ features[0] += json_dict['data_analysis']['bins']['s_to_c']
+ #print(features)
+
+ out = ''
+ rfc = global_user_data
+ try:
+ out += '[Predict: {}]'.format(rfc.predict(features)[0])
+ except sklearn.exceptions.NotFittedError:
+ pass
+
+ # TLS.DoH_DoT
+ if json_dict['ndpi']['proto'].startswith('TLS.') is not True and \
+ json_dict['ndpi']['proto'] != 'TLS':
+ rfc.addSample(features, [0])
+ else:
+ rfc.addSample(features, [1])
+
+ if rfc.fit() is True:
+ out += '*** FIT *** '
+ out += '[{}]'.format(json_dict['ndpi']['proto'])
+ print(out)
+
+ return True
+
+if __name__ == '__main__':
+ argparser = nDPIsrvd.defaultArgumentParser()
+ args = argparser.parse_args()
+ address = nDPIsrvd.validateAddress(args)
+
+ sys.stderr.write('Recv buffer size: {}\n'.format(nDPIsrvd.NETWORK_BUFFER_MAX_SIZE))
+ sys.stderr.write('Connecting to {} ..\n'.format(address[0]+':'+str(address[1]) if type(address) is tuple else address))
+
+ rfc = RFC(10)
+
+ nsock = nDPIsrvdSocket()
+ nsock.connect(address)
+ nsock.loop(onJsonLineRecvd, None, rfc)