diff --git a/logic/models/add.go b/logic/models/add.go index db00819..5b573e0 100644 --- a/logic/models/add.go +++ b/logic/models/add.go @@ -37,7 +37,6 @@ func loadBaseImage(c *Context, id string) { switch format { case "png": case "jpeg": - break default: // TODO better logging fmt.Printf("Found unkown format '%s'\n", format) diff --git a/logic/models/data.go b/logic/models/data.go index a7cf3fe..8aab454 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -126,7 +126,7 @@ func processZipFile(c *Context, model *BaseModel) { return } - file_path := path.Join(base_path, data_point_id+".png") + file_path := path.Join(base_path, data_point_id + "." + model.Format) f, err := os.Create(file_path) if err != nil { fmt.Printf("Could not create file %s\n", file_path) diff --git a/logic/models/run.go b/logic/models/run.go index 3b51d15..2503b17 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -27,6 +27,16 @@ func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image { return image.Scale(0, 255) } +func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image { + scope = tg.NewScope(scope) + contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath)) + output := op.DecodePng(scope.SubScope("DecodeJpeg"), contents, op.DecodePngChannels(channels)) + output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0})) + image := &image.Image{ + Tensor: tg.NewTensor(scope, output)} + return image.Scale(0, 255) +} + func handleRun(handle *Handle) { handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { @@ -100,15 +110,16 @@ func handleRun(handle *Handle) { // TODO create a database table with tasks run_path := path.Join("/tmp", model.Id, "runs") os.MkdirAll(run_path, os.ModePerm) + img_path := path.Join(run_path, "img." + model.Format) - img_file, err := os.Create(path.Join(run_path, "img.png")) + img_file, err := os.Create(img_path) if err != nil { return Error500(err) } defer img_file.Close() img_file.Write(file) - if !testImgForModel(c, model, path.Join(run_path, "img.png")) { + if !testImgForModel(c, model, img_path) { LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ "Model": model, "NotFound": false, @@ -120,7 +131,16 @@ func handleRun(handle *Handle) { root := tg.NewRoot() - tf_img := ReadPNG(root, path.Join(run_path, "img.png"), int64(model.ImageMode)) + var tf_img *image.Image = nil + + switch model.Format { + case "png": + tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) + case "jpeg": + tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) + default: + panic("Not sure what to do with '" + model.Format + "'") + } exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) inputImage, err:= tf.NewTensor(exec_results[0].Value()) diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 50ed4b7..6edcaf2 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -27,6 +27,7 @@ DATA_DIR_PREPARE = DATA_DIR + "/" def pathToLabel(path): path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") path = tf.strings.regex_replace(path, ".jpg", "") + path = tf.strings.regex_replace(path, ".jpeg", "") path = tf.strings.regex_replace(path, ".png", "") return table.lookup(tf.strings.as_string([path])) #return tf.strings.as_string([path])