@@ -8,7 +8,7 @@ import requests
|
||||
|
||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
||||
|
||||
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
@@ -160,7 +160,9 @@ model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()], use_multiprocessing = True)
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[
|
||||
NotifyServerCallback(),
|
||||
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)], use_multiprocessing = True)
|
||||
|
||||
acc = his.history["accuracy"]
|
||||
|
||||
@@ -169,5 +171,5 @@ f.write(str(acc[-1]))
|
||||
f.close()
|
||||
|
||||
|
||||
tf.saved_model.save(model, "model")
|
||||
model.save("model.keras")
|
||||
tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
|
||||
model.save("{{ .SaveModelPath }}/model.keras")
|
||||
|
||||
Reference in New Issue
Block a user