Fixed the model not training and running forever
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||
"github.com/galeone/tensorflow/tensorflow/go/op"
|
||||
tg "github.com/galeone/tfgo"
|
||||
@@ -19,6 +20,7 @@ func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
||||
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
|
||||
output := op.DecodePng(scope.SubScope("DecodePng"), contents, op.DecodePngChannels(channels))
|
||||
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
|
||||
output = op.ExpandDims(scope.SubScope("Stack"), output, op.Const(scope.SubScope("axis"), []int32{1}))
|
||||
image := &image.Image{
|
||||
Tensor: tg.NewTensor(scope, output)}
|
||||
return image.Scale(0, 255)
|
||||
@@ -29,6 +31,7 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
||||
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
|
||||
output := op.DecodePng(scope.SubScope("DecodeJpeg"), contents, op.DecodePngChannels(channels))
|
||||
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
|
||||
output = op.ExpandDims(scope.SubScope("Stack"), output, op.Const(scope.SubScope("axis"), []int32{1}))
|
||||
image := &image.Image{
|
||||
Tensor: tg.NewTensor(scope, output)}
|
||||
return image.Scale(0, 255)
|
||||
@@ -49,6 +52,8 @@ func runModelNormal(base BasePack, model *BaseModel, def_id string, inputImage *
|
||||
var vmax float32 = 0.0
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
log.Info("preds", "preds", predictions)
|
||||
|
||||
for i, v := range predictions {
|
||||
if v > vmax {
|
||||
order = i
|
||||
@@ -62,10 +67,13 @@ func runModelNormal(base BasePack, model *BaseModel, def_id string, inputImage *
|
||||
}
|
||||
|
||||
func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
|
||||
log := base.GetLogger()
|
||||
|
||||
err = nil
|
||||
order = 0
|
||||
|
||||
log.Info("Running base")
|
||||
|
||||
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
|
||||
|
||||
//results := base_model.Exec([]tf.Output{
|
||||
@@ -86,7 +94,7 @@ func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.
|
||||
return
|
||||
}
|
||||
|
||||
base.GetLogger().Info("test", "count", len(heads))
|
||||
log.Info("Running heads", "heads", heads)
|
||||
|
||||
var vmax float32 = 0.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user