feat: closes #40

This commit is contained in:
2023-10-19 10:44:13 +01:00
parent f163e25fba
commit 2c3539b81a
4 changed files with 184 additions and 105 deletions

View File

@@ -93,6 +93,10 @@ val_ds = list_ds.take(val_size)
dataset = prepare_dataset(train_ds)
dataset_validation = prepare_dataset(val_ds)
{{ if .LoadPrev }}
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
{{ else }}
model = keras.Sequential([
{{- range .Layers }}
{{- if eq .LayerType 1}}
@@ -106,13 +110,14 @@ model = keras.Sequential([
{{- end }}
{{- end }}
])
{{ end }}
model.compile(
loss=losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()])
acc = his.history["accuracy"]
@@ -120,6 +125,6 @@ f = open("accuracy.val", "w")
f.write(str(acc[-1]))
f.close()
tf.saved_model.save(model, "model")
# model.save("model.keras", save_format="tf")
tf.saved_model.save(model, "model")
model.save("model.keras")