feat: started working on the head part of the split models

This commit is contained in:
Andre Henriques 2024-02-14 15:11:45 +00:00
parent 508d43bc2f
commit b5a28a0bdb
4 changed files with 196 additions and 84 deletions

View File

@ -37,6 +37,62 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
return image.Scale(0, 255) return image.Scale(0, 255)
} }
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
order = 0
err = nil
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
})
var vmax float32 = 0.0
var predictions = results[0].Value().([][]float32)[0]
for i, v := range predictions {
if v > vmax {
order = i
vmax = v
}
}
return
}
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
err = nil
order = 0
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
//results := base_model.Exec([]tf.Output{
base_model.Exec([]tf.Output{
base_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
base_model.Op("serving_default_input_1", 0): inputImage,
})
type head struct {
Id string
Range_start int
}
heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id)
if err != nil {
return
}
// TODO runthe head model
c.Logger.Info("Got", "heads", len(heads))
return
}
func handleRun(handle *Handle) { func handleRun(handle *Handle) {
handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { handle.Post("/models/run", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) { if !CheckAuthLevel(1, w, r, c) {
@ -90,23 +146,18 @@ func handleRun(handle *Handle) {
return ErrorCode(nil, 400, c.AddMap(nil)) return ErrorCode(nil, 400, c.AddMap(nil))
} }
definitions_rows, err := handle.Db.Query("select id from model_definition where model_id=$1;", model.Id) def := JustId{}
if err != nil { err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id)
return Error500(err) if err == NotFoundError {
}
defer definitions_rows.Close()
if !definitions_rows.Next() {
// TODO improve this // TODO improve this
fmt.Printf("Could not find definition\n") fmt.Printf("Could not find definition\n")
return ErrorCode(nil, 400, c.AddMap(nil)) return ErrorCode(nil, 400, c.AddMap(nil))
} } else if err != nil {
var def_id string
if err = definitions_rows.Scan(&def_id); err != nil {
return Error500(err) return Error500(err)
} }
def_id := def.Id
// TODO create a database table with tasks // TODO create a database table with tasks
run_path := path.Join("/tmp", model.Id, "runs") run_path := path.Join("/tmp", model.Id, "runs")
os.MkdirAll(run_path, os.ModePerm) os.MkdirAll(run_path, os.ModePerm)
@ -148,29 +199,28 @@ func handleRun(handle *Handle) {
return Error500(err) return Error500(err)
} }
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) vi := -1
results := tf_model.Exec([]tf.Output{ if model.ModelType == 2 {
tf_model.Op("StatefulPartitionedCall", 0), c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
}, map[tf.Output]*tf.Tensor{ vi, err = runModelExp(c, model, def_id, inputImage)
tf_model.Op("serving_default_rescaling_input", 0): inputImage, if err != nil {
}) return c.Error500(err);
}
var vmax float32 = 0.0 } else {
vi := 0 c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
var predictions = results[0].Value().([][]float32)[0] vi, err = runModelNormal(c, model, def_id, inputImage)
if err != nil {
for i, v := range predictions { return c.Error500(err);
if v > vmax {
vi = i
vmax = v
} }
} }
os.RemoveAll(run_path) os.RemoveAll(run_path)
rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi) rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
if err != nil { return Error500(err) } if err != nil {
return Error500(err)
}
if !rows.Next() { if !rows.Next() {
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
"Model": model, "Model": model,
@ -181,7 +231,9 @@ func handleRun(handle *Handle) {
} }
var name string var name string
if err = rows.Scan(&name); err != nil { return nil } if err = rows.Scan(&name); err != nil {
return nil
}
LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{ LoadDefineTemplate(w, "/models/edit.html", "run-model-card", c.AddMap(AnyMap{
"Model": model, "Model": model,

View File

@ -785,22 +785,30 @@ func trainModelExp(c *Context, model *BaseModel) {
return return
} }
ModelUpdateStatus(c, model.Id, READY) // There should only be one def availabale
} def := JustId{}
func splitModel(c *Context, model *BaseModel) (err error) {
type Def struct {
Id string
}
def := Def{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
return return
} }
head := Def{} // Remove the base model
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
ModelUpdateStatus(c, model.Id, READY)
}
func splitModel(c *Context, model *BaseModel) (err error) {
def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
return
}
head := JustId{}
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil { if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
return return
@ -887,8 +895,6 @@ func splitModel(c *Context, model *BaseModel) (err error) {
return return
} }
func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id)
if err != nil { if err != nil {

View File

@ -10,6 +10,7 @@ type BaseModel struct {
Status int Status int
Id string Id string
ModelType int
ImageMode int ImageMode int
Width int Width int
Height int Height int
@ -54,7 +55,7 @@ const (
var ModelNotFoundError = errors.New("Model not found error") var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
rows, err := db.Query("select name, status, id, width, height, color_mode, format from models where id=$1;", id) rows, err := db.Query("select name, status, id, width, height, color_mode, format, model_type from models where id=$1;", id)
if err != nil { if err != nil {
return return
} }
@ -66,7 +67,7 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
base = &BaseModel{} base = &BaseModel{}
var colorMode string var colorMode string
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format) err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format, &base.ModelType)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -183,14 +184,14 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) {
return return
} }
type JustId struct { Id string }
type Generic struct{ reflect.Type } type Generic struct{ reflect.Type }
var NotFoundError = errors.New("Not found") var NotFoundError = errors.New("Not found")
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error { func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) {
t := reflect.TypeFor[T]()
t := reflect.TypeOf(store).Elem()
nargs := t.NumField() nargs := t.NumField()
query := "" query := ""
@ -199,10 +200,60 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err
query += strings.ToLower(t.Field(i).Name) + "," query += strings.ToLower(t.Field(i).Name) + ","
} }
// Remove the last comma
query = query[0 : len(query)-1] query = query[0 : len(query)-1]
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*T{}
for rows.Next() {
item := new(T)
if err = mapRow(item, rows, nargs); err != nil {
return nil, err
}
list = append(list, item)
}
return list, nil
}
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
err = nil
val := reflect.Indirect(reflect.ValueOf(store))
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Addr().Interface()
}
err = rows.Scan(scan_args...)
if err != nil {
return
}
return nil
}
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem()
nargs := t.NumField()
query := ""
for i := 0; i < nargs; i += 1 {
query += strings.ToLower(t.Field(i).Name) + ","
}
// Remove the last comma
query = query[0 : len(query)-1]
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil { if err != nil {
return err return err
} }
@ -212,6 +263,7 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err
return NotFoundError return NotFoundError
} }
err = nil
val := reflect.ValueOf(store).Elem() val := reflect.ValueOf(store).Elem()
scan_args := make([]interface{}, nargs); scan_args := make([]interface{}, nargs);
@ -227,3 +279,4 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err
return nil return nil
} }