feat: closes #39
This commit is contained in:
@@ -434,7 +434,20 @@
|
||||
{{/* TODO improve this */}}
|
||||
Training the model...<br/>
|
||||
{{/* TODO Add progress status on definitions */}}
|
||||
{{/* TODO Add aility to stop training */}}
|
||||
{{ range .Defs}}
|
||||
<div>
|
||||
<div>
|
||||
{{.Status}}
|
||||
</div>
|
||||
<div>
|
||||
{{.EpochProgress}}
|
||||
</div>
|
||||
<div>
|
||||
{{.Accuracy}}
|
||||
</div>
|
||||
</div>
|
||||
{{ end }}
|
||||
{{/* TODO Add ability to stop training */}}
|
||||
</div>
|
||||
{{/* Model Ready */}}
|
||||
{{ else if (eq .Model.Status 5)}}
|
||||
|
||||
@@ -4,6 +4,14 @@ import pandas as pd
|
||||
from tensorflow import keras
|
||||
from tensorflow.data import AUTOTUNE
|
||||
from keras import layers, losses, optimizers
|
||||
import requests
|
||||
|
||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||
def on_epoch_begin(self, epoch, *args, **kwargs):
|
||||
if (epoch % 5) == 0:
|
||||
# TODO change this
|
||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}')
|
||||
|
||||
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
image_size = ({{ .Size }})
|
||||
@@ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
def pathToLabel(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||
{{ if eq .Model.Format "png" }}
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
{{ else if eq .Model.Format "jpeg" }}
|
||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
{{ else }}
|
||||
ERROR
|
||||
{{ end }}
|
||||
return table.lookup(tf.strings.as_string([path]))
|
||||
#return tf.strings.as_string([path])
|
||||
|
||||
def decode_image(img):
|
||||
{{ if eq .Model.Format "png" }}
|
||||
@@ -100,7 +112,7 @@ model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50)
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
|
||||
|
||||
acc = his.history["accuracy"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user