diff --git a/logic/models/run.go b/logic/models/run.go index df88ce9..533df6d 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -110,7 +110,7 @@ func handleRun(handle *Handle) { root := tg.NewRoot() - tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3) + tf_img := ReadPNG(root, path.Join(run_path, "img.png"), int64(model.ImageMode)) exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) inputImage, err:= tf.NewTensor(exec_results[0].Value()) @@ -130,8 +130,6 @@ func handleRun(handle *Handle) { vi := 0 var predictions = results[0].Value().([][]float32)[0] - fmt.Println(predictions) - for i, v := range predictions { if v > vmax { vi = i @@ -155,7 +153,6 @@ func handleRun(handle *Handle) { var name string if err = rows.Scan(&name); err != nil { return nil } - LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ "Model": model, "Result": name, diff --git a/logic/models/train/train.go b/logic/models/train/train.go index a2f6374..d12cea7 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -19,27 +19,12 @@ import ( func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { id = "" - _, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy) - if err != nil { - return - } - - rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id) - if err != nil { - return - } - defer rows.Close() - - if !rows.Next() { - return id, errors.New("Something wrong!") - } - + rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy) + if err != nil { return } + defer rows.Close() + if !rows.Next() { return id, errors.New("Something wrong!") } err = rows.Scan(&id) - if err != nil { - return - } - - return + return } type ModelDefinitionStatus int @@ -103,7 +88,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int, return } -func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) { +func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (accuracy float64, err error) { accuracy = 0 layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { @@ -128,14 +113,14 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc } // Generate run folder - run_path := path.Join("/tmp", model_id, "defs", definition_id) + run_path := path.Join("/tmp", model.Id, "defs", definition_id) err = os.MkdirAll(run_path, os.ModePerm) if err != nil { return } - _, err = generateCvs(handle, run_path, model_id) + _, err = generateCvs(handle, run_path, model.Id) if err != nil { return } // Create python script @@ -153,8 +138,9 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc if err = tmpl.Execute(f, AnyMap{ "Layers": got, "Size": got[0].Shape, - "DataDir": path.Join(getDir(), "savedData", model_id, "data"), + "DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "RunPath": run_path, + "ColorMode": model.ImageMode, }); err != nil { return } @@ -167,7 +153,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc } // Copy result around - result_path := path.Join("savedData", model_id, "defs", definition_id) + result_path := path.Join("savedData", model.Id, "defs", definition_id) if err = os.MkdirAll(result_path, os.ModePerm); err != nil { return @@ -235,7 +221,7 @@ func trainModel(handle *Handle, model *BaseModel) { } for _, def := range definitions { - accuracy, err := trainDefinition(handle, model.Id, def.id) + accuracy, err := trainDefinition(handle, model, def.id) if err != nil { fmt.Printf("Failed to train definition!Err:\n") fmt.Println(err) diff --git a/logic/models/utils/types.go b/logic/models/utils/types.go index da8d1b8..1d6cc32 100644 --- a/logic/models/utils/types.go +++ b/logic/models/utils/types.go @@ -10,8 +10,9 @@ type BaseModel struct { Status int Id string - Width int - Height int + ImageMode int + Width int + Height int } const ( @@ -30,7 +31,7 @@ const ( var ModelNotFoundError = errors.New("Model not found error") func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { - rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id) + rows, err := db.Query("select name, status, id, width, height, color_mode from models where id=$1;", id) if err != nil { return } @@ -41,10 +42,17 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { } base = &BaseModel{} - err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height) + var colorMode string + err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode) if err != nil { return nil, err } + switch colorMode { + case "greyscale": + base.ImageMode = 1 + default: + panic("unkown color mode") + } return } diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 8fd087e..6c19c50 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -32,10 +32,7 @@ def pathToLabel(path): #return tf.strings.as_string([path]) def decode_image(img): - # channels were reduced to 1 since image is grayscale - # TODO chnage channel number based if grayscale - img = tf.io.decode_png(img, channels=3) - + img = tf.io.decode_png(img, channels={{.ColorMode}}) return tf.image.resize(img, image_size) def process_path(path):