176 lines
4.2 KiB
Go
176 lines
4.2 KiB
Go
package models_train
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"text/template"
|
|
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
|
)
|
|
|
|
/*import (
|
|
tf "github.com/galeone/tensorflow/tensorflow/go"
|
|
tg "github.com/galeone/tfgo"
|
|
"github.com/galeone/tfgo/image"
|
|
"github.com/galeone/tfgo/image/filter"
|
|
"github.com/galeone/tfgo/image/padding"
|
|
)*/
|
|
|
|
func getDir() string {
|
|
dir, err := os.Getwd()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return dir
|
|
}
|
|
|
|
func shapeToSize(shape string) string {
|
|
split := strings.Split(shape, ",")
|
|
return strings.Join(split[:len(split) - 1], ",")
|
|
}
|
|
|
|
func handleTest(handle *Handle) {
|
|
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
|
if !CheckAuthLevel(1, w, r, c) {
|
|
return nil
|
|
}
|
|
|
|
id, err := GetIdFromUrl(r, "id")
|
|
if err != nil {
|
|
return ErrorCode(err, 400, c.AddMap(nil))
|
|
}
|
|
|
|
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
return Error500(err)
|
|
}
|
|
|
|
var name string
|
|
var file_path string
|
|
err = rows.Scan(&name, &file_path)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
file_path = strings.Replace(file_path, "file://", "", 1)
|
|
|
|
img_path := path.Join("savedData", id, "data", "training", name, file_path)
|
|
|
|
fmt.Printf("%s\n", img_path)
|
|
|
|
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
defer definitions.Close()
|
|
|
|
if !definitions.Next() {
|
|
fmt.Println("Did not find definition")
|
|
return Error500(nil)
|
|
}
|
|
|
|
var definition_id string
|
|
|
|
if err = definitions.Scan(&definition_id); err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
defer layers.Close()
|
|
|
|
type layerrow struct {
|
|
LayerType int
|
|
Shape string
|
|
}
|
|
|
|
got := []layerrow{}
|
|
|
|
for layers.Next() {
|
|
var row = layerrow{}
|
|
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
|
return Error500(err)
|
|
}
|
|
row.Shape = shapeToSize(row.Shape)
|
|
got = append(got, row)
|
|
}
|
|
|
|
// Generate folder
|
|
|
|
run_path := path.Join("/tmp", id, "defs", definition_id)
|
|
|
|
err = os.MkdirAll(run_path, os.ModePerm)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
f, err := os.Create(path.Join(run_path, "run.py"))
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
fmt.Printf("Using path: %s\n", run_path)
|
|
|
|
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
if err = tmpl.Execute(f, AnyMap{
|
|
"Layers": got,
|
|
"Size": got[0].Shape,
|
|
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
|
|
}); err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
cmd := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path))
|
|
_, err = cmd.Output()
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
result_path := path.Join("savedData", id, "defs", definition_id)
|
|
|
|
os.MkdirAll(result_path, os.ModePerm)
|
|
|
|
err = exec.Command("cp", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run()
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
defer accuracy_file.Close()
|
|
|
|
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
accuracy, err := strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
w.Write([]byte(strconv.FormatFloat(accuracy, 'f', -1, 64)))
|
|
return nil
|
|
})
|
|
}
|