1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| import matplotlib.pyplot as plt
from keras import layers from tensorflow import keras from init import train_dataset, validation_dataset, data_augmentation
inputs = keras.Input(shape=(180, 180, 3)) x = data_augmentation(inputs) x = layers.Rescaling(1./255)(x)
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(x) x = layers.MaxPool2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size=3, activation="relu")(x) x = layers.MaxPool2D(pool_size=2)(x)
x = layers.Conv2D(filters=128, kernel_size=3, activation="relu")(x) x = layers.MaxPool2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x) x = layers.MaxPool2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x) x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x) outputs = layers.Dense(1, activation="sigmoid")(x) model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
model.compile(loss="binary_crossentropy", optimizer="rmsprop", metrics=["accuracy"])
callbacks = [ keras.callbacks.ModelCheckpoint( filepath="convent_from_scratch.keras", save_best_only=True, monitor="val_loss") ]
history = model.fit(train_dataset, epochs=80, validation_data=validation_dataset, callbacks=callbacks)
model.save('model.h5')
acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss']
epochs = range(len(acc)) plt.plot(epochs, acc, 'r', label='Training accuracy') plt.plot(epochs, val_acc, 'b', label='validation accuracy') plt.title('Training and validation accuracy') plt.legend(loc=0) plt.savefig('accuracy_plot.png')
plt.figure()
epochs = range(len(loss)) plt.plot(epochs, loss, 'r', label='Training loss') plt.plot(epochs, val_loss, 'b', label='validation loss') plt.title('Training and validation loss') plt.legend(loc=0) plt.savefig('loss_plot.png')
|