140 lines
3.2 KiB
Go
140 lines
3.2 KiB
Go
|
package models_train
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"path"
|
||
|
"strings"
|
||
|
"text/template"
|
||
|
|
||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||
|
)
|
||
|
|
||
|
/*import (
|
||
|
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||
|
tg "github.com/galeone/tfgo"
|
||
|
"github.com/galeone/tfgo/image"
|
||
|
"github.com/galeone/tfgo/image/filter"
|
||
|
"github.com/galeone/tfgo/image/padding"
|
||
|
)*/
|
||
|
|
||
|
func getDir() string {
|
||
|
dir, err := os.Getwd()
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return dir
|
||
|
}
|
||
|
|
||
|
func shapeToSize(shape string) string {
|
||
|
split := strings.Split(shape, ",")
|
||
|
return strings.Join(split[:len(split) - 1], ",")
|
||
|
}
|
||
|
|
||
|
func handleTest(handle *Handle) {
|
||
|
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||
|
if !CheckAuthLevel(1, w, r, c) {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
id, err := GetIdFromUrl(r, "id")
|
||
|
if err != nil {
|
||
|
return ErrorCode(err, 400, c.AddMap(nil))
|
||
|
}
|
||
|
|
||
|
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
|
||
|
if !rows.Next() {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
|
||
|
var name string
|
||
|
var file_path string
|
||
|
err = rows.Scan(&name, &file_path)
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
|
||
|
file_path = strings.Replace(file_path, "file://", "", 1)
|
||
|
|
||
|
img_path := path.Join("savedData", id, "data", "training", name, file_path)
|
||
|
|
||
|
fmt.Printf("%s\n", img_path)
|
||
|
|
||
|
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
defer definitions.Close()
|
||
|
|
||
|
if !definitions.Next() {
|
||
|
fmt.Println("Did not find definition")
|
||
|
return Error500(nil)
|
||
|
}
|
||
|
|
||
|
var definition_id string
|
||
|
|
||
|
if err = definitions.Scan(&definition_id); err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
|
||
|
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
defer layers.Close()
|
||
|
|
||
|
type layerrow struct {
|
||
|
LayerType int
|
||
|
Shape string
|
||
|
}
|
||
|
|
||
|
got := []layerrow{}
|
||
|
|
||
|
for layers.Next() {
|
||
|
var row = layerrow{}
|
||
|
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
row.Shape = shapeToSize(row.Shape)
|
||
|
got = append(got, row)
|
||
|
}
|
||
|
|
||
|
// Generate folder
|
||
|
|
||
|
err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm)
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
|
||
|
f, err := os.Create(path.Join("/tmp", id, "run.py"))
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
defer f.Close()
|
||
|
|
||
|
fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py"))
|
||
|
|
||
|
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||
|
if err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
|
||
|
if err = tmpl.Execute(f, AnyMap{
|
||
|
"Layers": got,
|
||
|
"Size": got[0].Shape,
|
||
|
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
|
||
|
}); err != nil {
|
||
|
return Error500(err)
|
||
|
}
|
||
|
|
||
|
w.Write([]byte("Done"))
|
||
|
return nil
|
||
|
})
|
||
|
}
|