chore: #22 Added code to run through

This commit is contained in:
Andre Henriques 2023-09-28 12:16:36 +01:00
parent 194277297f
commit dfa62118de
4 changed files with 15 additions and 10 deletions

BIN
a.out Executable file

Binary file not shown.

View File

@ -18,6 +18,7 @@ func HandleModels (handle *Handle) {
model_classes.HandleList(handle) model_classes.HandleList(handle)
// Train endpoints // Train endpoints
handleRun(handle)
models_train.HandleTrainEndpoints(handle) models_train.HandleTrainEndpoints(handle)
} }

View File

@ -70,19 +70,20 @@ func handleRun(handle *Handle) {
} }
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id) definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
if !definitions_rows.Next() { if err != nil {
// TODO improve this return Error500(err)
return ErrorCode(nil, 400, c.AddMap(nil))
} }
defer definitions_rows.Close() defer definitions_rows.Close()
if !definitions_rows.Next() { if !definitions_rows.Next() {
return Error500(nil) // TODO improve this
fmt.Printf("Could not find definition\n")
return ErrorCode(nil, 400, c.AddMap(nil))
} }
var def_id string var def_id string
if err = definitions_rows.Scan(&def_id); err != nil { if err = definitions_rows.Scan(&def_id); err != nil {
return Error500(nil) return Error500(err)
} }
// TODO create a database table with tasks // TODO create a database table with tasks
@ -97,18 +98,21 @@ func handleRun(handle *Handle) {
img_file.Write(file) img_file.Write(file)
root := tg.NewRoot() root := tg.NewRoot()
tf_img := image.Read(root, path.Join(run_path, "img.png"), 1) tf_img := image.Read(root, path.Join(run_path, "img.png"), 3)
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
tf_img_tensor, err := tf.NewTensor(tf_img.Value()) batch := tg.Batchify(root, []tf.Output{tf_img.Value()})
exec_results := tg.Exec(root, []tf.Output{batch}, nil, &tf.SessionOptions{})
inputImage, err:= tf.NewTensor(exec_results[0].Value())
if err != nil { if err != nil {
return Error500(err) return Error500(err)
} }
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
results := tf_model.Exec([]tf.Output{ results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0), tf_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{ }, map[tf.Output]*tf.Tensor{
tf_model.Op("serving_default_inputs_input", 0): tf_img_tensor, tf_model.Op("serving_default_rescaling_input", 0): inputImage,
}) })
predictions := results[0] predictions := results[0]

View File

@ -303,7 +303,7 @@
Image File Image File
</span> </span>
</button> </button>
<input id="file" name="file" type="file" required accept="application/zip" /> <input id="file" name="file" type="file" required accept="image/png" />
</div> </div>
</fieldset> </fieldset>
<button> <button>