118
logic/models/run.go
Normal file
118
logic/models/run.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
|
||||
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||
tg "github.com/galeone/tfgo"
|
||||
"github.com/galeone/tfgo/image"
|
||||
)
|
||||
|
||||
func handleRun(handle *Handle) {
|
||||
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
if c.Mode == JSON {
|
||||
// TODO improve message
|
||||
return ErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
read_form, err := r.MultipartReader()
|
||||
if err != nil {
|
||||
// TODO improve message
|
||||
return ErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
var id string
|
||||
var file []byte
|
||||
|
||||
for {
|
||||
part, err_part := read_form.NextPart()
|
||||
if err_part == io.EOF {
|
||||
break
|
||||
} else if err_part != nil {
|
||||
return &Error{Code: http.StatusBadRequest}
|
||||
}
|
||||
if part.FormName() == "id" {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(part)
|
||||
id = buf.String()
|
||||
}
|
||||
if part.FormName() == "file" {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(part)
|
||||
file = buf.Bytes()
|
||||
}
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
"GoBackLink": "/models",
|
||||
})
|
||||
} else if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != READY {
|
||||
// TODO improve this
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
|
||||
if !definitions_rows.Next() {
|
||||
// TODO improve this
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
defer definitions_rows.Close()
|
||||
|
||||
if !definitions_rows.Next() {
|
||||
return Error500(nil)
|
||||
}
|
||||
|
||||
var def_id string
|
||||
if err = definitions_rows.Scan(&def_id); err != nil {
|
||||
return Error500(nil)
|
||||
}
|
||||
|
||||
// TODO create a database table with tasks
|
||||
run_path := path.Join("/tmp", model.Id, "runs")
|
||||
os.MkdirAll(run_path, os.ModePerm)
|
||||
|
||||
img_file, err := os.Create(path.Join(run_path, "img.png"))
|
||||
if err != nil {
|
||||
return Error500(nil)
|
||||
}
|
||||
defer img_file.Close()
|
||||
img_file.Write(file)
|
||||
|
||||
root := tg.NewRoot()
|
||||
tf_img := image.Read(root, path.Join(run_path, "img.png"), 1)
|
||||
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||
|
||||
tf_img_tensor, err := tf.NewTensor(tf_img.Value())
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
results := tf_model.Exec([]tf.Output{
|
||||
tf_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
tf_model.Op("serving_default_inputs_input", 0): tf_img_tensor,
|
||||
})
|
||||
|
||||
predictions := results[0]
|
||||
fmt.Println(predictions.Value())
|
||||
return nil
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user