diff --git a/logic/models/train/tensorflow-test.go b/logic/models/train/tensorflow-test.go index c645ae9..e6d0a77 100644 --- a/logic/models/train/tensorflow-test.go +++ b/logic/models/train/tensorflow-test.go @@ -2,9 +2,12 @@ package models_train import ( "fmt" + "io" "net/http" "os" + "os/exec" "path" + "strconv" "strings" "text/template" @@ -107,18 +110,20 @@ func handleTest(handle *Handle) { // Generate folder - err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm) + run_path := path.Join("/tmp", id, "defs", definition_id) + + err = os.MkdirAll(run_path, os.ModePerm) if err != nil { return Error500(err) } - f, err := os.Create(path.Join("/tmp", id, "run.py")) + f, err := os.Create(path.Join(run_path, "run.py")) if err != nil { return Error500(err) } defer f.Close() - fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py")) + fmt.Printf("Using path: %s\n", run_path) tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") if err != nil { @@ -133,7 +138,38 @@ func handleTest(handle *Handle) { return Error500(err) } - w.Write([]byte("Done")) + cmd := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)) + _, err = cmd.Output() + if err != nil { + return Error500(err) + } + + result_path := path.Join("savedData", id, "defs", definition_id) + + os.MkdirAll(result_path, os.ModePerm) + + err = exec.Command("cp", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run() + if err != nil { + return Error500(err) + } + + accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val")) + if err != nil { + return Error500(err) + } + defer accuracy_file.Close() + + accuracy_file_bytes, err := io.ReadAll(accuracy_file) + if err != nil { + return Error500(err) + } + + accuracy, err := strconv.ParseFloat(string(accuracy_file_bytes), 64) + if err != nil { + return Error500(err) + } + + w.Write([]byte(strconv.FormatFloat(accuracy, 'f', -1, 64))) return nil }) } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 28fd457..4542524 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -39,6 +39,7 @@ func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string type ModelDefinitionStatus int const ( + MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3 MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 MODEL_DEFINITION_STATUS_INIT = 2 MODEL_DEFINITION_STATUS_TRAINING = 3 @@ -56,6 +57,41 @@ func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape return } +func trainModel(handle *Handle, model *BaseModel) { + definitionsRows, err := handle.Db.Query("select id from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT) + if err != nil { + fmt.Printf("Failed to trainModel!Err:\n") + fmt.Println(err) + ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + return + } + defer definitionsRows.Close() + + definitions := []string{} + + for definitionsRows.Next() { + var id string + if err = definitionsRows.Scan(&id); err != nil { + fmt.Printf("Failed to trainModel!Err:\n") + fmt.Println(err) + ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + return + } + definitions = append(definitions, id) + } + + if len(definitions) == 0 { + fmt.Printf("Failed to trainModel!Err:\n") + fmt.Println(err) + ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + return + } + + for _, def_id := range definitions { + _ = def_id + } +} + func handleTrain(handle *Handle) { handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { @@ -151,6 +187,8 @@ func handleTrain(handle *Handle) { // TODO start training with id fid + go trainModel(handle, model) + ModelUpdateStatus(handle, model.Id, TRAINING) Redirect("/models/edit?id=" + model.Id, c.Mode, w, r) return nil