feat: closes #57
This commit is contained in:
parent
b5a28a0bdb
commit
4985c8aa14
@ -70,7 +70,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
|
||||
|
||||
//results := base_model.Exec([]tf.Output{
|
||||
base_model.Exec([]tf.Output{
|
||||
base_results := base_model.Exec([]tf.Output{
|
||||
base_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||
@ -87,6 +87,27 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
return
|
||||
}
|
||||
|
||||
var vmax float32 = 0.0
|
||||
|
||||
for _, element := range heads {
|
||||
head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil)
|
||||
|
||||
results := head_model.Exec([]tf.Output{
|
||||
head_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
head_model.Op("serving_default_input_2", 0): base_results[0],
|
||||
})
|
||||
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
for i, v := range predictions {
|
||||
if v > vmax {
|
||||
order = element.Range_start + i
|
||||
vmax = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO runthe head model
|
||||
|
||||
c.Logger.Info("Got", "heads", len(heads))
|
||||
|
Loading…
Reference in New Issue
Block a user