diff --git a/.air.toml b/.air.toml index 619275b..4d1231b 100644 --- a/.air.toml +++ b/.air.toml @@ -7,7 +7,7 @@ tmp_dir = "tmp" bin = "./tmp/main" cmd = "go build -o ./tmp/main ." delay = 0 - exclude_dir = ["assets", "tmp", "vendor", "testData", "savedData"] + exclude_dir = ["assets", "tmp", "vendor", "testData", "savedData", "tensorflow"] exclude_file = [] exclude_regex = ["_test.go"] exclude_unchanged = false diff --git a/go.mod b/go.mod index e50c8c5..403bf49 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,10 @@ module git.andr3h3nriqu3s.com/andr3/fyp go 1.20 require ( + github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e // indirect + github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 // indirect github.com/google/uuid v1.3.1 // indirect github.com/lib/pq v1.10.9 // indirect golang.org/x/crypto v0.13.0 // indirect + google.golang.org/protobuf v1.28.1 // indirect ) diff --git a/go.sum b/go.sum index 64b7ec0..37aced0 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,16 @@ +github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e h1:9+2AEFZymTi25FIIcDwuzcOPH04z9+fV6XeLiGORPDI= +github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo= +github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 h1:8Bt1P/zy1gb37L4n8CGgp1qmFwBV5729kxVfj0sqhJk= +github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99/go.mod h1:3YgYBeIX42t83uP27Bd4bSMxTnQhSbxl0pYSkCDB1tc= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/logic/models/add.go b/logic/models/add.go index 852a2d1..7ce4e11 100644 --- a/logic/models/add.go +++ b/logic/models/add.go @@ -12,6 +12,7 @@ import ( "path" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" ) func loadBaseImage(handle *Handle, id string) { @@ -21,7 +22,7 @@ func loadBaseImage(handle *Handle, id string) { // TODO better logging fmt.Println(err) fmt.Printf("Failed to read image for model with id %s\n", id) - modelUpdateStatus(handle, id, -1) + ModelUpdateStatus(handle, id, -1) return } defer infile.Close() @@ -31,7 +32,7 @@ func loadBaseImage(handle *Handle, id string) { // TODO better logging fmt.Println(err) fmt.Printf("Failed to load image for model with id %s\n", id) - modelUpdateStatus(handle, id, -1) + ModelUpdateStatus(handle, id, -1) return } if format != "png" { @@ -67,7 +68,7 @@ func loadBaseImage(handle *Handle, id string) { fmt.Println("Other so assuming color") } - modelUpdateStatus(handle, id, -1) + ModelUpdateStatus(handle, id, -1) return } @@ -77,7 +78,7 @@ func loadBaseImage(handle *Handle, id string) { // TODO better logging fmt.Println(err) fmt.Printf("Could not update model\n") - modelUpdateStatus(handle, id, -1) + ModelUpdateStatus(handle, id, -1) return } } diff --git a/logic/models/classes/main.go b/logic/models/classes/main.go index 2e24078..fd43b2c 100644 --- a/logic/models/classes/main.go +++ b/logic/models/classes/main.go @@ -33,6 +33,17 @@ func ListClasses(db *sql.DB, model_id string) (cls []ModelClass, err error) { return } +func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) { + result = false + rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id) + if err != nil { + return + } + defer rows.Close() + + return rows.Next(), nil +} + var ClassAlreadyExists = errors.New("Class aready exists") func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) { diff --git a/logic/models/data.go b/logic/models/data.go index b7c963a..875434a 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -14,6 +14,7 @@ import ( model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" ) func InsertIfNotPresent(ss []string, s string) []string { @@ -31,7 +32,7 @@ func processZipFile(handle *Handle, id string) { reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip")) if err != nil { // TODO add msg to error - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) fmt.Printf("Faield to proccess zip file failed to open reader\n") fmt.Println(err) return @@ -51,7 +52,7 @@ func processZipFile(handle *Handle, id string) { if paths[0] != "training" && paths[0] != "testing" { fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } @@ -66,7 +67,7 @@ func processZipFile(handle *Handle, id string) { fmt.Printf("testing and training are diferent\n") fmt.Println(testing) fmt.Println(training) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } @@ -79,21 +80,21 @@ func processZipFile(handle *Handle, id string) { err = os.MkdirAll(dir_path, os.ModePerm) if err != nil { fmt.Printf("Failed to create dir %s\n", dir_path) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } dir_path = path.Join(base_path, "testing", name) err = os.MkdirAll(dir_path, os.ModePerm) if err != nil { fmt.Printf("Failed to create dir %s\n", dir_path) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } id, err := model_classes.CreateClass(handle.Db, id, name) if err != nil { fmt.Printf("Failed to create class '%s' on db\n", name) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } ids[name] = id @@ -108,14 +109,14 @@ func processZipFile(handle *Handle, id string) { f, err := os.Create(file_path) if err != nil { fmt.Printf("Could not create file %s\n", file_path) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } defer f.Close() data, err := reader.Open(file.Name) if err != nil { fmt.Printf("Could not create file %s\n", file_path) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } defer data.Close() @@ -135,13 +136,13 @@ func processZipFile(handle *Handle, id string) { if err != nil { fmt.Printf("Failed to add data point for %s\n", id) fmt.Println(err) - modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) return } } fmt.Printf("Added data to model '%s'!\n", id) - modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING) + ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING) } func handleDataUpload(handle *Handle) { @@ -182,7 +183,7 @@ func handleDataUpload(handle *Handle) { } } - _, err = getBaseModel(handle.Db, id) + _, err = GetBaseModel(handle.Db, id) if err == ModelNotFoundError { return ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", @@ -203,7 +204,7 @@ func handleDataUpload(handle *Handle) { f.Write(file) - modelUpdateStatus(handle, id, PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, id, PREPARING_ZIP_FILE) go processZipFile(handle, id) @@ -230,7 +231,7 @@ func handleDataUpload(handle *Handle) { id := f.Get("id") - model, err := getBaseModel(handle.Db, id) + model, err := GetBaseModel(handle.Db, id) if err == ModelNotFoundError { return ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", @@ -260,7 +261,7 @@ func handleDataUpload(handle *Handle) { return Error500(err) } - modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING) + ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING) Redirect("/models/edit?id="+id, c.Mode, w, r) return nil }) diff --git a/logic/models/delete.go b/logic/models/delete.go index 3e19d46..09f2018 100644 --- a/logic/models/delete.go +++ b/logic/models/delete.go @@ -7,6 +7,7 @@ import ( "path" "strconv" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) @@ -71,7 +72,7 @@ func handleDelete(handle *Handle) { }) } - var model BaseModel = BaseModel{} + model := BaseModel{} model.Id = id err = rows.Scan(&model.Name, &model.Status) @@ -80,6 +81,8 @@ func handleDelete(handle *Handle) { } switch model.Status { + case FAILED_PREPARING_TRAINING: + fallthrough case FAILED_PREPARING: deleteModel(handle, id, w, c, model) return nil diff --git a/logic/models/edit.go b/logic/models/edit.go index 957cbc7..81c4574 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -5,6 +5,7 @@ import ( "net/http" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) @@ -70,6 +71,11 @@ func handleEdit(handle *Handle) { case CONFIRM_PRE_TRAINING: cls, err := model_classes.ListClasses(handle.Db, id) + if err != nil { + return Error500(err) + } + + has_data, err := model_classes.ModelHasDataPoints(handle.Db, id) if err != nil { return Error500(err) } @@ -77,7 +83,10 @@ func handleEdit(handle *Handle) { LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ "Model": model, "Classes": cls, + "HasData": has_data, })) + case TRAINING: + fallthrough case PREPARING_ZIP_FILE: LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ "Model": model, diff --git a/logic/models/index.go b/logic/models/index.go index f62542f..fa0258b 100644 --- a/logic/models/index.go +++ b/logic/models/index.go @@ -2,6 +2,7 @@ package models import ( model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" + models_train "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) @@ -15,5 +16,8 @@ func HandleModels (handle *Handle) { handleDataUpload(handle) model_classes.HandleList(handle) + + // Train endpoints + models_train.HandleTrainEndpoints(handle) } diff --git a/logic/models/train/main.go b/logic/models/train/main.go new file mode 100644 index 0000000..966ec71 --- /dev/null +++ b/logic/models/train/main.go @@ -0,0 +1,13 @@ +package models_train + +import ( + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func HandleTrainEndpoints(handle *Handle) { + handleTrain(handle) + handleRest(handle) + + //TODO remove + handleTest(handle) +} diff --git a/logic/models/train/reset.go b/logic/models/train/reset.go new file mode 100644 index 0000000..5fb67e2 --- /dev/null +++ b/logic/models/train/reset.go @@ -0,0 +1,58 @@ +package models_train + +import ( + "net/http" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func handleRest(handle *Handle) { + handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if !CheckAuthLevel(1, w, r, c) { + return nil + } + if c.Mode == JSON { + panic("handle JSON /models/train/reset") + } + + f, err := MyParseForm(r) + if err != nil { + // TODO improve response + return ErrorCode(nil, 400, c.AddMap(nil)) + } + + if !CheckId(f, "id") { + // TODO improve response + return ErrorCode(nil, 400, c.AddMap(nil)) + } + + id := f.Get("id") + + model, err := GetBaseModel(handle.Db, id) + if err == ModelNotFoundError { + return ErrorCode(nil, http.StatusNotFound, AnyMap{ + "NotFoundMessage": "Model not found", + "GoBackLink": "/models", + }) + } else if err != nil { + // TODO improve response + return Error500(err) + } + + if model.Status != FAILED_PREPARING_TRAINING { + // TODO improve response + return ErrorCode(nil, 400, c.AddMap(nil)) + } + + _, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id) + if err != nil { + // TODO improve response + return Error500(err) + } + + ModelUpdateStatus(handle, model.Id, CONFIRM_PRE_TRAINING) + Redirect("/models/edit?id=" + model.Id, c.Mode, w, r) + return nil + }) +} diff --git a/logic/models/train/tensorflow-test.go b/logic/models/train/tensorflow-test.go new file mode 100644 index 0000000..c645ae9 --- /dev/null +++ b/logic/models/train/tensorflow-test.go @@ -0,0 +1,139 @@ +package models_train + +import ( + "fmt" + "net/http" + "os" + "path" + "strings" + "text/template" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +/*import ( + tf "github.com/galeone/tensorflow/tensorflow/go" + tg "github.com/galeone/tfgo" + "github.com/galeone/tfgo/image" + "github.com/galeone/tfgo/image/filter" + "github.com/galeone/tfgo/image/padding" +)*/ + +func getDir() string { + dir, err := os.Getwd() + if err != nil { + panic(err) + } + return dir +} + +func shapeToSize(shape string) string { + split := strings.Split(shape, ",") + return strings.Join(split[:len(split) - 1], ",") +} + +func handleTest(handle *Handle) { + handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if !CheckAuthLevel(1, w, r, c) { + return nil + } + + id, err := GetIdFromUrl(r, "id") + if err != nil { + return ErrorCode(err, 400, c.AddMap(nil)) + } + + rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id) + if err != nil { + return Error500(err) + } + defer rows.Close() + + if !rows.Next() { + return Error500(err) + } + + var name string + var file_path string + err = rows.Scan(&name, &file_path) + if err != nil { + return Error500(err) + } + + file_path = strings.Replace(file_path, "file://", "", 1) + + img_path := path.Join("savedData", id, "data", "training", name, file_path) + + fmt.Printf("%s\n", img_path) + + definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id) + if err != nil { + return Error500(err) + } + defer definitions.Close() + + if !definitions.Next() { + fmt.Println("Did not find definition") + return Error500(nil) + } + + var definition_id string + + if err = definitions.Scan(&definition_id); err != nil { + return Error500(err) + } + + layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) + if err != nil { + return Error500(err) + } + defer layers.Close() + + type layerrow struct { + LayerType int + Shape string + } + + got := []layerrow{} + + for layers.Next() { + var row = layerrow{} + if err = layers.Scan(&row.LayerType, &row.Shape); err != nil { + return Error500(err) + } + row.Shape = shapeToSize(row.Shape) + got = append(got, row) + } + + // Generate folder + + err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm) + if err != nil { + return Error500(err) + } + + f, err := os.Create(path.Join("/tmp", id, "run.py")) + if err != nil { + return Error500(err) + } + defer f.Close() + + fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py")) + + tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") + if err != nil { + return Error500(err) + } + + if err = tmpl.Execute(f, AnyMap{ + "Layers": got, + "Size": got[0].Shape, + "DataDir": path.Join(getDir(), "savedData", id, "data", "training"), + }); err != nil { + return Error500(err) + } + + w.Write([]byte("Done")) + return nil + }) +} diff --git a/logic/models/train/train.go b/logic/models/train/train.go new file mode 100644 index 0000000..28fd457 --- /dev/null +++ b/logic/models/train/train.go @@ -0,0 +1,158 @@ +package models_train + +import ( + "database/sql" + "errors" + "fmt" + "net/http" + + model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { + id = "" + _, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy) + if err != nil { + return + } + + rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id) + if err != nil { + return + } + defer rows.Close() + + if !rows.Next() { + return id, errors.New("Something wrong!") + } + + err = rows.Scan(&id) + if err != nil { + return + } + + return +} + +type ModelDefinitionStatus int + +const ( + MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 + MODEL_DEFINITION_STATUS_INIT = 2 + MODEL_DEFINITION_STATUS_TRAINING = 3 + MODEL_DEFINITION_STATUS_TRANIED = 4 + MODEL_DEFINITION_STATUS_READY = 5 +) + +func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) { + _, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id) + return +} + +func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape string) (err error) { + _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape) + return +} + +func handleTrain(handle *Handle) { + handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if !CheckAuthLevel(1, w, r, c) { + return nil + } + if c.Mode == JSON { + panic("TODO /models/train JSON") + } + + r.ParseForm() + f := r.Form + + number_of_models := 0 + accuracy := 0 + + if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) { + fmt.Println( + !CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy), + ) + // TODO improve this response + return ErrorCode(nil, 400, c.AddMap(nil)) + } + + id := f.Get("id") + model_type := f.Get("model_type") + // Its not used rn + _ = model_type + + model, err := GetBaseModel(handle.Db, id) + if err == ModelNotFoundError { + return ErrorCode(nil, http.StatusNotFound, c.AddMap(AnyMap{ + "NotFoundMessage": "Model not found", + "GoBackLink": "/models", + })) + } else if err != nil { + // TODO improve this response + return Error500(err) + } + + if model.Status != CONFIRM_PRE_TRAINING { + // TODO improve this response + return ErrorCode(nil, 400, c.AddMap(nil)) + } + + cls, err := model_classes.ListClasses(handle.Db, model.Id) + if err != nil { + ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return Error500(err) + } + + + var fid string + for i := 0; i < number_of_models; i++ { + def_id, err := MakeDefenition(handle.Db, model.Id, accuracy) + if err != nil { + ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return Error500(err) + } + + if fid == "" { + fid = def_id + } + + // TODO change shape of it depends on the type of the image + err = MakeLayer(handle.Db, def_id, 1, 1, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) + if err != nil { + ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return Error500(err) + } + err = MakeLayer(handle.Db, def_id, 4, 3, fmt.Sprintf("%d,1", len(cls))) + if err != nil { + ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return Error500(err) + } + err = MakeLayer(handle.Db, def_id, 5, 2, fmt.Sprintf("%d,1", len(cls))) + if err != nil { + ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return Error500(err) + } + + err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT) + if err != nil { + ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return Error500(err) + } + } + + // TODO start training with id fid + + ModelUpdateStatus(handle, model.Id, TRAINING) + Redirect("/models/edit?id=" + model.Id, c.Mode, w, r) + return nil + }) +} diff --git a/logic/models/types.go b/logic/models/types.go deleted file mode 100644 index 9e27214..0000000 --- a/logic/models/types.go +++ /dev/null @@ -1,43 +0,0 @@ -package models - -import ( - "database/sql" - "errors" -) - -type BaseModel struct { - Name string - Status int - Id string -} - -const ( - FAILED_PREPARING_ZIP_FILE = -2 - FAILED_PREPARING = -1 - - PREPARING = 1 - CONFIRM_PRE_TRAINING = 2 - PREPARING_ZIP_FILE = 3 -) - -var ModelNotFoundError = errors.New("Model not found error") - -func getBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { - rows, err := db.Query("select name, status, id from models where id=$1;", id) - if err != nil { - return - } - defer rows.Close() - - if !rows.Next() { - return nil, ModelNotFoundError - } - - base = &BaseModel{} - err = rows.Scan(&base.Name, &base.Status, &base.Id) - if err != nil { - return nil, err - } - - return -} diff --git a/logic/models/utils/types.go b/logic/models/utils/types.go new file mode 100644 index 0000000..5c71d4d --- /dev/null +++ b/logic/models/utils/types.go @@ -0,0 +1,49 @@ +package models_utils + +import ( + "database/sql" + "errors" +) + +type BaseModel struct { + Name string + Status int + Id string + + Width int + Height int +} + +const ( + FAILED_TRAINING = -4 + FAILED_PREPARING_TRAINING = -3 + FAILED_PREPARING_ZIP_FILE = -2 + FAILED_PREPARING = -1 + + PREPARING = 1 + CONFIRM_PRE_TRAINING = 2 + PREPARING_ZIP_FILE = 3 + TRAINING = 4 +) + +var ModelNotFoundError = errors.New("Model not found error") + +func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { + rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id) + if err != nil { + return + } + defer rows.Close() + + if !rows.Next() { + return nil, ModelNotFoundError + } + + base = &BaseModel{} + err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height) + if err != nil { + return nil, err + } + + return +} diff --git a/logic/models/utils.go b/logic/models/utils/utils.go similarity index 67% rename from logic/models/utils.go rename to logic/models/utils/utils.go index ff1667e..97e0fe7 100644 --- a/logic/models/utils.go +++ b/logic/models/utils/utils.go @@ -1,4 +1,4 @@ -package models +package models_utils import ( "fmt" @@ -6,7 +6,8 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) -func modelUpdateStatus(handle *Handle, id string, status int) { +// TODO make this return and caller handle error +func ModelUpdateStatus(handle *Handle, id string, status int) { _, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id) if err != nil { fmt.Println("Failed to update model status") diff --git a/logic/utils/utils.go b/logic/utils/utils.go index 44c4f63..7e80215 100644 --- a/logic/utils/utils.go +++ b/logic/utils/utils.go @@ -2,10 +2,12 @@ package utils import ( "errors" + "fmt" "io" "mime" "net/http" "net/url" + "strconv" "github.com/google/uuid" ) @@ -14,6 +16,21 @@ func CheckEmpty(f url.Values, path string) bool { return !f.Has(path) || f.Get(path) == "" } +func CheckNumber(f url.Values, path string, number *int) bool { + if CheckEmpty(f, path) { + fmt.Println("here", path) + fmt.Println(f.Get(path)) + return false + } + n, err := strconv.Atoi(f.Get(path)) + if err != nil { + fmt.Println(err) + return false + } + *number = n + return true +} + func CheckId(f url.Values, path string) bool { return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) } diff --git a/sql/models.sql b/sql/models.sql index eb15296..239a626 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -1,5 +1,6 @@ -- drop table if exists model_data_point; --- drop table if exists model_defenitions; +-- drop table if exists model_definition_layer; +-- drop table if exists model_definition; -- drop table if exists models; create table if not exists models ( id uuid primary key default gen_random_uuid(), @@ -32,3 +33,37 @@ create table if not exists model_data_point ( -- 2 testing model_mode integer default 1 ); + +-- drop table if exists model_definition; +-- drop table if exists model_definition; +create table if not exists model_definition ( + id uuid primary key default gen_random_uuid(), + model_id uuid references models (id) on delete cascade, + accuracy integer default 0, + target_accuracy integer not null, + epoch integer default 0, + -- TODO add max epoch + -- 1: Pre Init + -- 2: Init + -- 3: Training + -- 4: Tranied + -- 5: Ready + status integer default 1, + created_on timestamp default current_timestamp +); + +-- drop table if exists model_definition_layer; +create table if not exists model_definition_layer ( + id uuid primary key default gen_random_uuid(), + def_id uuid references model_definition (id) on delete cascade, + layer_order integer not null, + -- 1: input + -- 2: dense + -- 3: flatten + -- TODO add conv + layer_type integer not null, + -- ei 28,28,1 + -- a 28x28 grayscale image + shape text not null +); + diff --git a/views/models/edit.html b/views/models/edit.html index a792ef8..e846395 100644 --- a/views/models/edit.html +++ b/views/models/edit.html @@ -252,8 +252,39 @@ {{ end }} {{ define "train-model-card" }} -
{{ end }} @@ -313,6 +344,29 @@ {{/* TODO improve this */}} Processing zip file... + {{/* FAILED TO Prepare for training */}} + {{ else if (eq .Model.Status -3)}} + {{ template "base-model-card" . }} + + {{ template "delete-model-card" . }} + {{ else if (eq .Model.Status 4)}} + {{ template "base-model-card" . }} +