From 4985c8aa14430f58af24267bc2c9f030c3741f9b Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Mon, 19 Feb 2024 12:00:30 +0000 Subject: [PATCH] feat: closes #57 --- logic/models/run.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/logic/models/run.go b/logic/models/run.go index 5e9844a..decb7c5 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -70,7 +70,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil) //results := base_model.Exec([]tf.Output{ - base_model.Exec([]tf.Output{ + base_results := base_model.Exec([]tf.Output{ base_model.Op("StatefulPartitionedCall", 0), }, map[tf.Output]*tf.Tensor{ //base_model.Op("serving_default_rescaling_input", 0): inputImage, @@ -87,6 +87,27 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten return } + var vmax float32 = 0.0 + + for _, element := range heads { + head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil) + + results := head_model.Exec([]tf.Output{ + head_model.Op("StatefulPartitionedCall", 0), + }, map[tf.Output]*tf.Tensor{ + head_model.Op("serving_default_input_2", 0): base_results[0], + }) + + var predictions = results[0].Value().([][]float32)[0] + + for i, v := range predictions { + if v > vmax { + order = element.Range_start + i + vmax = v + } + } + } + // TODO runthe head model c.Logger.Info("Got", "heads", len(heads))