feat: closes #40
This commit is contained in:
@@ -93,6 +93,10 @@ val_ds = list_ds.take(val_size)
|
||||
dataset = prepare_dataset(train_ds)
|
||||
dataset_validation = prepare_dataset(val_ds)
|
||||
|
||||
|
||||
{{ if .LoadPrev }}
|
||||
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
|
||||
{{ else }}
|
||||
model = keras.Sequential([
|
||||
{{- range .Layers }}
|
||||
{{- if eq .LayerType 1}}
|
||||
@@ -106,13 +110,14 @@ model = keras.Sequential([
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
])
|
||||
{{ end }}
|
||||
|
||||
model.compile(
|
||||
loss=losses.SparseCategoricalCrossentropy(),
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()])
|
||||
|
||||
acc = his.history["accuracy"]
|
||||
|
||||
@@ -120,6 +125,6 @@ f = open("accuracy.val", "w")
|
||||
f.write(str(acc[-1]))
|
||||
f.close()
|
||||
|
||||
tf.saved_model.save(model, "model")
|
||||
|
||||
# model.save("model.keras", save_format="tf")
|
||||
tf.saved_model.save(model, "model")
|
||||
model.save("model.keras")
|
||||
|
||||
Reference in New Issue
Block a user