395 lines
10 KiB
Go
395 lines
10 KiB
Go
package models
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
|
|
|
tf "github.com/galeone/tensorflow/tensorflow/go"
|
|
"github.com/galeone/tensorflow/tensorflow/go/op"
|
|
tg "github.com/galeone/tfgo"
|
|
"github.com/galeone/tfgo/image"
|
|
)
|
|
|
|
func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
|
scope = tg.NewScope(scope)
|
|
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
|
|
output := op.DecodePng(scope.SubScope("DecodePng"), contents, op.DecodePngChannels(channels))
|
|
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
|
|
image := &image.Image{
|
|
Tensor: tg.NewTensor(scope, output)}
|
|
return image.Scale(0, 255)
|
|
}
|
|
|
|
func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
|
scope = tg.NewScope(scope)
|
|
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
|
|
output := op.DecodePng(scope.SubScope("DecodeJpeg"), contents, op.DecodePngChannels(channels))
|
|
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
|
|
image := &image.Image{
|
|
Tensor: tg.NewTensor(scope, output)}
|
|
return image.Scale(0, 255)
|
|
}
|
|
|
|
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
|
|
order = 0
|
|
err = nil
|
|
|
|
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
|
|
|
results := tf_model.Exec([]tf.Output{
|
|
tf_model.Op("StatefulPartitionedCall", 0),
|
|
}, map[tf.Output]*tf.Tensor{
|
|
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
|
})
|
|
|
|
var vmax float32 = 0.0
|
|
var predictions = results[0].Value().([][]float32)[0]
|
|
|
|
for i, v := range predictions {
|
|
if v > vmax {
|
|
order = i
|
|
vmax = v
|
|
}
|
|
}
|
|
|
|
confidence = vmax
|
|
|
|
return
|
|
}
|
|
|
|
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
|
|
|
|
err = nil
|
|
order = 0
|
|
|
|
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
|
|
|
|
//results := base_model.Exec([]tf.Output{
|
|
base_results := base_model.Exec([]tf.Output{
|
|
base_model.Op("StatefulPartitionedCall", 0),
|
|
}, map[tf.Output]*tf.Tensor{
|
|
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
|
|
base_model.Op("serving_default_input_1", 0): inputImage,
|
|
})
|
|
|
|
type head struct {
|
|
Id string
|
|
Range_start int
|
|
}
|
|
|
|
heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var vmax float32 = 0.0
|
|
|
|
for _, element := range heads {
|
|
head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil)
|
|
|
|
results := head_model.Exec([]tf.Output{
|
|
head_model.Op("StatefulPartitionedCall", 0),
|
|
}, map[tf.Output]*tf.Tensor{
|
|
head_model.Op("serving_default_input_2", 0): base_results[0],
|
|
})
|
|
|
|
var predictions = results[0].Value().([][]float32)[0]
|
|
|
|
for i, v := range predictions {
|
|
if v > vmax {
|
|
order = element.Range_start + i
|
|
vmax = v
|
|
}
|
|
}
|
|
}
|
|
|
|
// TODO runthe head model
|
|
confidence = vmax
|
|
|
|
c.Logger.Info("Got", "heads", len(heads))
|
|
return
|
|
}
|
|
|
|
func handleRun(handle *Handle) {
|
|
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
|
if !CheckAuthLevel(1, w, r, c) {
|
|
return nil
|
|
}
|
|
if c.Mode == JSON {
|
|
|
|
read_form, err := r.MultipartReader()
|
|
if err != nil {
|
|
// TODO improve message
|
|
return ErrorCode(nil, 400, nil)
|
|
}
|
|
|
|
var id string
|
|
var file []byte
|
|
|
|
for {
|
|
part, err_part := read_form.NextPart()
|
|
if err_part == io.EOF {
|
|
break
|
|
} else if err_part != nil {
|
|
return c.JsonBadRequest("Invalid multipart data")
|
|
}
|
|
if part.FormName() == "id" {
|
|
buf := new(bytes.Buffer)
|
|
buf.ReadFrom(part)
|
|
id = buf.String()
|
|
}
|
|
if part.FormName() == "file" {
|
|
buf := new(bytes.Buffer)
|
|
buf.ReadFrom(part)
|
|
file = buf.Bytes()
|
|
}
|
|
}
|
|
|
|
model, err := GetBaseModel(handle.Db, id)
|
|
if err == ModelNotFoundError {
|
|
return c.JsonBadRequest("Models not found")
|
|
} else if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
|
|
if model.Status != READY {
|
|
return c.JsonBadRequest("Model not ready to run images")
|
|
}
|
|
|
|
def := JustId{}
|
|
err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id)
|
|
if err == NotFoundError {
|
|
return c.JsonBadRequest("Could not find definition")
|
|
} else if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
|
|
def_id := def.Id
|
|
|
|
// TODO create a database table with tasks
|
|
run_path := path.Join("/tmp", model.Id, "runs")
|
|
os.MkdirAll(run_path, os.ModePerm)
|
|
img_path := path.Join(run_path, "img."+model.Format)
|
|
|
|
img_file, err := os.Create(img_path)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
defer img_file.Close()
|
|
img_file.Write(file)
|
|
|
|
if !testImgForModel(c, model, img_path) {
|
|
return c.JsonBadRequest("Provided image does not match the model")
|
|
}
|
|
|
|
root := tg.NewRoot()
|
|
|
|
var tf_img *image.Image = nil
|
|
|
|
switch model.Format {
|
|
case "png":
|
|
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
|
|
case "jpeg":
|
|
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
|
|
default:
|
|
panic("Not sure what to do with '" + model.Format + "'")
|
|
}
|
|
|
|
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
|
inputImage, err := tf.NewTensor(exec_results[0].Value())
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
|
|
vi := -1
|
|
var confidence float32 = 0
|
|
|
|
if model.ModelType == 2 {
|
|
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
|
vi, confidence, err = runModelExp(c, model, def_id, inputImage)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
} else {
|
|
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
|
vi, confidence, err = runModelNormal(c, model, def_id, inputImage)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
}
|
|
|
|
os.RemoveAll(run_path)
|
|
|
|
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
if !rows.Next() {
|
|
return c.SendJSON(nil)
|
|
}
|
|
|
|
var name string
|
|
if err = rows.Scan(&name); err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
|
|
returnValue := struct {
|
|
Class string `json:"class"`
|
|
Confidence float32 `json:"confidence"`
|
|
}{
|
|
Class: name,
|
|
Confidence: confidence,
|
|
}
|
|
|
|
return c.SendJSON(returnValue)
|
|
}
|
|
|
|
read_form, err := r.MultipartReader()
|
|
if err != nil {
|
|
// TODO improve message
|
|
return ErrorCode(nil, 400, nil)
|
|
}
|
|
|
|
var id string
|
|
var file []byte
|
|
|
|
for {
|
|
part, err_part := read_form.NextPart()
|
|
if err_part == io.EOF {
|
|
break
|
|
} else if err_part != nil {
|
|
return &Error{Code: http.StatusBadRequest}
|
|
}
|
|
if part.FormName() == "id" {
|
|
buf := new(bytes.Buffer)
|
|
buf.ReadFrom(part)
|
|
id = buf.String()
|
|
}
|
|
if part.FormName() == "file" {
|
|
buf := new(bytes.Buffer)
|
|
buf.ReadFrom(part)
|
|
file = buf.Bytes()
|
|
}
|
|
}
|
|
|
|
model, err := GetBaseModel(handle.Db, id)
|
|
if err == ModelNotFoundError {
|
|
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
|
"NotFoundMessage": "Model not found",
|
|
"GoBackLink": "/models",
|
|
})
|
|
} else if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
if model.Status != READY {
|
|
// TODO improve this
|
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
|
}
|
|
|
|
def := JustId{}
|
|
err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id)
|
|
if err == NotFoundError {
|
|
// TODO improve this
|
|
fmt.Printf("Could not find definition\n")
|
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
|
} else if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
def_id := def.Id
|
|
|
|
// TODO create a database table with tasks
|
|
run_path := path.Join("/tmp", model.Id, "runs")
|
|
os.MkdirAll(run_path, os.ModePerm)
|
|
img_path := path.Join(run_path, "img."+model.Format)
|
|
|
|
img_file, err := os.Create(img_path)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
defer img_file.Close()
|
|
img_file.Write(file)
|
|
|
|
if !testImgForModel(c, model, img_path) {
|
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
|
"Model": model,
|
|
"NotFound": false,
|
|
"Result": nil,
|
|
"ImageError": true,
|
|
}))
|
|
return nil
|
|
}
|
|
|
|
root := tg.NewRoot()
|
|
|
|
var tf_img *image.Image = nil
|
|
|
|
switch model.Format {
|
|
case "png":
|
|
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
|
|
case "jpeg":
|
|
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
|
|
default:
|
|
panic("Not sure what to do with '" + model.Format + "'")
|
|
}
|
|
|
|
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
|
inputImage, err := tf.NewTensor(exec_results[0].Value())
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
|
|
vi := -1
|
|
var confidence float32 = 0
|
|
|
|
if model.ModelType == 2 {
|
|
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
|
vi, confidence, err = runModelExp(c, model, def_id, inputImage)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
} else {
|
|
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
|
vi, confidence, err = runModelNormal(c, model, def_id, inputImage)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
}
|
|
|
|
os.RemoveAll(run_path)
|
|
|
|
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
|
|
if err != nil {
|
|
return Error500(err)
|
|
}
|
|
if !rows.Next() {
|
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
|
"Model": model,
|
|
"NotFound": true,
|
|
"Result": nil,
|
|
"Confidence": confidence,
|
|
}))
|
|
return nil
|
|
}
|
|
|
|
var name string
|
|
if err = rows.Scan(&name); err != nil {
|
|
return nil
|
|
}
|
|
|
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
|
"Model": model,
|
|
"Result": name,
|
|
}))
|
|
return nil
|
|
})
|
|
}
|