Spectrum / predict.py
nilekhet's picture
Upload 6 files
b743670
import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt
from lime import lime_image
from skimage.segmentation import mark_boundaries
def explain_instance(image_path, model, num_features=5, num_samples=1000):
img = load_img(image_path, target_size=image_size)
img_array = img_to_array(img) / 255
explanation = explainer.explain_instance(img_array, model.predict, top_labels=num_classes, hide_color=0,
num_samples=num_samples, num_features=num_features)
return explanation
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: predict.py image_path")
sys.exit(1)
image_path = sys.argv[1]
image_size = (200, 200)
model_path = "malware_classifier_lime.h5"
model = load_model(model_path)
num_classes = 119
explainer = lime_image.LimeImageExplainer()
explanation = explain_instance(image_path, model)
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
img = load_img(image_path, target_size=image_size)
img_array = img_to_array(img) / 255
# Display the original image
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(img_array)
plt.title("Original Image")
plt.axis("off")
# Display the LIME explanation
plt.subplot(1, 2, 2)
plt.imshow(mark_boundaries(temp, mask))
plt.title("LIME Explanation")
plt.axis("off")
plt.show()
# Make a prediction
img = load_img(image_path, target_size=image_size)
img_array = img_to_array(img) / 255
img_array = np.expand_dims(img_array, axis=0)
prediction = model.predict(img_array)
predicted_class = np.argmax(prediction)
# Get the class name
class_name = list(train_generator.class_indices.keys())[list(train_generator.class_indices.values()).index(predicted_class)]
print(f"Predicted class: {predicted_class}, Class name: {class_name}")