aboutsummaryrefslogtreecommitdiff
path: root/dga/scikit-learn_tests/test_script.py
diff options
context:
space:
mode:
authorLuca Deri <deri@ntop.org>2024-10-26 21:15:36 +0200
committerLuca Deri <deri@ntop.org>2024-10-26 21:15:36 +0200
commitf5d903caadb00b3e2f68c74cf9da7a19cf4545f7 (patch)
treeaac33900efd9dc38b3cdb15e563055428b3765b0 /dga/scikit-learn_tests/test_script.py
parent0fb30c857d3f54546e8de61cd5234c2860474369 (diff)
Moved new DGA code
Diffstat (limited to 'dga/scikit-learn_tests/test_script.py')
-rw-r--r--dga/scikit-learn_tests/test_script.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/dga/scikit-learn_tests/test_script.py b/dga/scikit-learn_tests/test_script.py
new file mode 100644
index 000000000..4ded249f8
--- /dev/null
+++ b/dga/scikit-learn_tests/test_script.py
@@ -0,0 +1,23 @@
+import joblib
+from sklearn.neural_network import MLPClassifier
+from sklearn.metrics import classification_report, accuracy_score
+import time
+
+mlp = joblib.load('mlp_model.joblib')
+X_test = joblib.load('X_test.joblib')
+y_test = joblib.load('y_test.joblib')
+label_encoder = joblib.load('label_encoder.joblib')
+
+# Perform prediction
+start = time.time()
+y_pred = mlp.predict(X_test)
+print(f"Prediction time: {time.time()-start:.2f} seconds")
+
+# Evaluate the model
+accuracy = accuracy_score(y_test, y_pred)
+report = classification_report(y_test, y_pred, target_names=label_encoder.classes_)
+
+# Print the results
+print(f"Accuracy: {accuracy:.4f}")
+print("\nClassification Report:")
+print(report)