package models_train import ( "fmt" "io" "net/http" "os" "os/exec" "path" "strconv" "strings" "text/template" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) /*import ( tf "github.com/galeone/tensorflow/tensorflow/go" tg "github.com/galeone/tfgo" "github.com/galeone/tfgo/image" "github.com/galeone/tfgo/image/filter" "github.com/galeone/tfgo/image/padding" )*/ func getDir() string { dir, err := os.Getwd() if err != nil { panic(err) } return dir } func shapeToSize(shape string) string { split := strings.Split(shape, ",") return strings.Join(split[:len(split) - 1], ",") } func handleTest(handle *Handle) { handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { return nil } id, err := GetIdFromUrl(r, "id") if err != nil { return ErrorCode(err, 400, c.AddMap(nil)) } rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id) if err != nil { return Error500(err) } defer rows.Close() if !rows.Next() { return Error500(err) } var name string var file_path string err = rows.Scan(&name, &file_path) if err != nil { return Error500(err) } file_path = strings.Replace(file_path, "file://", "", 1) img_path := path.Join("savedData", id, "data", "training", name, file_path) fmt.Printf("%s\n", img_path) definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id) if err != nil { return Error500(err) } defer definitions.Close() if !definitions.Next() { fmt.Println("Did not find definition") return Error500(nil) } var definition_id string if err = definitions.Scan(&definition_id); err != nil { return Error500(err) } layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { return Error500(err) } defer layers.Close() type layerrow struct { LayerType int Shape string } got := []layerrow{} for layers.Next() { var row = layerrow{} if err = layers.Scan(&row.LayerType, &row.Shape); err != nil { return Error500(err) } row.Shape = shapeToSize(row.Shape) got = append(got, row) } // Generate folder run_path := path.Join("/tmp", id, "defs", definition_id) err = os.MkdirAll(run_path, os.ModePerm) if err != nil { return Error500(err) } f, err := os.Create(path.Join(run_path, "run.py")) if err != nil { return Error500(err) } defer f.Close() fmt.Printf("Using path: %s\n", run_path) tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") if err != nil { return Error500(err) } if err = tmpl.Execute(f, AnyMap{ "Layers": got, "Size": got[0].Shape, "DataDir": path.Join(getDir(), "savedData", id, "data", "training"), }); err != nil { return Error500(err) } cmd := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)) _, err = cmd.Output() if err != nil { return Error500(err) } result_path := path.Join("savedData", id, "defs", definition_id) os.MkdirAll(result_path, os.ModePerm) err = exec.Command("cp", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run() if err != nil { return Error500(err) } accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val")) if err != nil { return Error500(err) } defer accuracy_file.Close() accuracy_file_bytes, err := io.ReadAll(accuracy_file) if err != nil { return Error500(err) } accuracy, err := strconv.ParseFloat(string(accuracy_file_bytes), 64) if err != nil { return Error500(err) } w.Write([]byte(strconv.FormatFloat(accuracy, 'f', -1, 64))) return nil }) }