StandardCAS-NSTID commited on
Commit
065e0c7
1 Parent(s): 1766e29

Create Estallie_Trainer.py

Browse files
Files changed (1) hide show
  1. Estallie_Trainer.py +60 -0
Estallie_Trainer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
4
+
5
+ # Define constants
6
+ IMAGE_SIZE = (512, 512)
7
+ BATCH_SIZE = 4
8
+ EPOCHS = 10
9
+ TRAIN_DIR = 'T'
10
+ VALID_DIR = 'T'
11
+ MODEL_PATH = 'nsfw_classifier.h5'
12
+
13
+ # Create an image data generator for training data
14
+ train_datagen = ImageDataGenerator(rescale=1./255)
15
+ train_generator = train_datagen.flow_from_directory(
16
+ TRAIN_DIR,
17
+ target_size=IMAGE_SIZE,
18
+ batch_size=BATCH_SIZE,
19
+ class_mode='binary')
20
+
21
+ # Create an image data generator for validation data
22
+ valid_datagen = ImageDataGenerator(rescale=1./255)
23
+ valid_generator = valid_datagen.flow_from_directory(
24
+ VALID_DIR,
25
+ target_size=IMAGE_SIZE,
26
+ batch_size=BATCH_SIZE,
27
+ class_mode='binary')
28
+
29
+ # Check if the model already exists
30
+ if os.path.exists(MODEL_PATH):
31
+ print("Loading existing model")
32
+ model = tf.keras.models.load_model(MODEL_PATH)
33
+ else:
34
+ print("Creating new model")
35
+ # Define the model
36
+ model = tf.keras.models.Sequential([
37
+ tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
38
+ tf.keras.layers.MaxPooling2D(2, 2),
39
+ tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
40
+ tf.keras.layers.MaxPooling2D(2, 2),
41
+ tf.keras.layers.Flatten(),
42
+ tf.keras.layers.Dense(512, activation='relu'),
43
+ tf.keras.layers.Dense(1, activation='sigmoid')
44
+ ])
45
+
46
+ # Compile the model
47
+ model.compile(loss='binary_crossentropy',
48
+ optimizer='adam',
49
+ metrics=['accuracy'])
50
+
51
+ # Train the model
52
+ history = model.fit(
53
+ train_generator,
54
+ steps_per_epoch=train_generator.samples // BATCH_SIZE,
55
+ epochs=EPOCHS,
56
+ validation_data=valid_generator,
57
+ validation_steps=valid_generator.samples // BATCH_SIZE)
58
+
59
+ # Save the model
60
+ model.save(MODEL_PATH)