chore: #22 Added code to run through
This commit is contained in:
parent
194277297f
commit
dfa62118de
@ -18,6 +18,7 @@ func HandleModels (handle *Handle) {
|
|||||||
model_classes.HandleList(handle)
|
model_classes.HandleList(handle)
|
||||||
|
|
||||||
// Train endpoints
|
// Train endpoints
|
||||||
|
handleRun(handle)
|
||||||
models_train.HandleTrainEndpoints(handle)
|
models_train.HandleTrainEndpoints(handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,19 +70,20 @@ func handleRun(handle *Handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
|
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
|
||||||
if !definitions_rows.Next() {
|
if err != nil {
|
||||||
// TODO improve this
|
return Error500(err)
|
||||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
|
||||||
}
|
}
|
||||||
defer definitions_rows.Close()
|
defer definitions_rows.Close()
|
||||||
|
|
||||||
if !definitions_rows.Next() {
|
if !definitions_rows.Next() {
|
||||||
return Error500(nil)
|
// TODO improve this
|
||||||
|
fmt.Printf("Could not find definition\n")
|
||||||
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
var def_id string
|
var def_id string
|
||||||
if err = definitions_rows.Scan(&def_id); err != nil {
|
if err = definitions_rows.Scan(&def_id); err != nil {
|
||||||
return Error500(nil)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO create a database table with tasks
|
// TODO create a database table with tasks
|
||||||
@ -97,18 +98,21 @@ func handleRun(handle *Handle) {
|
|||||||
img_file.Write(file)
|
img_file.Write(file)
|
||||||
|
|
||||||
root := tg.NewRoot()
|
root := tg.NewRoot()
|
||||||
tf_img := image.Read(root, path.Join(run_path, "img.png"), 1)
|
tf_img := image.Read(root, path.Join(run_path, "img.png"), 3)
|
||||||
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
|
||||||
|
|
||||||
tf_img_tensor, err := tf.NewTensor(tf_img.Value())
|
batch := tg.Batchify(root, []tf.Output{tf_img.Value()})
|
||||||
|
exec_results := tg.Exec(root, []tf.Output{batch}, nil, &tf.SessionOptions{})
|
||||||
|
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||||
|
|
||||||
results := tf_model.Exec([]tf.Output{
|
results := tf_model.Exec([]tf.Output{
|
||||||
tf_model.Op("StatefulPartitionedCall", 0),
|
tf_model.Op("StatefulPartitionedCall", 0),
|
||||||
}, map[tf.Output]*tf.Tensor{
|
}, map[tf.Output]*tf.Tensor{
|
||||||
tf_model.Op("serving_default_inputs_input", 0): tf_img_tensor,
|
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||||
})
|
})
|
||||||
|
|
||||||
predictions := results[0]
|
predictions := results[0]
|
||||||
|
@ -303,7 +303,7 @@
|
|||||||
Image File
|
Image File
|
||||||
</span>
|
</span>
|
||||||
</button>
|
</button>
|
||||||
<input id="file" name="file" type="file" required accept="application/zip" />
|
<input id="file" name="file" type="file" required accept="image/png" />
|
||||||
</div>
|
</div>
|
||||||
</fieldset>
|
</fieldset>
|
||||||
<button>
|
<button>
|
||||||
|
Loading…
Reference in New Issue
Block a user