diff --git a/go.mod b/go.mod index 58241cd..eb658de 100644 --- a/go.mod +++ b/go.mod @@ -1,23 +1,27 @@ module git.andr3h3nriqu3s.com/andr3/fyp -go 1.20 +go 1.21 + +require ( + github.com/charmbracelet/log v0.3.1 + github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe + github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 + github.com/google/uuid v1.6.0 + github.com/lib/pq v1.10.9 + golang.org/x/crypto v0.18.0 +) require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/lipgloss v0.8.0 // indirect - github.com/charmbracelet/log v0.2.5 // indirect - github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e // indirect - github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 // indirect + github.com/charmbracelet/lipgloss v0.9.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect - github.com/google/uuid v1.3.1 // indirect - github.com/lib/pq v1.10.9 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.18 // indirect - github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect - github.com/rivo/uniseg v0.2.0 // indirect - golang.org/x/crypto v0.13.0 // indirect - golang.org/x/sys v0.12.0 // indirect - google.golang.org/protobuf v1.28.1 // indirect + github.com/rivo/uniseg v0.4.6 // indirect + golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect + golang.org/x/sys v0.16.0 // indirect + google.golang.org/protobuf v1.32.0 // indirect ) diff --git a/go.sum b/go.sum index 1202f51..755022d 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,16 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/charmbracelet/lipgloss v0.8.0 h1:IS00fk4XAHcf8uZKc3eHeMUTCxUH6NkaTrdyCQk84RU= github.com/charmbracelet/lipgloss v0.8.0/go.mod h1:p4eYUZZJ/0oXTuCQKFF8mqyKCz0ja6y+7DniDDw5KKU= +github.com/charmbracelet/lipgloss v0.9.1 h1:PNyd3jvaJbg4jRHKWXnCj1akQm4rh8dbEzN1p/u1KWg= +github.com/charmbracelet/lipgloss v0.9.1/go.mod h1:1mPmG4cxScwUQALAAnacHaigiiHB9Pmr+v1VEawJl6I= github.com/charmbracelet/log v0.2.5 h1:1yVvyKCKVV639RR4LIq1iy1Cs1AKxuNO+Hx2LJtk7Wc= github.com/charmbracelet/log v0.2.5/go.mod h1:nQGK8tvc4pS9cvVEH/pWJiZ50eUq1aoXUOjGpXvdD0k= +github.com/charmbracelet/log v0.3.1 h1:TjuY4OBNbxmHWSwO3tosgqs5I3biyY8sQPny/eCMTYw= +github.com/charmbracelet/log v0.3.1/go.mod h1:OR4E1hutLsax3ZKpXbgUqPtTjQfrh1pG3zwHGWuuq8g= github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e h1:9+2AEFZymTi25FIIcDwuzcOPH04z9+fV6XeLiGORPDI= github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo= +github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe h1:7yELf1NFEwECpXMGowkoftcInMlVtLTCdwWLmxKgzNM= +github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo= github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 h1:8Bt1P/zy1gb37L4n8CGgp1qmFwBV5729kxVfj0sqhJk= github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99/go.mod h1:3YgYBeIX42t83uP27Bd4bSMxTnQhSbxl0pYSkCDB1tc= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= @@ -14,15 +20,21 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= @@ -30,12 +42,26 @@ github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1n github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.6 h1:Sovz9sDSwbOz9tgUy8JpT+KgCkPYJEN/oYzlJiYTNLg= +github.com/rivo/uniseg v0.4.6/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= +golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= diff --git a/logic/models/classes/list.go b/logic/models/classes/list.go index 99ac053..d74d38e 100644 --- a/logic/models/classes/list.go +++ b/logic/models/classes/list.go @@ -1,7 +1,6 @@ package model_classes import ( - "fmt" "net/http" "strconv" @@ -54,7 +53,7 @@ func HandleList(handle *Handle) { return Error500(err) } - rows, err := handle.Db.Query("select id, file_path, model_mode, status from model_data_point where class_id=$1 limit 10 offset $2;", id, page * 10) + rows, err := handle.Db.Query("select id, file_path, model_mode, status from model_data_point where class_id=$1 limit 11 offset $2;", id, page * 10) if err != nil { return Error500(err) } @@ -77,31 +76,16 @@ func HandleList(handle *Handle) { } got = append(got, nrow) } + + max_len := min(11, len(got)) - rows_count, err := handle.Db.Query("select count(*) from model_data_point where class_id=$1;", id) - if err != nil { - return Error500(err) - } - defer rows_count.Close() - - if !rows_count.Next() { - fmt.Printf("select count(*) from model_data_point where class_id='%s';\n", id) - return Error500(nil) - } - - count := 0 - err = rows_count.Scan(&count) - if err != nil { - return Error500(err) - } - LoadDefineTemplate(w, "/models/edit.html", "data-model-create-class-table-table", c.AddMap(AnyMap{ - "List": got, - "Count": count, + "List": got[0:max_len], "Page": page, "Id": id, "Name": name, "Model": model, + "ShowNext": len(got) == 11, })) return nil }) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index c471a73..a9db993 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -46,6 +46,11 @@ func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, return } +func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) { + _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type) + return +} + func generateCvs(c *Context, run_path string, model_id string) (count int, err error) { classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id) @@ -427,6 +432,205 @@ func trainModel(c *Context, model *BaseModel) { ModelUpdateStatus(c, model.Id, READY) } +func trainModelExp(c *Context, model *BaseModel) { + var err error = nil + + failed := func(msg string) { + c.Logger.Error(msg, "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + } + + definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) + if err != nil { + failed("Failed to trainModel!") + return + } + defer definitionsRows.Close() + + var definitions TraingModelRowDefinitions = []TrainModelRow{} + + for definitionsRows.Next() { + var rowv TrainModelRow + rowv.acuracy = 0 + if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil { + c.Logger.Error("Failed to train Model Could not read definition from db!Err:") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + definitions = append(definitions, rowv) + } + + if len(definitions) == 0 { + c.Logger.Error("No Definitions defined!") + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + firstRound := true + finished := false + + for { + var toRemove ToRemoveList = []int{} + for i, def := range definitions { + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) + accuracy, err := trainDefinition(c, model, def.id, !firstRound) + if err != nil { + c.Logger.Error("Failed to train definition!Err:", "err", err) + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + toRemove = append(toRemove, i) + continue + } + def.epoch += EPOCH_PER_RUN + accuracy = accuracy * 100 + def.acuracy = float64(accuracy) + + definitions[i].epoch += EPOCH_PER_RUN + definitions[i].acuracy = accuracy + + if accuracy >= float64(def.target_accuracy) { + c.Logger.Info("Found a definition that reaches target_accuracy!") + _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id) + if err != nil { + c.Logger.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + _, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + if err != nil { + c.Logger.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + finished = true + break + } + + if def.epoch > MAX_EPOCH { + fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy) + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + toRemove = append(toRemove, i) + continue + } + + _, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id) + if err != nil { + c.Logger.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + } + + firstRound = false + if finished { + break + } + + sort.Sort(sort.Reverse(toRemove)) + + c.Logger.Info("Round done", "toRemove", toRemove) + + for _, n := range toRemove { + definitions = remove(definitions, n) + } + + len_def := len(definitions) + + if len_def == 0 { + break + } + + if len_def == 1 { + continue + } + + sort.Sort(sort.Reverse(definitions)) + + acc := definitions[0].acuracy - 20.0 + + c.Logger.Info("Training models, Highest acc", "acc", definitions[0].acuracy, "mod_acc", acc) + + toRemove = []int{} + for i, def := range definitions { + if def.acuracy < acc { + toRemove = append(toRemove, i) + } + } + + c.Logger.Info("Removing due to accuracy", "toRemove", toRemove) + + sort.Sort(sort.Reverse(toRemove)) + for _, n := range toRemove { + c.Logger.Warn("Removing definition not fast enough learning", "n", n) + ModelDefinitionUpdateStatus(c, definitions[n].id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + definitions = remove(definitions, n) + } + } + + rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) + if err != nil { + c.Logger.Error("DB: failed to read definition") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + defer rows.Close() + + if !rows.Next() { + // TODO Make the Model status have a message + c.Logger.Error("All definitions failed to train!") + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + var id string + if err = rows.Scan(&id); err != nil { + c.Logger.Error("Failed to read id:") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + if _, err = c.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil { + c.Logger.Error("Failed to update model definition") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + to_delete, err := c.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id) + if err != nil { + c.Logger.Error("Failed to select model_definition to delete") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + defer to_delete.Close() + + for to_delete.Next() { + var id string + if to_delete.Scan(&id); err != nil { + c.Logger.Error("Failed to scan the id of a model_definition to delete") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) + } + + // TODO Check if returning also works here + if _, err = c.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { + c.Logger.Error("Failed to delete model_definition") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + ModelUpdateStatus(c, model.Id, READY) +} + func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) if err != nil { @@ -511,7 +715,7 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe } } - } else if complexity == 1 { + } else if complexity == 1 || complexity == 2 { loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) if loop == 0 { @@ -542,39 +746,6 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe return failed() } } - - } else if complexity == 2 { - - loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) - if loop == 0 { - loop = 1 - } - for i := 0; i < loop; i++ { - err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "") - order++ - if err != nil { - return failed() - } - } - - err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "") - if err != nil { - return failed() - } - order++ - - loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) - if loop == 0 { - loop = 1 - } - for i := 0; i < loop; i++ { - err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) - order++ - if err != nil { - return failed() - } - } - } else { c.Logger.Error("Unkown complexity", "complexity", complexity) return failed() @@ -625,6 +796,163 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb return nil } +func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) { + rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status) + + if err != nil { + return + } + defer rows.Close() + + if !rows.Next() { + c.Logger.Error("Could not get status of model definition") + err = errors.New("Could not get status of model definition") + return + } + + err = rows.Scan(&id) + if err != nil { + return + } + + return +} + +func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatus) (err error) { + _, err = db.Exec("update model_definition set status = $1 where id = $2", status, id) + return +} + +// This generates a definition +func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error { + var err error = nil + failed := func() *Error { + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return c.Error500(err) + } + + if complexity == 0 { + return failed() + } + + def_id, err := MakeDefenition(c.Db, model.Id, target_accuracy) + if err != nil { + return failed() + } + + order := 1 + + width := model.Width + height := model.Height + + // Note the shape of the first layer defines the import size + if complexity == 2 { + // Note the shape for now is no used + width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0)))) + height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0)))) + c.Logger.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height) + + } + + err = MakeLayerExpandable(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1) + + order++ + + // handle the errors inside the pervious if block + if err != nil { + return failed() + } + + // Create the blocks + loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) + if loop == 0 { + loop = 1 + } + + for i := 0; i < loop; i++ { + err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1) + order++ + if err != nil { + return failed() + } + } + + // Flatten the blocks into dense + err = MakeLayerExpandable(c.Db, def_id, order, LAYER_FLATTEN, "", 1) + if err != nil { + return failed() + } + order++ + + // Flatten the blocks into dense + err = MakeLayerExpandable(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes * 2), 1) + if err != nil { + return failed() + } + order++ + + loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) + if loop == 0 { + loop = 1 + } + + for i := 0; i < loop; i++ { + err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) + order++ + if err != nil { + return failed() + } + } + + _, err = CreateExpModelHead(c, def_id, 0, number_of_classes - 1, MODEL_DEFINITION_STATUS_INIT) + if err != nil { + return failed() + } + + err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT) + if err != nil { + return failed() + } + + return nil +} + +func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy int, number_of_models int) *Error { + cls, err := model_classes.ListClasses(c.Db, model.Id) + if err != nil { + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return c.Error500(err) + } + + err = removeFailedDataPoints(c, model) + if err != nil { + return c.Error500(err) + } + + cls_len := len(cls) + + if number_of_models == 1 { + if model.Width > 100 && model.Height > 100 { + generateExpandableDefinition(c, model, target_accuracy, cls_len, 2) + } else { + generateExpandableDefinition(c, model, target_accuracy, cls_len, 1) + } + } else if number_of_models == 3 { + for i := 0; i < number_of_models; i++ { + generateExpandableDefinition(c, model, target_accuracy, cls_len, i) + } + } else { + // TODO handle incrisea the complexity + for i := 0; i < number_of_models; i++ { + generateExpandableDefinition(c, model, target_accuracy, cls_len, 1) + } + } + + return nil +} + func handleTrain(handle *Handle) { handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { @@ -641,35 +969,27 @@ func handleTrain(handle *Handle) { accuracy := 0 if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) { - fmt.Println( - !CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy), - ) // TODO improve this response return ErrorCode(nil, 400, c.AddMap(nil)) } id := f.Get("id") - model_type := f.Get("model_type") - // Its not used rn - _ = model_type + model_type_id := 1 + model_type_form := f.Get("model_type") - // TODO check if the model has data - /*rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id) - if err != nil { - return Error500(err) + if model_type_form == "expandable" { + model_type_id = 2 + c.Logger.Warn("TODO: handle expandable") + return c.Error400(nil, "TODO: handle expandable!", w, "/models/edit.html", "train-model-card", AnyMap{ + "HasData": true, + "ErrorMessage": "TODO: handle expandable!", + }) + } else if model_type_form != "simple" { + return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{ + "HasData": true, + "ErrorMessage": "Invalid model type!", + }) } - defer rows.Close() - - if !rows.Next() { - return Error500(err) - } - - var name string - var file_path string - err = rows.Scan(&name, &file_path) - if err != nil { - return Error500(err) - }*/ model, err := GetBaseModel(handle.Db, id) if err == ModelNotFoundError { @@ -687,14 +1007,28 @@ func handleTrain(handle *Handle) { return ErrorCode(nil, 400, c.AddMap(nil)) } - full_error := generateDefinitions(c, model, accuracy, number_of_models) - if full_error != nil { - return full_error - } + if model_type_id == 2 { + full_error := generateExpandableDefinitions(c, model, accuracy, number_of_models) + if full_error != nil { + return full_error + } + } else { + full_error := generateDefinitions(c, model, accuracy, number_of_models) + if full_error != nil { + return full_error + } + } + go trainModel(c, model) - ModelUpdateStatus(c, model.Id, TRAINING) + _, err = c.Db.Exec("update models set status = $1, model_type = $2 where id = $3", TRAINING, model_type_id, model.Id) + if err != nil { + fmt.Println("Failed to update model status") + fmt.Println(err) + // TODO improve this response + return Error500(err) + } Redirect("/models/edit?id="+model.Id, c.Mode, w, r) return nil }) diff --git a/main.go b/main.go index 570caf0..d02aa6b 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( ) const ( + clear_db = false host = "localhost" port = 5432 user = "postgres" @@ -35,8 +36,9 @@ func main() { //TODO check if file structure exists to save data handle := NewHandler(db) + // TODO remove this before commiting _, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING) - if err != nil { + if err != nil && clear_db { log.Warn("Database might not be on") panic(err) } diff --git a/sql/models.sql b/sql/models.sql index aa57d62..05bcb1c 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -10,6 +10,12 @@ create table if not exists models ( -- -1: failed preparing -- 1: preparing status integer default 1, + + -- Types: + -- 0: Unset + -- 1: simple + -- 2: expandable + model_type integer default 0, width integer, height integer, @@ -70,6 +76,34 @@ create table if not exists model_definition_layer ( layer_type integer not null, -- ei 28,28,1 -- a 28x28 grayscale image - shape text not null + shape text not null, + + -- Type based on the expandability + -- 0: not expandalbe model + -- 1: fixed + -- 2: head + exp_type integer default 0 +); + +-- drop table if exists exp_model_head; +create table if not exists exp_model_head ( + id uuid primary key default gen_random_uuid(), + def_id uuid references model_definition (id) on delete cascade, + -- Start order id to the class that this model satifies inclusive + range_start integer not null, + -- end order id to the class that this model satifies inclusive + range_end integer not null, + + accuracy real default 0, + + -- TODO add max epoch + -- 1: Pre Init + -- 2: Init + -- 3: Training + -- 4: Tranied + -- 5: Ready + status integer default 1, + created_on timestamp default current_timestamp, + epoch_progress integer default 0 ); diff --git a/views/models/add.html b/views/models/add.html index 33334a6..273a561 100644 --- a/views/models/add.html +++ b/views/models/add.html @@ -21,7 +21,7 @@