From bca44f9ba5863a6923da3398b0a099f34e64ccd5 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Thu, 8 Feb 2024 18:20:58 +0000 Subject: [PATCH] chore: work on the expandable models --- auto_reload.sh | 4 +- logic/models/edit.go | 61 +++++++++++++++++++-- logic/models/train/train.go | 35 ++++++------ logic/utils/handler.go | 1 + logic/utils/utils.go | 105 ++++++++++++++++++++++++++---------- 5 files changed, 156 insertions(+), 50 deletions(-) diff --git a/auto_reload.sh b/auto_reload.sh index d59f1f1..11ef7d2 100755 --- a/auto_reload.sh +++ b/auto_reload.sh @@ -2,4 +2,6 @@ cd $(dirname "$0") -go run . +go run . || true + +while true; true; end diff --git a/logic/models/edit.go b/logic/models/edit.go index 1c3b335..256d270 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -24,7 +24,7 @@ func handleEdit(handle *Handle) { } // TODO handle admin users - rows, err := handle.Db.Query("select name, status, width, height, color_mode, format from models where id=$1 and user_id=$2;", id, c.User.Id) + rows, err := handle.Db.Query("select name, status, width, height, color_mode, format, model_type from models where id=$1 and user_id=$2;", id, c.User.Id) if err != nil { return Error500(err) } @@ -45,12 +45,13 @@ func handleEdit(handle *Handle) { Height *int Color_mode *string Format string + Type int } var model rowmodel = rowmodel{} model.Id = id - err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format) + err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format, &model.Type) if err != nil { return Error500(err) } @@ -124,6 +125,7 @@ func handleEdit(handle *Handle) { } type layerdef struct { + id string LayerType int Shape string } @@ -132,7 +134,7 @@ func handleEdit(handle *Handle) { for _, def := range defs { if def.Status == MODEL_DEFINITION_STATUS_TRAINING { - rows, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id) + rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id) if err != nil { return c.Error500(err) } @@ -140,12 +142,63 @@ func handleEdit(handle *Handle) { for rows.Next() { var layerdef layerdef - err = rows.Scan(&layerdef.LayerType, &layerdef.Shape) + err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape) if err != nil { return c.Error500(err) } layers = append(layers, layerdef) } + + if model.Type == 2 { + + type lastLayerType struct { + Id string + Range_start int + Range_end int + } + + var lastLayer lastLayerType + + err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id) + if err != nil { + return c.Error500(err) + } + + c.Logger.Info("res", "id", lastLayer.Id, "start", lastLayer.Range_start, "end", lastLayer.Range_end) + + layers = append(layers, layerdef{ + id: lastLayer.Id, + LayerType: LAYER_DENSE, + Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start), + }) + + /* + lastLayer, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status=3;", def.Id) + if err != nil { + return c.Error500(err) + } + defer lastLayer.Close() + + if !lastLayer.Next() { + c.Logger.Info("Could not find the model head for", "def_id", def.Id) + continue + } + + head_id, range_start, range_end := "", 0, 0 + err = lastLayer.Scan(&head_id, &range_start, &range_end) + + if err != nil { + return c.Error500(err) + } + + layers = append(layers, layerdef{ + id: head_id, + LayerType: LAYER_DENSE, + Shape: fmt.Sprintf("%d, 1", range_end-range_start), + }) + */ + } + break } } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index df63f13..ef83c76 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -216,7 +216,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load // Get untrained models heads // Status = 2 (INIT) - rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status = 2", definition_id) + rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) if err != nil { return } @@ -231,7 +231,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load exp := ExpHead{} if rows.Next() { - if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil { + if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil { return } } else { @@ -246,7 +246,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load return } - UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRANIED) + UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING) layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { @@ -949,7 +949,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb } func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) { - rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status) + rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status) if err != nil { return @@ -977,6 +977,8 @@ func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatu // This generates a definition func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error { + c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity) + var err error = nil failed := func() *Error { ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) @@ -1018,9 +1020,14 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy // Create the blocks loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) - if loop == 0 { - loop = 1 - } + + if model.Width < 50 && model.Height < 50 { + loop = 0 + } + + log.Info("Size of the simple block", "loop", loop) + + //loop = max(loop, 3) for i := 0; i < loop; i++ { err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1) @@ -1045,9 +1052,10 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy order++ loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) - if loop == 0 { - loop = 1 - } + + log.Info("Size of the dense layers", "loop", loop) + + // loop = max(loop, 3) for i := 0; i < loop; i++ { err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) @@ -1056,7 +1064,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy return failed() } } - + _, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT) if err != nil { return failed() @@ -1131,11 +1139,6 @@ func handleTrain(handle *Handle) { if model_type_form == "expandable" { model_type_id = 2 - c.Logger.Warn("TODO: handle expandable") - return c.Error400(nil, "TODO: handle expandable!", w, "/models/edit.html", "train-model-card", AnyMap{ - "HasData": true, - "ErrorMessage": "TODO: handle expandable!", - }) } else if model_type_form != "simple" { return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{ "HasData": true, diff --git a/logic/utils/handler.go b/logic/utils/handler.go index a03b137..6e37a1a 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -423,6 +423,7 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) var token *string logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: r.URL.Path, diff --git a/logic/utils/utils.go b/logic/utils/utils.go index 982f1a2..41e0089 100644 --- a/logic/utils/utils.go +++ b/logic/utils/utils.go @@ -7,7 +7,9 @@ import ( "mime" "net/http" "net/url" + "reflect" "strconv" + "strings" "github.com/google/uuid" ) @@ -17,37 +19,37 @@ func CheckEmpty(f url.Values, path string) bool { } func CheckNumber(f url.Values, path string, number *int) bool { - if CheckEmpty(f, path) { - fmt.Println("here", path) - fmt.Println(f.Get(path)) - return false - } - n, err := strconv.Atoi(f.Get(path)) - if err != nil { - fmt.Println(err) - return false - } - *number = n - return true + if CheckEmpty(f, path) { + fmt.Println("here", path) + fmt.Println(f.Get(path)) + return false + } + n, err := strconv.Atoi(f.Get(path)) + if err != nil { + fmt.Println(err) + return false + } + *number = n + return true } func CheckFloat64(f url.Values, path string, number *float64) bool { - if CheckEmpty(f, path) { - fmt.Println("here", path) - fmt.Println(f.Get(path)) - return false - } - n, err := strconv.ParseFloat(f.Get(path), 64) - if err != nil { - fmt.Println(err) - return false - } - *number = n - return true + if CheckEmpty(f, path) { + fmt.Println("here", path) + fmt.Println(f.Get(path)) + return false + } + n, err := strconv.ParseFloat(f.Get(path), 64) + if err != nil { + fmt.Println(err) + return false + } + *number = n + return true } func CheckId(f url.Values, path string) bool { - return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) + return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) } func IsValidUUID(u string) bool { @@ -57,19 +59,19 @@ func IsValidUUID(u string) bool { func GetIdFromUrl(r *http.Request, target string) (string, error) { if !r.URL.Query().Has(target) { - return "", errors.New("Query does not have " + target) + return "", errors.New("Query does not have " + target) } id := r.URL.Query().Get("id") if len(id) == 0 { - return "", errors.New("Query is empty for " + target) + return "", errors.New("Query is empty for " + target) } if !IsValidUUID(id) { - return "", errors.New("Value of query is not a valid uuid for " + target) + return "", errors.New("Value of query is not a valid uuid for " + target) } - return id, nil + return id, nil } type maxBytesReader struct { @@ -180,3 +182,48 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) { } return } + +type Generic struct{ reflect.Type } + +var NotFoundError = errors.New("Not found") + +func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error { + + t := reflect.TypeOf(store).Elem() + + nargs := t.NumField() + + query := "" + + for i := 0; i < nargs; i += 1 { + query += strings.ToLower(t.Field(i).Name) + "," + } + + query = query[0 : len(query)-1] + + rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) + + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return NotFoundError + } + + + val := reflect.ValueOf(store).Elem() + scan_args := make([]interface{}, nargs); + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Addr().Interface() + } + + err = rows.Scan(scan_args...) + if err != nil { + return err + } + + return nil +}