diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 68a5b23..076f843 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -159,6 +159,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura "DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "RunPath": run_path, "ColorMode": model.ImageMode, + "Model": model, }); err != nil { return } @@ -166,7 +167,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura // Run the command out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output() if err != nil { - fmt.Println(string(out)) + c.Logger.Debug(string(out)) return } diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 6c19c50..50ed4b7 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -32,7 +32,13 @@ def pathToLabel(path): #return tf.strings.as_string([path]) def decode_image(img): + {{ if eq .Model.Format "png" }} img = tf.io.decode_png(img, channels={{.ColorMode}}) + {{ else if eq .Model.Format "jpeg" }} + img = tf.io.decode_jpeg(img, channels={{.ColorMode}}) + {{ else }} + ERROR + {{ end }} return tf.image.resize(img, image_size) def process_path(path):