feat: started working on the head part of the split models
This commit is contained in:
parent
508d43bc2f
commit
b5a28a0bdb
@ -37,6 +37,62 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
|||||||
return image.Scale(0, 255)
|
return image.Scale(0, 255)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
|
||||||
|
order = 0
|
||||||
|
err = nil
|
||||||
|
|
||||||
|
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||||
|
|
||||||
|
results := tf_model.Exec([]tf.Output{
|
||||||
|
tf_model.Op("StatefulPartitionedCall", 0),
|
||||||
|
}, map[tf.Output]*tf.Tensor{
|
||||||
|
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||||
|
})
|
||||||
|
|
||||||
|
var vmax float32 = 0.0
|
||||||
|
var predictions = results[0].Value().([][]float32)[0]
|
||||||
|
|
||||||
|
for i, v := range predictions {
|
||||||
|
if v > vmax {
|
||||||
|
order = i
|
||||||
|
vmax = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
|
||||||
|
|
||||||
|
err = nil
|
||||||
|
order = 0
|
||||||
|
|
||||||
|
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
|
||||||
|
|
||||||
|
//results := base_model.Exec([]tf.Output{
|
||||||
|
base_model.Exec([]tf.Output{
|
||||||
|
base_model.Op("StatefulPartitionedCall", 0),
|
||||||
|
}, map[tf.Output]*tf.Tensor{
|
||||||
|
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
|
||||||
|
base_model.Op("serving_default_input_1", 0): inputImage,
|
||||||
|
})
|
||||||
|
|
||||||
|
type head struct {
|
||||||
|
Id string
|
||||||
|
Range_start int
|
||||||
|
}
|
||||||
|
|
||||||
|
heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO runthe head model
|
||||||
|
|
||||||
|
c.Logger.Info("Got", "heads", len(heads))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func handleRun(handle *Handle) {
|
func handleRun(handle *Handle) {
|
||||||
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
if !CheckAuthLevel(1, w, r, c) {
|
if !CheckAuthLevel(1, w, r, c) {
|
||||||
@ -90,27 +146,22 @@ func handleRun(handle *Handle) {
|
|||||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id)
|
def := JustId{}
|
||||||
if err != nil {
|
err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id)
|
||||||
return Error500(err)
|
if err == NotFoundError {
|
||||||
}
|
|
||||||
defer definitions_rows.Close()
|
|
||||||
|
|
||||||
if !definitions_rows.Next() {
|
|
||||||
// TODO improve this
|
// TODO improve this
|
||||||
fmt.Printf("Could not find definition\n")
|
fmt.Printf("Could not find definition\n")
|
||||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
}
|
} else if err != nil {
|
||||||
|
|
||||||
var def_id string
|
|
||||||
if err = definitions_rows.Scan(&def_id); err != nil {
|
|
||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def_id := def.Id
|
||||||
|
|
||||||
// TODO create a database table with tasks
|
// TODO create a database table with tasks
|
||||||
run_path := path.Join("/tmp", model.Id, "runs")
|
run_path := path.Join("/tmp", model.Id, "runs")
|
||||||
os.MkdirAll(run_path, os.ModePerm)
|
os.MkdirAll(run_path, os.ModePerm)
|
||||||
img_path := path.Join(run_path, "img." + model.Format)
|
img_path := path.Join(run_path, "img."+model.Format)
|
||||||
|
|
||||||
img_file, err := os.Create(img_path)
|
img_file, err := os.Create(img_path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -119,74 +170,75 @@ func handleRun(handle *Handle) {
|
|||||||
defer img_file.Close()
|
defer img_file.Close()
|
||||||
img_file.Write(file)
|
img_file.Write(file)
|
||||||
|
|
||||||
if !testImgForModel(c, model, img_path) {
|
if !testImgForModel(c, model, img_path) {
|
||||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||||
"Model": model,
|
"Model": model,
|
||||||
"NotFound": false,
|
"NotFound": false,
|
||||||
"Result": nil,
|
"Result": nil,
|
||||||
"ImageError": true,
|
"ImageError": true,
|
||||||
}))
|
}))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
root := tg.NewRoot()
|
root := tg.NewRoot()
|
||||||
|
|
||||||
var tf_img *image.Image = nil
|
var tf_img *image.Image = nil
|
||||||
|
|
||||||
switch model.Format {
|
|
||||||
case "png":
|
|
||||||
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
|
|
||||||
case "jpeg":
|
|
||||||
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
|
|
||||||
default:
|
|
||||||
panic("Not sure what to do with '" + model.Format + "'")
|
|
||||||
}
|
|
||||||
|
|
||||||
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
switch model.Format {
|
||||||
inputImage, err:= tf.NewTensor(exec_results[0].Value())
|
case "png":
|
||||||
if err != nil {
|
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
|
||||||
return Error500(err)
|
case "jpeg":
|
||||||
}
|
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
|
||||||
|
default:
|
||||||
|
panic("Not sure what to do with '" + model.Format + "'")
|
||||||
|
}
|
||||||
|
|
||||||
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
|
||||||
|
inputImage, err := tf.NewTensor(exec_results[0].Value())
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
results := tf_model.Exec([]tf.Output{
|
vi := -1
|
||||||
tf_model.Op("StatefulPartitionedCall", 0),
|
|
||||||
}, map[tf.Output]*tf.Tensor{
|
|
||||||
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
|
|
||||||
})
|
|
||||||
|
|
||||||
var vmax float32 = 0.0
|
if model.ModelType == 2 {
|
||||||
vi := 0
|
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||||
var predictions = results[0].Value().([][]float32)[0]
|
vi, err = runModelExp(c, model, def_id, inputImage)
|
||||||
|
if err != nil {
|
||||||
for i, v := range predictions {
|
return c.Error500(err);
|
||||||
if v > vmax {
|
}
|
||||||
vi = i
|
} else {
|
||||||
vmax = v
|
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||||
|
vi, err = runModelNormal(c, model, def_id, inputImage)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
os.RemoveAll(run_path)
|
os.RemoveAll(run_path)
|
||||||
|
|
||||||
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
|
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
|
||||||
if err != nil { return Error500(err) }
|
if err != nil {
|
||||||
if !rows.Next() {
|
return Error500(err)
|
||||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
}
|
||||||
"Model": model,
|
if !rows.Next() {
|
||||||
"NotFound": true,
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||||
"Result": nil,
|
"Model": model,
|
||||||
}))
|
"NotFound": true,
|
||||||
return nil
|
"Result": nil,
|
||||||
}
|
}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
if err = rows.Scan(&name); err != nil { return nil }
|
if err = rows.Scan(&name); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
|
||||||
"Model": model,
|
"Model": model,
|
||||||
"Result": name,
|
"Result": name,
|
||||||
}))
|
}))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -784,23 +784,31 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
failed("Failed to split the model")
|
failed("Failed to split the model")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// There should only be one def availabale
|
||||||
|
def := JustId{}
|
||||||
|
|
||||||
|
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the base model
|
||||||
|
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
|
||||||
|
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
|
||||||
|
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
|
||||||
|
|
||||||
ModelUpdateStatus(c, model.Id, READY)
|
ModelUpdateStatus(c, model.Id, READY)
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitModel(c *Context, model *BaseModel) (err error) {
|
func splitModel(c *Context, model *BaseModel) (err error) {
|
||||||
|
|
||||||
type Def struct {
|
def := JustId{}
|
||||||
Id string
|
|
||||||
}
|
|
||||||
|
|
||||||
def := Def{}
|
|
||||||
|
|
||||||
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
head := Def{}
|
head := JustId{}
|
||||||
|
|
||||||
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
|
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
|
||||||
return
|
return
|
||||||
@ -887,8 +895,6 @@ func splitModel(c *Context, model *BaseModel) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
|
func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
|
||||||
rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id)
|
rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -6,10 +6,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type BaseModel struct {
|
type BaseModel struct {
|
||||||
Name string
|
Name string
|
||||||
Status int
|
Status int
|
||||||
Id string
|
Id string
|
||||||
|
|
||||||
|
ModelType int
|
||||||
ImageMode int
|
ImageMode int
|
||||||
Width int
|
Width int
|
||||||
Height int
|
Height int
|
||||||
@ -54,7 +55,7 @@ const (
|
|||||||
var ModelNotFoundError = errors.New("Model not found error")
|
var ModelNotFoundError = errors.New("Model not found error")
|
||||||
|
|
||||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||||
rows, err := db.Query("select name, status, id, width, height, color_mode, format from models where id=$1;", id)
|
rows, err := db.Query("select name, status, id, width, height, color_mode, format, model_type from models where id=$1;", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -66,7 +67,7 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
|||||||
|
|
||||||
base = &BaseModel{}
|
base = &BaseModel{}
|
||||||
var colorMode string
|
var colorMode string
|
||||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format)
|
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format, &base.ModelType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -183,14 +184,64 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JustId struct { Id string }
|
||||||
|
|
||||||
type Generic struct{ reflect.Type }
|
type Generic struct{ reflect.Type }
|
||||||
|
|
||||||
var NotFoundError = errors.New("Not found")
|
var NotFoundError = errors.New("Not found")
|
||||||
|
|
||||||
|
func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) {
|
||||||
|
t := reflect.TypeFor[T]()
|
||||||
|
nargs := t.NumField()
|
||||||
|
|
||||||
|
query := ""
|
||||||
|
|
||||||
|
for i := 0; i < nargs; i += 1 {
|
||||||
|
query += strings.ToLower(t.Field(i).Name) + ","
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the last comma
|
||||||
|
query = query[0 : len(query)-1]
|
||||||
|
|
||||||
|
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
list := []*T{}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
item := new(T)
|
||||||
|
if err = mapRow(item, rows, nargs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
list = append(list, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
return list, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||||
|
err = nil
|
||||||
|
|
||||||
|
val := reflect.Indirect(reflect.ValueOf(store))
|
||||||
|
scan_args := make([]interface{}, nargs);
|
||||||
|
for i := 0; i < nargs; i++ {
|
||||||
|
valueField := val.Field(i)
|
||||||
|
scan_args[i] = valueField.Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(scan_args...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
||||||
|
|
||||||
t := reflect.TypeOf(store).Elem()
|
t := reflect.TypeOf(store).Elem()
|
||||||
|
|
||||||
nargs := t.NumField()
|
nargs := t.NumField()
|
||||||
|
|
||||||
query := ""
|
query := ""
|
||||||
@ -199,10 +250,10 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err
|
|||||||
query += strings.ToLower(t.Field(i).Name) + ","
|
query += strings.ToLower(t.Field(i).Name) + ","
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove the last comma
|
||||||
query = query[0 : len(query)-1]
|
query = query[0 : len(query)-1]
|
||||||
|
|
||||||
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -212,6 +263,7 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err
|
|||||||
return NotFoundError
|
return NotFoundError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nil
|
||||||
|
|
||||||
val := reflect.ValueOf(store).Elem()
|
val := reflect.ValueOf(store).Elem()
|
||||||
scan_args := make([]interface{}, nargs);
|
scan_args := make([]interface{}, nargs);
|
||||||
@ -227,3 +279,4 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user