fyp/logic/models/train/tensorflow-test.go

176 lines
4.2 KiB
Go
Raw Normal View History

package models_train
import (
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path"
"strconv"
"strings"
"text/template"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
/*import (
tf "github.com/galeone/tensorflow/tensorflow/go"
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
"github.com/galeone/tfgo/image/filter"
"github.com/galeone/tfgo/image/padding"
)*/
func getDir() string {
dir, err := os.Getwd()
if err != nil {
panic(err)
}
return dir
}
func shapeToSize(shape string) string {
split := strings.Split(shape, ",")
return strings.Join(split[:len(split) - 1], ",")
}
func handleTest(handle *Handle) {
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
id, err := GetIdFromUrl(r, "id")
if err != nil {
return ErrorCode(err, 400, c.AddMap(nil))
}
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
if err != nil {
return Error500(err)
}
defer rows.Close()
if !rows.Next() {
return Error500(err)
}
var name string
var file_path string
err = rows.Scan(&name, &file_path)
if err != nil {
return Error500(err)
}
file_path = strings.Replace(file_path, "file://", "", 1)
img_path := path.Join("savedData", id, "data", "training", name, file_path)
fmt.Printf("%s\n", img_path)
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
if err != nil {
return Error500(err)
}
defer definitions.Close()
if !definitions.Next() {
fmt.Println("Did not find definition")
return Error500(nil)
}
var definition_id string
if err = definitions.Scan(&definition_id); err != nil {
return Error500(err)
}
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
return Error500(err)
}
defer layers.Close()
type layerrow struct {
LayerType int
Shape string
}
got := []layerrow{}
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
return Error500(err)
}
row.Shape = shapeToSize(row.Shape)
got = append(got, row)
}
// Generate folder
run_path := path.Join("/tmp", id, "defs", definition_id)
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return Error500(err)
}
f, err := os.Create(path.Join(run_path, "run.py"))
if err != nil {
return Error500(err)
}
defer f.Close()
fmt.Printf("Using path: %s\n", run_path)
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
if err != nil {
return Error500(err)
}
if err = tmpl.Execute(f, AnyMap{
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
}); err != nil {
return Error500(err)
}
cmd := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path))
_, err = cmd.Output()
if err != nil {
return Error500(err)
}
result_path := path.Join("savedData", id, "defs", definition_id)
os.MkdirAll(result_path, os.ModePerm)
err = exec.Command("cp", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run()
if err != nil {
return Error500(err)
}
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil {
return Error500(err)
}
defer accuracy_file.Close()
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
if err != nil {
return Error500(err)
}
accuracy, err := strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil {
return Error500(err)
}
w.Write([]byte(strconv.FormatFloat(accuracy, 'f', -1, 64)))
return nil
})
}