aboutsummaryrefslogtreecommitdiff
path: root/tests/dga/ml_tests/scikit-learn_tests/test_script.py
diff options
context:
space:
mode:
authorYellowMan <giuseppedipalma002@gmail.com>2024-10-26 19:04:20 +0200
committerYellowMan <giuseppedipalma002@gmail.com>2024-10-26 19:04:20 +0200
commit551941ea4dd52b616efa75b875ceeb4fc807b063 (patch)
tree616e3842d3c46149290ae8d313bdd4417c5e8477 /tests/dga/ml_tests/scikit-learn_tests/test_script.py
parent3b1286ab03b0c9223ac208a01868d8c1f6c0ae00 (diff)
ml tests for dga detection
Diffstat (limited to 'tests/dga/ml_tests/scikit-learn_tests/test_script.py')
-rw-r--r--tests/dga/ml_tests/scikit-learn_tests/test_script.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/dga/ml_tests/scikit-learn_tests/test_script.py b/tests/dga/ml_tests/scikit-learn_tests/test_script.py
new file mode 100644
index 000000000..4ded249f8
--- /dev/null
+++ b/tests/dga/ml_tests/scikit-learn_tests/test_script.py
@@ -0,0 +1,23 @@
+import joblib
+from sklearn.neural_network import MLPClassifier
+from sklearn.metrics import classification_report, accuracy_score
+import time
+
+mlp = joblib.load('mlp_model.joblib')
+X_test = joblib.load('X_test.joblib')
+y_test = joblib.load('y_test.joblib')
+label_encoder = joblib.load('label_encoder.joblib')
+
+# Perform prediction
+start = time.time()
+y_pred = mlp.predict(X_test)
+print(f"Prediction time: {time.time()-start:.2f} seconds")
+
+# Evaluate the model
+accuracy = accuracy_score(y_test, y_pred)
+report = classification_report(y_test, y_pred, target_names=label_encoder.classes_)
+
+# Print the results
+print(f"Accuracy: {accuracy:.4f}")
+print("\nClassification Report:")
+print(report)