package models import ( "bytes" "io" "os" "path" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" tf "github.com/galeone/tensorflow/tensorflow/go" "github.com/galeone/tensorflow/tensorflow/go/op" tg "github.com/galeone/tfgo" "github.com/galeone/tfgo/image" ) func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image { scope = tg.NewScope(scope) contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath)) output := op.DecodePng(scope.SubScope("DecodePng"), contents, op.DecodePngChannels(channels)) output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0})) image := &image.Image{ Tensor: tg.NewTensor(scope, output)} return image.Scale(0, 255) } func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image { scope = tg.NewScope(scope) contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath)) output := op.DecodePng(scope.SubScope("DecodeJpeg"), contents, op.DecodePngChannels(channels)) output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0})) image := &image.Image{ Tensor: tg.NewTensor(scope, output)} return image.Scale(0, 255) } func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { order = 0 err = nil tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) results := tf_model.Exec([]tf.Output{ tf_model.Op("StatefulPartitionedCall", 0), }, map[tf.Output]*tf.Tensor{ tf_model.Op("serving_default_rescaling_input", 0): inputImage, }) var vmax float32 = 0.0 var predictions = results[0].Value().([][]float32)[0] for i, v := range predictions { if v > vmax { order = i vmax = v } } confidence = vmax return } func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { err = nil order = 0 base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil) //results := base_model.Exec([]tf.Output{ base_results := base_model.Exec([]tf.Output{ base_model.Op("StatefulPartitionedCall", 0), }, map[tf.Output]*tf.Tensor{ //base_model.Op("serving_default_rescaling_input", 0): inputImage, base_model.Op("serving_default_input_1", 0): inputImage, }) type head struct { Id string Range_start int } heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id) if err != nil { return } c.Logger.Info("test", "count", len(heads)) var vmax float32 = 0.0 for _, element := range heads { head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil) results := head_model.Exec([]tf.Output{ head_model.Op("StatefulPartitionedCall", 0), }, map[tf.Output]*tf.Tensor{ head_model.Op("serving_default_head_input", 0): base_results[0], }) var predictions = results[0].Value().([][]float32)[0] for i, v := range predictions { c.Logger.Info("predictions", "class", i, "preds", v) if v > vmax { order = element.Range_start + i vmax = v } } } // TODO runthe head model confidence = vmax c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax) return } func handleRun(handle *Handle) { handle.Post("/models/run", func(c *Context) *Error { if !c.CheckAuthLevel(1) { return nil } read_form, err := c.R.MultipartReader() if err != nil { return c.JsonBadRequest("Invalid muilpart body") } var id string var file []byte for { part, err_part := read_form.NextPart() if err_part == io.EOF { break } else if err_part != nil { return c.JsonBadRequest("Invalid multipart data") } if part.FormName() == "id" { buf := new(bytes.Buffer) buf.ReadFrom(part) id = buf.String() } if part.FormName() == "file" { buf := new(bytes.Buffer) buf.ReadFrom(part) file = buf.Bytes() } } model, err := GetBaseModel(handle.Db, id) if err == ModelNotFoundError { return c.JsonBadRequest("Models not found") } else if err != nil { return c.Error500(err) } if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED { return c.JsonBadRequest("Model not ready to run images") } def := JustId{} err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id) if err == NotFoundError { return c.JsonBadRequest("Could not find definition") } else if err != nil { return c.Error500(err) } def_id := def.Id // TODO create a database table with tasks run_path := path.Join("/tmp", model.Id, "runs") os.MkdirAll(run_path, os.ModePerm) img_path := path.Join(run_path, "img."+model.Format) img_file, err := os.Create(img_path) if err != nil { return c.Error500(err) } defer img_file.Close() img_file.Write(file) if !testImgForModel(c, model, img_path) { return c.JsonBadRequest("Provided image does not match the model") } root := tg.NewRoot() var tf_img *image.Image = nil switch model.Format { case "png": tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) case "jpeg": tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) default: panic("Not sure what to do with '" + model.Format + "'") } exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) inputImage, err := tf.NewTensor(exec_results[0].Value()) if err != nil { return c.Error500(err) } vi := -1 var confidence float32 = 0 if model.ModelType == 2 { c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) vi, confidence, err = runModelExp(c, model, def_id, inputImage) if err != nil { return c.Error500(err) } } else { c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) vi, confidence, err = runModelNormal(c, model, def_id, inputImage) if err != nil { return c.Error500(err) } } os.RemoveAll(run_path) rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi) if err != nil { return c.Error500(err) } if !rows.Next() { return c.SendJSON(nil) } var name string if err = rows.Scan(&name); err != nil { return c.Error500(err) } returnValue := struct { Class string `json:"class"` Confidence float32 `json:"confidence"` }{ Class: name, Confidence: confidence, } return c.SendJSON(returnValue) }) }