2023-09-26 20:15:28 +01:00
|
|
|
import tensorflow as tf
|
|
|
|
import random
|
|
|
|
from tensorflow import keras
|
|
|
|
from keras import layers, losses, optimizers
|
|
|
|
|
|
|
|
seed = random.randint(0, 100000000)
|
|
|
|
|
|
|
|
batch_size = 100
|
|
|
|
|
|
|
|
dataset = keras.utils.image_dataset_from_directory(
|
|
|
|
"{{ .DataDir }}",
|
|
|
|
color_mode="rgb",
|
|
|
|
validation_split=0.2,
|
2023-09-29 13:27:43 +01:00
|
|
|
label_mode='categorical',
|
2023-09-26 20:15:28 +01:00
|
|
|
seed=seed,
|
2023-09-29 13:27:43 +01:00
|
|
|
shuffle=True,
|
2023-09-26 20:15:28 +01:00
|
|
|
subset="training",
|
|
|
|
image_size=({{ .Size }}),
|
|
|
|
batch_size=batch_size)
|
|
|
|
|
|
|
|
dataset_validation = keras.utils.image_dataset_from_directory(
|
|
|
|
"{{ .DataDir }}",
|
|
|
|
color_mode="rgb",
|
|
|
|
validation_split=0.2,
|
2023-09-29 13:27:43 +01:00
|
|
|
label_mode='categorical',
|
2023-09-26 20:15:28 +01:00
|
|
|
seed=seed,
|
2023-09-29 13:27:43 +01:00
|
|
|
shuffle=True,
|
2023-09-26 20:15:28 +01:00
|
|
|
subset="validation",
|
|
|
|
image_size=({{ .Size }}),
|
|
|
|
batch_size=batch_size)
|
|
|
|
|
|
|
|
model = keras.Sequential([
|
|
|
|
{{- range .Layers }}
|
|
|
|
{{- if eq .LayerType 1}}
|
|
|
|
layers.Rescaling(1./255),
|
|
|
|
{{- else if eq .LayerType 2 }}
|
2023-09-27 13:55:29 +01:00
|
|
|
layers.Dense({{ .Shape }}, activation="sigmoid"),
|
2023-09-26 20:15:28 +01:00
|
|
|
{{- else if eq .LayerType 3}}
|
|
|
|
layers.Flatten(),
|
|
|
|
{{- else }}
|
|
|
|
ERROR
|
|
|
|
{{- end }}
|
|
|
|
{{- end }}
|
|
|
|
])
|
|
|
|
|
2023-09-27 13:55:29 +01:00
|
|
|
model.compile(
|
2023-09-29 13:27:43 +01:00
|
|
|
loss=losses.CategoricalCrossentropy(),
|
2023-09-27 13:55:29 +01:00
|
|
|
optimizer=tf.keras.optimizers.Adam(),
|
|
|
|
metrics=['accuracy'])
|
2023-09-26 20:15:28 +01:00
|
|
|
|
2023-09-29 13:27:43 +01:00
|
|
|
his = model.fit(dataset, validation_data= dataset_validation, epochs=50)
|
2023-09-26 20:15:28 +01:00
|
|
|
|
2023-09-27 13:55:29 +01:00
|
|
|
acc = his.history["accuracy"]
|
|
|
|
|
|
|
|
f = open("accuracy.val", "w")
|
|
|
|
f.write(str(acc[-1]))
|
|
|
|
f.close()
|
|
|
|
|
2023-09-27 21:20:39 +01:00
|
|
|
tf.saved_model.save(model, "model")
|
|
|
|
|
|
|
|
# model.save("model.keras", save_format="tf")
|