feat: closes #22
This commit is contained in:
@@ -7,15 +7,27 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
|
||||
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||
"github.com/galeone/tensorflow/tensorflow/go/op"
|
||||
tg "github.com/galeone/tfgo"
|
||||
"github.com/galeone/tfgo/image"
|
||||
)
|
||||
|
||||
func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
||||
scope = tg.NewScope(scope)
|
||||
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
|
||||
output := op.DecodePng(scope.SubScope("DecodePng"), contents, op.DecodePngChannels(channels))
|
||||
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
|
||||
image := &image.Image{
|
||||
Tensor: tg.NewTensor(scope, output)}
|
||||
return image.Scale(0, 255)
|
||||
}
|
||||
|
||||
func handleRun(handle *Handle) {
|
||||
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
@@ -98,10 +110,10 @@ func handleRun(handle *Handle) {
|
||||
img_file.Write(file)
|
||||
|
||||
root := tg.NewRoot()
|
||||
tf_img := image.Read(root, path.Join(run_path, "img.png"), 3)
|
||||
|
||||
batch := tg.Batchify(root, []tf.Output{tf_img.Value()})
|
||||
exec_results := tg.Exec(root, []tf.Output{batch}, nil, &tf.SessionOptions{})
|
||||
|
||||
tf_img := ReadPNG(root, path.Join(run_path, "img.png"), 3)
|
||||
|
||||
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
||||
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
@@ -115,8 +127,23 @@ func handleRun(handle *Handle) {
|
||||
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||
})
|
||||
|
||||
predictions := results[0]
|
||||
fmt.Println(predictions.Value())
|
||||
var vmax float32 = 0.0
|
||||
vi := 0
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
for i, v := range predictions {
|
||||
if v > vmax {
|
||||
vi = i
|
||||
vmax = v
|
||||
}
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Result": strconv.Itoa(vi),
|
||||
}))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -156,9 +156,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user