feat: started working on the head part of the split models
This commit is contained in:
@@ -37,6 +37,62 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
||||
return image.Scale(0, 255)
|
||||
}
|
||||
|
||||
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
|
||||
order = 0
|
||||
err = nil
|
||||
|
||||
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||
|
||||
results := tf_model.Exec([]tf.Output{
|
||||
tf_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||
})
|
||||
|
||||
var vmax float32 = 0.0
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
for i, v := range predictions {
|
||||
if v > vmax {
|
||||
order = i
|
||||
vmax = v
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
|
||||
|
||||
err = nil
|
||||
order = 0
|
||||
|
||||
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
|
||||
|
||||
//results := base_model.Exec([]tf.Output{
|
||||
base_model.Exec([]tf.Output{
|
||||
base_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||
base_model.Op("serving_default_input_1", 0): inputImage,
|
||||
})
|
||||
|
||||
type head struct {
|
||||
Id string
|
||||
Range_start int
|
||||
}
|
||||
|
||||
heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO runthe head model
|
||||
|
||||
c.Logger.Info("Got", "heads", len(heads))
|
||||
return
|
||||
}
|
||||
|
||||
func handleRun(handle *Handle) {
|
||||
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
@@ -90,27 +146,22 @@ func handleRun(handle *Handle) {
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer definitions_rows.Close()
|
||||
|
||||
if !definitions_rows.Next() {
|
||||
def := JustId{}
|
||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id)
|
||||
if err == NotFoundError {
|
||||
// TODO improve this
|
||||
fmt.Printf("Could not find definition\n")
|
||||
fmt.Printf("Could not find definition\n")
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
var def_id string
|
||||
if err = definitions_rows.Scan(&def_id); err != nil {
|
||||
} else if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
def_id := def.Id
|
||||
|
||||
// TODO create a database table with tasks
|
||||
run_path := path.Join("/tmp", model.Id, "runs")
|
||||
os.MkdirAll(run_path, os.ModePerm)
|
||||
img_path := path.Join(run_path, "img." + model.Format)
|
||||
img_path := path.Join(run_path, "img."+model.Format)
|
||||
|
||||
img_file, err := os.Create(img_path)
|
||||
if err != nil {
|
||||
@@ -119,74 +170,75 @@ func handleRun(handle *Handle) {
|
||||
defer img_file.Close()
|
||||
img_file.Write(file)
|
||||
|
||||
if !testImgForModel(c, model, img_path) {
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"NotFound": false,
|
||||
"Result": nil,
|
||||
"ImageError": true,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
if !testImgForModel(c, model, img_path) {
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"NotFound": false,
|
||||
"Result": nil,
|
||||
"ImageError": true,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
root := tg.NewRoot()
|
||||
|
||||
var tf_img *image.Image = nil
|
||||
|
||||
switch model.Format {
|
||||
case "png":
|
||||
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
|
||||
case "jpeg":
|
||||
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
|
||||
default:
|
||||
panic("Not sure what to do with '" + model.Format + "'")
|
||||
}
|
||||
var tf_img *image.Image = nil
|
||||
|
||||
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
||||
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
switch model.Format {
|
||||
case "png":
|
||||
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
|
||||
case "jpeg":
|
||||
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
|
||||
default:
|
||||
panic("Not sure what to do with '" + model.Format + "'")
|
||||
}
|
||||
|
||||
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
||||
inputImage, err := tf.NewTensor(exec_results[0].Value())
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
results := tf_model.Exec([]tf.Output{
|
||||
tf_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||
})
|
||||
vi := -1
|
||||
|
||||
var vmax float32 = 0.0
|
||||
vi := 0
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
for i, v := range predictions {
|
||||
if v > vmax {
|
||||
vi = i
|
||||
vmax = v
|
||||
if model.ModelType == 2 {
|
||||
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, err = runModelExp(c, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
return c.Error500(err);
|
||||
}
|
||||
} else {
|
||||
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, err = runModelNormal(c, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
return c.Error500(err);
|
||||
}
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
os.RemoveAll(run_path)
|
||||
|
||||
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
|
||||
if err != nil { return Error500(err) }
|
||||
if !rows.Next() {
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"NotFound": true,
|
||||
"Result": nil,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"NotFound": true,
|
||||
"Result": nil,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
var name string
|
||||
if err = rows.Scan(&name); err != nil { return nil }
|
||||
var name string
|
||||
if err = rows.Scan(&name); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Result": name,
|
||||
}))
|
||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Result": name,
|
||||
}))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user