package models_train import ( "fmt" "net/http" "os" "path" "strings" "text/template" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) /*import ( tf "github.com/galeone/tensorflow/tensorflow/go" tg "github.com/galeone/tfgo" "github.com/galeone/tfgo/image" "github.com/galeone/tfgo/image/filter" "github.com/galeone/tfgo/image/padding" )*/ func getDir() string { dir, err := os.Getwd() if err != nil { panic(err) } return dir } func shapeToSize(shape string) string { split := strings.Split(shape, ",") return strings.Join(split[:len(split) - 1], ",") } func handleTest(handle *Handle) { handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { return nil } id, err := GetIdFromUrl(r, "id") if err != nil { return ErrorCode(err, 400, c.AddMap(nil)) } rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id) if err != nil { return Error500(err) } defer rows.Close() if !rows.Next() { return Error500(err) } var name string var file_path string err = rows.Scan(&name, &file_path) if err != nil { return Error500(err) } file_path = strings.Replace(file_path, "file://", "", 1) img_path := path.Join("savedData", id, "data", "training", name, file_path) fmt.Printf("%s\n", img_path) definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id) if err != nil { return Error500(err) } defer definitions.Close() if !definitions.Next() { fmt.Println("Did not find definition") return Error500(nil) } var definition_id string if err = definitions.Scan(&definition_id); err != nil { return Error500(err) } layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { return Error500(err) } defer layers.Close() type layerrow struct { LayerType int Shape string } got := []layerrow{} for layers.Next() { var row = layerrow{} if err = layers.Scan(&row.LayerType, &row.Shape); err != nil { return Error500(err) } row.Shape = shapeToSize(row.Shape) got = append(got, row) } // Generate folder err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm) if err != nil { return Error500(err) } f, err := os.Create(path.Join("/tmp", id, "run.py")) if err != nil { return Error500(err) } defer f.Close() fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py")) tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") if err != nil { return Error500(err) } if err = tmpl.Execute(f, AnyMap{ "Layers": got, "Size": got[0].Shape, "DataDir": path.Join(getDir(), "savedData", id, "data", "training"), }); err != nil { return Error500(err) } w.Write([]byte("Done")) return nil }) }