feat: closes #57

This commit is contained in:
Andre Henriques 2024-02-19 12:00:30 +00:00
parent b5a28a0bdb
commit 4985c8aa14

View File

@ -70,7 +70,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil) base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
//results := base_model.Exec([]tf.Output{ //results := base_model.Exec([]tf.Output{
base_model.Exec([]tf.Output{ base_results := base_model.Exec([]tf.Output{
base_model.Op("StatefulPartitionedCall", 0), base_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{ }, map[tf.Output]*tf.Tensor{
//base_model.Op("serving_default_rescaling_input", 0): inputImage, //base_model.Op("serving_default_rescaling_input", 0): inputImage,
@ -87,6 +87,27 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
return return
} }
var vmax float32 = 0.0
for _, element := range heads {
head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil)
results := head_model.Exec([]tf.Output{
head_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
head_model.Op("serving_default_input_2", 0): base_results[0],
})
var predictions = results[0].Value().([][]float32)[0]
for i, v := range predictions {
if v > vmax {
order = element.Range_start + i
vmax = v
}
}
}
// TODO runthe head model // TODO runthe head model
c.Logger.Info("Got", "heads", len(heads)) c.Logger.Info("Got", "heads", len(heads))