aboutsummaryrefslogtreecommitdiff
path: root/dga/tensorflow_tests/test_script.py
diff options
context:
space:
mode:
authorLuca Deri <deri@ntop.org>2024-10-28 12:55:18 +0100
committerLuca Deri <deri@ntop.org>2024-10-28 12:55:18 +0100
commitecd3c734d00671a4fe5ac1713422dae55f2bad2f (patch)
tree60d57b74974f2e3515e6c16a4a6e1b925bffd2cd /dga/tensorflow_tests/test_script.py
parentfecc378e0426cbad42da636bb075dadb3fb24e61 (diff)
Rename
Diffstat (limited to 'dga/tensorflow_tests/test_script.py')
-rw-r--r--dga/tensorflow_tests/test_script.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/dga/tensorflow_tests/test_script.py b/dga/tensorflow_tests/test_script.py
new file mode 100644
index 000000000..5c946c8cf
--- /dev/null
+++ b/dga/tensorflow_tests/test_script.py
@@ -0,0 +1,22 @@
+import tensorflow as tf
+import joblib
+import numpy as np
+from sklearn.metrics import classification_report, accuracy_score
+
+# Load the model
+model = tf.keras.models.load_model("dga_model.keras")
+X_test, y_test = joblib.load("test_data.pkl")
+label_encoder = joblib.load("label_encoder.pkl")
+tokenizer = joblib.load("tokenizer.pkl")
+
+# Make predictions on the test set
+y_pred = (model.predict(X_test) > 0.5).astype("int32").flatten()
+
+# Calculate accuracy
+accuracy = accuracy_score(y_test, y_pred)
+print(f"Accuracy: {accuracy:.4f}")
+
+# Generate the classification report
+report = classification_report(y_test, y_pred, target_names=label_encoder.classes_)
+print("\nClassification Report:")
+print(report)