package models import ( "bytes" "fmt" "io" "net/http" "os" "path" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" tf "github.com/galeone/tensorflow/tensorflow/go" tg "github.com/galeone/tfgo" "github.com/galeone/tfgo/image" ) func handleRun(handle *Handle) { handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { return nil } if c.Mode == JSON { // TODO improve message return ErrorCode(nil, 400, nil) } read_form, err := r.MultipartReader() if err != nil { // TODO improve message return ErrorCode(nil, 400, nil) } var id string var file []byte for { part, err_part := read_form.NextPart() if err_part == io.EOF { break } else if err_part != nil { return &Error{Code: http.StatusBadRequest} } if part.FormName() == "id" { buf := new(bytes.Buffer) buf.ReadFrom(part) id = buf.String() } if part.FormName() == "file" { buf := new(bytes.Buffer) buf.ReadFrom(part) file = buf.Bytes() } } model, err := GetBaseModel(handle.Db, id) if err == ModelNotFoundError { return ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", "GoBackLink": "/models", }) } else if err != nil { return Error500(err) } if model.Status != READY { // TODO improve this return ErrorCode(nil, 400, c.AddMap(nil)) } definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id) if err != nil { return Error500(err) } defer definitions_rows.Close() if !definitions_rows.Next() { // TODO improve this fmt.Printf("Could not find definition\n") return ErrorCode(nil, 400, c.AddMap(nil)) } var def_id string if err = definitions_rows.Scan(&def_id); err != nil { return Error500(err) } // TODO create a database table with tasks run_path := path.Join("/tmp", model.Id, "runs") os.MkdirAll(run_path, os.ModePerm) img_file, err := os.Create(path.Join(run_path, "img.png")) if err != nil { return Error500(nil) } defer img_file.Close() img_file.Write(file) root := tg.NewRoot() tf_img := image.Read(root, path.Join(run_path, "img.png"), 3) batch := tg.Batchify(root, []tf.Output{tf_img.Value()}) exec_results := tg.Exec(root, []tf.Output{batch}, nil, &tf.SessionOptions{}) inputImage, err:= tf.NewTensor(exec_results[0].Value()) if err != nil { return Error500(err) } tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) results := tf_model.Exec([]tf.Output{ tf_model.Op("StatefulPartitionedCall", 0), }, map[tf.Output]*tf.Tensor{ tf_model.Op("serving_default_rescaling_input", 0): inputImage, }) predictions := results[0] fmt.Println(predictions.Value()) return nil }) }