fyp/logic/models/run.go

150 lines
3.8 KiB
Go
Raw Normal View History

2023-09-27 21:20:39 +01:00
package models
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"path"
2023-09-29 13:27:43 +01:00
"strconv"
2023-09-27 21:20:39 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
tf "github.com/galeone/tensorflow/tensorflow/go"
2023-09-29 13:27:43 +01:00
"github.com/galeone/tensorflow/tensorflow/go/op"
2023-09-27 21:20:39 +01:00
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
)
2023-09-29 13:27:43 +01:00
func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image {
scope = tg.NewScope(scope)
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
output := op.DecodePng(scope.SubScope("DecodePng"), contents, op.DecodePngChannels(channels))
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
image := &image.Image{
Tensor: tg.NewTensor(scope, output)}
return image.Scale(0, 255)
}
2023-09-27 21:20:39 +01:00
func handleRun(handle *Handle) {
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
if c.Mode == JSON {
// TODO improve message
return ErrorCode(nil, 400, nil)
}
read_form, err := r.MultipartReader()
if err != nil {
// TODO improve message
return ErrorCode(nil, 400, nil)
}
var id string
var file []byte
for {
part, err_part := read_form.NextPart()
if err_part == io.EOF {
break
} else if err_part != nil {
return &Error{Code: http.StatusBadRequest}
}
if part.FormName() == "id" {
buf := new(bytes.Buffer)
buf.ReadFrom(part)
id = buf.String()
}
if part.FormName() == "file" {
buf := new(bytes.Buffer)
buf.ReadFrom(part)
file = buf.Bytes()
}
}
model, err := GetBaseModel(handle.Db, id)
if err == ModelNotFoundError {
return ErrorCode(nil, http.StatusNotFound, AnyMap{
"NotFoundMessage": "Model not found",
"GoBackLink": "/models",
})
} else if err != nil {
return Error500(err)
}
if model.Status != READY {
// TODO improve this
return ErrorCode(nil, 400, c.AddMap(nil))
}
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
2023-09-28 12:16:36 +01:00
if err != nil {
return Error500(err)
2023-09-27 21:20:39 +01:00
}
defer definitions_rows.Close()
if !definitions_rows.Next() {
2023-09-28 12:16:36 +01:00
// TODO improve this
fmt.Printf("Could not find definition\n")
return ErrorCode(nil, 400, c.AddMap(nil))
2023-09-27 21:20:39 +01:00
}
var def_id string
if err = definitions_rows.Scan(&def_id); err != nil {
2023-09-28 12:16:36 +01:00
return Error500(err)
2023-09-27 21:20:39 +01:00
}
// TODO create a database table with tasks
run_path := path.Join("/tmp", model.Id, "runs")
os.MkdirAll(run_path, os.ModePerm)
img_file, err := os.Create(path.Join(run_path, "img.png"))
if err != nil {
return Error500(nil)
}
defer img_file.Close()
img_file.Write(file)
root := tg.NewRoot()
2023-09-29 13:27:43 +01:00
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3)
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
2023-09-28 12:16:36 +01:00
inputImage, err:= tf.NewTensor(exec_results[0].Value())
2023-09-27 21:20:39 +01:00
if err != nil {
return Error500(err)
}
2023-09-28 12:16:36 +01:00
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
2023-09-27 21:20:39 +01:00
results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
2023-09-28 12:16:36 +01:00
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
2023-09-27 21:20:39 +01:00
})
2023-09-29 13:27:43 +01:00
var vmax float32 = 0.0
vi := 0
var predictions = results[0].Value().([][]float32)[0]
for i, v := range predictions {
if v > vmax {
vi = i
vmax = v
}
}
os.RemoveAll(run_path)
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
"Model": model,
"Result": strconv.Itoa(vi),
}))
2023-09-27 21:20:39 +01:00
return nil
})
}