From b5a28a0bdb6fdde7fe0d3e736fddbfa2c3221710 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Wed, 14 Feb 2024 15:11:45 +0000 Subject: [PATCH] feat: started working on the head part of the split models --- logic/models/run.go | 188 +++++++++++++++++++++++------------- logic/models/train/train.go | 22 +++-- logic/models/utils/types.go | 11 ++- logic/utils/utils.go | 59 ++++++++++- 4 files changed, 196 insertions(+), 84 deletions(-) diff --git a/logic/models/run.go b/logic/models/run.go index 2503b17..5e9844a 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -37,6 +37,62 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image { return image.Scale(0, 255) } +func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) { + order = 0 + err = nil + + tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) + + results := tf_model.Exec([]tf.Output{ + tf_model.Op("StatefulPartitionedCall", 0), + }, map[tf.Output]*tf.Tensor{ + tf_model.Op("serving_default_rescaling_input", 0): inputImage, + }) + + var vmax float32 = 0.0 + var predictions = results[0].Value().([][]float32)[0] + + for i, v := range predictions { + if v > vmax { + order = i + vmax = v + } + } + + return +} + +func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) { + + err = nil + order = 0 + + base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil) + + //results := base_model.Exec([]tf.Output{ + base_model.Exec([]tf.Output{ + base_model.Op("StatefulPartitionedCall", 0), + }, map[tf.Output]*tf.Tensor{ + //base_model.Op("serving_default_rescaling_input", 0): inputImage, + base_model.Op("serving_default_input_1", 0): inputImage, + }) + + type head struct { + Id string + Range_start int + } + + heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id) + if err != nil { + return + } + + // TODO runthe head model + + c.Logger.Info("Got", "heads", len(heads)) + return +} + func handleRun(handle *Handle) { handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { @@ -90,27 +146,22 @@ func handleRun(handle *Handle) { return ErrorCode(nil, 400, c.AddMap(nil)) } - definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id) - if err != nil { - return Error500(err) - } - defer definitions_rows.Close() - - if !definitions_rows.Next() { + def := JustId{} + err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id) + if err == NotFoundError { // TODO improve this - fmt.Printf("Could not find definition\n") + fmt.Printf("Could not find definition\n") return ErrorCode(nil, 400, c.AddMap(nil)) - } - - var def_id string - if err = definitions_rows.Scan(&def_id); err != nil { + } else if err != nil { return Error500(err) } + def_id := def.Id + // TODO create a database table with tasks run_path := path.Join("/tmp", model.Id, "runs") os.MkdirAll(run_path, os.ModePerm) - img_path := path.Join(run_path, "img." + model.Format) + img_path := path.Join(run_path, "img."+model.Format) img_file, err := os.Create(img_path) if err != nil { @@ -119,74 +170,75 @@ func handleRun(handle *Handle) { defer img_file.Close() img_file.Write(file) - if !testImgForModel(c, model, img_path) { - LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ - "Model": model, - "NotFound": false, - "Result": nil, - "ImageError": true, - })) - return nil - } + if !testImgForModel(c, model, img_path) { + LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ + "Model": model, + "NotFound": false, + "Result": nil, + "ImageError": true, + })) + return nil + } root := tg.NewRoot() - var tf_img *image.Image = nil - - switch model.Format { - case "png": - tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) - case "jpeg": - tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) - default: - panic("Not sure what to do with '" + model.Format + "'") - } + var tf_img *image.Image = nil - exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) - inputImage, err:= tf.NewTensor(exec_results[0].Value()) - if err != nil { - return Error500(err) - } + switch model.Format { + case "png": + tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) + case "jpeg": + tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) + default: + panic("Not sure what to do with '" + model.Format + "'") + } - tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) + exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) + inputImage, err := tf.NewTensor(exec_results[0].Value()) + if err != nil { + return Error500(err) + } - results := tf_model.Exec([]tf.Output{ - tf_model.Op("StatefulPartitionedCall", 0), - }, map[tf.Output]*tf.Tensor{ - tf_model.Op("serving_default_rescaling_input", 0): inputImage, - }) + vi := -1 - var vmax float32 = 0.0 - vi := 0 - var predictions = results[0].Value().([][]float32)[0] - - for i, v := range predictions { - if v > vmax { - vi = i - vmax = v + if model.ModelType == 2 { + c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) + vi, err = runModelExp(c, model, def_id, inputImage) + if err != nil { + return c.Error500(err); + } + } else { + c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) + vi, err = runModelNormal(c, model, def_id, inputImage) + if err != nil { + return c.Error500(err); } } - os.RemoveAll(run_path) + os.RemoveAll(run_path) - rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi) - if err != nil { return Error500(err) } - if !rows.Next() { - LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ - "Model": model, - "NotFound": true, - "Result": nil, - })) - return nil - } + rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi) + if err != nil { + return Error500(err) + } + if !rows.Next() { + LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ + "Model": model, + "NotFound": true, + "Result": nil, + })) + return nil + } - var name string - if err = rows.Scan(&name); err != nil { return nil } + var name string + if err = rows.Scan(&name); err != nil { + return nil + } - LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ - "Model": model, - "Result": name, - })) + LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ + "Model": model, + "Result": name, + })) return nil }) } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index bcc99b9..144cd57 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -784,23 +784,31 @@ func trainModelExp(c *Context, model *BaseModel) { failed("Failed to split the model") return } + + // There should only be one def availabale + def := JustId{} + + if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { + return + } + + // Remove the base model + c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id) + os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model")) + os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras")) ModelUpdateStatus(c, model.Id, READY) } func splitModel(c *Context, model *BaseModel) (err error) { - type Def struct { - Id string - } - - def := Def{} + def := JustId{} if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { return } - head := Def{} + head := JustId{} if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil { return @@ -887,8 +895,6 @@ func splitModel(c *Context, model *BaseModel) (err error) { return } - - func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) if err != nil { diff --git a/logic/models/utils/types.go b/logic/models/utils/types.go index faa88a4..d212273 100644 --- a/logic/models/utils/types.go +++ b/logic/models/utils/types.go @@ -6,10 +6,11 @@ import ( ) type BaseModel struct { - Name string - Status int - Id string + Name string + Status int + Id string + ModelType int ImageMode int Width int Height int @@ -54,7 +55,7 @@ const ( var ModelNotFoundError = errors.New("Model not found error") func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { - rows, err := db.Query("select name, status, id, width, height, color_mode, format from models where id=$1;", id) + rows, err := db.Query("select name, status, id, width, height, color_mode, format, model_type from models where id=$1;", id) if err != nil { return } @@ -66,7 +67,7 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { base = &BaseModel{} var colorMode string - err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format) + err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format, &base.ModelType) if err != nil { return nil, err } diff --git a/logic/utils/utils.go b/logic/utils/utils.go index 41e0089..ec7e9e4 100644 --- a/logic/utils/utils.go +++ b/logic/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "database/sql" "errors" "fmt" "io" @@ -183,14 +184,64 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) { return } +type JustId struct { Id string } + type Generic struct{ reflect.Type } var NotFoundError = errors.New("Not found") +func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) { + t := reflect.TypeFor[T]() + nargs := t.NumField() + + query := "" + + for i := 0; i < nargs; i += 1 { + query += strings.ToLower(t.Field(i).Name) + "," + } + + // Remove the last comma + query = query[0 : len(query)-1] + + rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*T{} + + for rows.Next() { + item := new(T) + if err = mapRow(item, rows, nargs); err != nil { + return nil, err + } + list = append(list, item) + } + + return list, nil +} + +func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { + err = nil + + val := reflect.Indirect(reflect.ValueOf(store)) + scan_args := make([]interface{}, nargs); + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Addr().Interface() + } + + err = rows.Scan(scan_args...) + if err != nil { + return + } + + return nil +} + func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error { - t := reflect.TypeOf(store).Elem() - nargs := t.NumField() query := "" @@ -199,10 +250,10 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err query += strings.ToLower(t.Field(i).Name) + "," } + // Remove the last comma query = query[0 : len(query)-1] rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) - if err != nil { return err } @@ -212,6 +263,7 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err return NotFoundError } + err = nil val := reflect.ValueOf(store).Elem() scan_args := make([]interface{}, nargs); @@ -227,3 +279,4 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err return nil } +