chore: work on the expandable models

This commit is contained in:
Andre Henriques 2024-02-08 18:20:58 +00:00
parent 6a0ac457d7
commit bca44f9ba5
5 changed files with 156 additions and 50 deletions

View File

@ -2,4 +2,6 @@
cd $(dirname "$0") cd $(dirname "$0")
go run . go run . || true
while true; true; end

View File

@ -24,7 +24,7 @@ func handleEdit(handle *Handle) {
} }
// TODO handle admin users // TODO handle admin users
rows, err := handle.Db.Query("select name, status, width, height, color_mode, format from models where id=$1 and user_id=$2;", id, c.User.Id) rows, err := handle.Db.Query("select name, status, width, height, color_mode, format, model_type from models where id=$1 and user_id=$2;", id, c.User.Id)
if err != nil { if err != nil {
return Error500(err) return Error500(err)
} }
@ -45,12 +45,13 @@ func handleEdit(handle *Handle) {
Height *int Height *int
Color_mode *string Color_mode *string
Format string Format string
Type int
} }
var model rowmodel = rowmodel{} var model rowmodel = rowmodel{}
model.Id = id model.Id = id
err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format) err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format, &model.Type)
if err != nil { if err != nil {
return Error500(err) return Error500(err)
} }
@ -124,6 +125,7 @@ func handleEdit(handle *Handle) {
} }
type layerdef struct { type layerdef struct {
id string
LayerType int LayerType int
Shape string Shape string
} }
@ -132,7 +134,7 @@ func handleEdit(handle *Handle) {
for _, def := range defs { for _, def := range defs {
if def.Status == MODEL_DEFINITION_STATUS_TRAINING { if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
rows, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id) rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
} }
@ -140,12 +142,63 @@ func handleEdit(handle *Handle) {
for rows.Next() { for rows.Next() {
var layerdef layerdef var layerdef layerdef
err = rows.Scan(&layerdef.LayerType, &layerdef.Shape) err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape)
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
} }
layers = append(layers, layerdef) layers = append(layers, layerdef)
} }
if model.Type == 2 {
type lastLayerType struct {
Id string
Range_start int
Range_end int
}
var lastLayer lastLayerType
err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id)
if err != nil {
return c.Error500(err)
}
c.Logger.Info("res", "id", lastLayer.Id, "start", lastLayer.Range_start, "end", lastLayer.Range_end)
layers = append(layers, layerdef{
id: lastLayer.Id,
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start),
})
/*
lastLayer, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status=3;", def.Id)
if err != nil {
return c.Error500(err)
}
defer lastLayer.Close()
if !lastLayer.Next() {
c.Logger.Info("Could not find the model head for", "def_id", def.Id)
continue
}
head_id, range_start, range_end := "", 0, 0
err = lastLayer.Scan(&head_id, &range_start, &range_end)
if err != nil {
return c.Error500(err)
}
layers = append(layers, layerdef{
id: head_id,
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d, 1", range_end-range_start),
})
*/
}
break break
} }
} }

View File

@ -216,7 +216,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
// Get untrained models heads // Get untrained models heads
// Status = 2 (INIT) // Status = 2 (INIT)
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status = 2", definition_id) rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil { if err != nil {
return return
} }
@ -231,7 +231,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
exp := ExpHead{} exp := ExpHead{}
if rows.Next() { if rows.Next() {
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil { if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
return return
} }
} else { } else {
@ -246,7 +246,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return return
} }
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRANIED) UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
@ -949,7 +949,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
} }
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) { func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status) rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
if err != nil { if err != nil {
return return
@ -977,6 +977,8 @@ func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatu
// This generates a definition // This generates a definition
func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error { func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity)
var err error = nil var err error = nil
failed := func() *Error { failed := func() *Error {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
@ -1018,9 +1020,14 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
// Create the blocks // Create the blocks
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
if loop == 0 {
loop = 1 if model.Width < 50 && model.Height < 50 {
} loop = 0
}
log.Info("Size of the simple block", "loop", loop)
//loop = max(loop, 3)
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1) err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
@ -1045,9 +1052,10 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
order++ order++
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
if loop == 0 {
loop = 1 log.Info("Size of the dense layers", "loop", loop)
}
// loop = max(loop, 3)
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
@ -1056,7 +1064,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
return failed() return failed()
} }
} }
_, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT) _, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT)
if err != nil { if err != nil {
return failed() return failed()
@ -1131,11 +1139,6 @@ func handleTrain(handle *Handle) {
if model_type_form == "expandable" { if model_type_form == "expandable" {
model_type_id = 2 model_type_id = 2
c.Logger.Warn("TODO: handle expandable")
return c.Error400(nil, "TODO: handle expandable!", w, "/models/edit.html", "train-model-card", AnyMap{
"HasData": true,
"ErrorMessage": "TODO: handle expandable!",
})
} else if model_type_form != "simple" { } else if model_type_form != "simple" {
return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{ return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{
"HasData": true, "HasData": true,

View File

@ -423,6 +423,7 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
var token *string var token *string
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
TimeFormat: time.Kitchen, TimeFormat: time.Kitchen,
Prefix: r.URL.Path, Prefix: r.URL.Path,

View File

@ -7,7 +7,9 @@ import (
"mime" "mime"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"strconv" "strconv"
"strings"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -17,37 +19,37 @@ func CheckEmpty(f url.Values, path string) bool {
} }
func CheckNumber(f url.Values, path string, number *int) bool { func CheckNumber(f url.Values, path string, number *int) bool {
if CheckEmpty(f, path) { if CheckEmpty(f, path) {
fmt.Println("here", path) fmt.Println("here", path)
fmt.Println(f.Get(path)) fmt.Println(f.Get(path))
return false return false
} }
n, err := strconv.Atoi(f.Get(path)) n, err := strconv.Atoi(f.Get(path))
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return false return false
} }
*number = n *number = n
return true return true
} }
func CheckFloat64(f url.Values, path string, number *float64) bool { func CheckFloat64(f url.Values, path string, number *float64) bool {
if CheckEmpty(f, path) { if CheckEmpty(f, path) {
fmt.Println("here", path) fmt.Println("here", path)
fmt.Println(f.Get(path)) fmt.Println(f.Get(path))
return false return false
} }
n, err := strconv.ParseFloat(f.Get(path), 64) n, err := strconv.ParseFloat(f.Get(path), 64)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return false return false
} }
*number = n *number = n
return true return true
} }
func CheckId(f url.Values, path string) bool { func CheckId(f url.Values, path string) bool {
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
} }
func IsValidUUID(u string) bool { func IsValidUUID(u string) bool {
@ -57,19 +59,19 @@ func IsValidUUID(u string) bool {
func GetIdFromUrl(r *http.Request, target string) (string, error) { func GetIdFromUrl(r *http.Request, target string) (string, error) {
if !r.URL.Query().Has(target) { if !r.URL.Query().Has(target) {
return "", errors.New("Query does not have " + target) return "", errors.New("Query does not have " + target)
} }
id := r.URL.Query().Get("id") id := r.URL.Query().Get("id")
if len(id) == 0 { if len(id) == 0 {
return "", errors.New("Query is empty for " + target) return "", errors.New("Query is empty for " + target)
} }
if !IsValidUUID(id) { if !IsValidUUID(id) {
return "", errors.New("Value of query is not a valid uuid for " + target) return "", errors.New("Value of query is not a valid uuid for " + target)
} }
return id, nil return id, nil
} }
type maxBytesReader struct { type maxBytesReader struct {
@ -180,3 +182,48 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) {
} }
return return
} }
type Generic struct{ reflect.Type }
var NotFoundError = errors.New("Not found")
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem()
nargs := t.NumField()
query := ""
for i := 0; i < nargs; i += 1 {
query += strings.ToLower(t.Field(i).Name) + ","
}
query = query[0 : len(query)-1]
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return NotFoundError
}
val := reflect.ValueOf(store).Elem()
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Addr().Interface()
}
err = rows.Scan(scan_args...)
if err != nil {
return err
}
return nil
}