chore: work on the expandable models
This commit is contained in:
parent
6a0ac457d7
commit
bca44f9ba5
@ -2,4 +2,6 @@
|
||||
|
||||
cd $(dirname "$0")
|
||||
|
||||
go run .
|
||||
go run . || true
|
||||
|
||||
while true; true; end
|
||||
|
@ -24,7 +24,7 @@ func handleEdit(handle *Handle) {
|
||||
}
|
||||
|
||||
// TODO handle admin users
|
||||
rows, err := handle.Db.Query("select name, status, width, height, color_mode, format from models where id=$1 and user_id=$2;", id, c.User.Id)
|
||||
rows, err := handle.Db.Query("select name, status, width, height, color_mode, format, model_type from models where id=$1 and user_id=$2;", id, c.User.Id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
@ -45,12 +45,13 @@ func handleEdit(handle *Handle) {
|
||||
Height *int
|
||||
Color_mode *string
|
||||
Format string
|
||||
Type int
|
||||
}
|
||||
|
||||
var model rowmodel = rowmodel{}
|
||||
model.Id = id
|
||||
|
||||
err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format)
|
||||
err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format, &model.Type)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
@ -124,6 +125,7 @@ func handleEdit(handle *Handle) {
|
||||
}
|
||||
|
||||
type layerdef struct {
|
||||
id string
|
||||
LayerType int
|
||||
Shape string
|
||||
}
|
||||
@ -132,7 +134,7 @@ func handleEdit(handle *Handle) {
|
||||
|
||||
for _, def := range defs {
|
||||
if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
|
||||
rows, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
|
||||
rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
@ -140,12 +142,63 @@ func handleEdit(handle *Handle) {
|
||||
|
||||
for rows.Next() {
|
||||
var layerdef layerdef
|
||||
err = rows.Scan(&layerdef.LayerType, &layerdef.Shape)
|
||||
err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
layers = append(layers, layerdef)
|
||||
}
|
||||
|
||||
if model.Type == 2 {
|
||||
|
||||
type lastLayerType struct {
|
||||
Id string
|
||||
Range_start int
|
||||
Range_end int
|
||||
}
|
||||
|
||||
var lastLayer lastLayerType
|
||||
|
||||
err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
c.Logger.Info("res", "id", lastLayer.Id, "start", lastLayer.Range_start, "end", lastLayer.Range_end)
|
||||
|
||||
layers = append(layers, layerdef{
|
||||
id: lastLayer.Id,
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start),
|
||||
})
|
||||
|
||||
/*
|
||||
lastLayer, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status=3;", def.Id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
defer lastLayer.Close()
|
||||
|
||||
if !lastLayer.Next() {
|
||||
c.Logger.Info("Could not find the model head for", "def_id", def.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
head_id, range_start, range_end := "", 0, 0
|
||||
err = lastLayer.Scan(&head_id, &range_start, &range_end)
|
||||
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
layers = append(layers, layerdef{
|
||||
id: head_id,
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d, 1", range_end-range_start),
|
||||
})
|
||||
*/
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
// Get untrained models heads
|
||||
|
||||
// Status = 2 (INIT)
|
||||
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status = 2", definition_id)
|
||||
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -231,7 +231,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
exp := ExpHead{}
|
||||
|
||||
if rows.Next() {
|
||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil {
|
||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@ -246,7 +246,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
return
|
||||
}
|
||||
|
||||
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRANIED)
|
||||
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
|
||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
@ -949,7 +949,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
|
||||
}
|
||||
|
||||
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
|
||||
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
||||
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
@ -977,6 +977,8 @@ func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatu
|
||||
|
||||
// This generates a definition
|
||||
func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
|
||||
c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity)
|
||||
|
||||
var err error = nil
|
||||
failed := func() *Error {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
@ -1018,9 +1020,14 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
|
||||
|
||||
// Create the blocks
|
||||
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
|
||||
if loop == 0 {
|
||||
loop = 1
|
||||
}
|
||||
|
||||
if model.Width < 50 && model.Height < 50 {
|
||||
loop = 0
|
||||
}
|
||||
|
||||
log.Info("Size of the simple block", "loop", loop)
|
||||
|
||||
//loop = max(loop, 3)
|
||||
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
|
||||
@ -1045,9 +1052,10 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
|
||||
order++
|
||||
|
||||
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
|
||||
if loop == 0 {
|
||||
loop = 1
|
||||
}
|
||||
|
||||
log.Info("Size of the dense layers", "loop", loop)
|
||||
|
||||
// loop = max(loop, 3)
|
||||
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
||||
@ -1056,7 +1064,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
|
||||
return failed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
_, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
return failed()
|
||||
@ -1131,11 +1139,6 @@ func handleTrain(handle *Handle) {
|
||||
|
||||
if model_type_form == "expandable" {
|
||||
model_type_id = 2
|
||||
c.Logger.Warn("TODO: handle expandable")
|
||||
return c.Error400(nil, "TODO: handle expandable!", w, "/models/edit.html", "train-model-card", AnyMap{
|
||||
"HasData": true,
|
||||
"ErrorMessage": "TODO: handle expandable!",
|
||||
})
|
||||
} else if model_type_form != "simple" {
|
||||
return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{
|
||||
"HasData": true,
|
||||
|
@ -423,6 +423,7 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
||||
var token *string
|
||||
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: r.URL.Path,
|
||||
|
@ -7,7 +7,9 @@ import (
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -17,37 +19,37 @@ func CheckEmpty(f url.Values, path string) bool {
|
||||
}
|
||||
|
||||
func CheckNumber(f url.Values, path string, number *int) bool {
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.Atoi(f.Get(path))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.Atoi(f.Get(path))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckFloat64(f url.Values, path string, number *float64) bool {
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.ParseFloat(f.Get(path), 64)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.ParseFloat(f.Get(path), 64)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckId(f url.Values, path string) bool {
|
||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||
}
|
||||
|
||||
func IsValidUUID(u string) bool {
|
||||
@ -57,19 +59,19 @@ func IsValidUUID(u string) bool {
|
||||
|
||||
func GetIdFromUrl(r *http.Request, target string) (string, error) {
|
||||
if !r.URL.Query().Has(target) {
|
||||
return "", errors.New("Query does not have " + target)
|
||||
return "", errors.New("Query does not have " + target)
|
||||
}
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
if len(id) == 0 {
|
||||
return "", errors.New("Query is empty for " + target)
|
||||
return "", errors.New("Query is empty for " + target)
|
||||
}
|
||||
|
||||
if !IsValidUUID(id) {
|
||||
return "", errors.New("Value of query is not a valid uuid for " + target)
|
||||
return "", errors.New("Value of query is not a valid uuid for " + target)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
return id, nil
|
||||
}
|
||||
|
||||
type maxBytesReader struct {
|
||||
@ -180,3 +182,48 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Generic struct{ reflect.Type }
|
||||
|
||||
var NotFoundError = errors.New("Not found")
|
||||
|
||||
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
||||
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
nargs := t.NumField()
|
||||
|
||||
query := ""
|
||||
|
||||
for i := 0; i < nargs; i += 1 {
|
||||
query += strings.ToLower(t.Field(i).Name) + ","
|
||||
}
|
||||
|
||||
query = query[0 : len(query)-1]
|
||||
|
||||
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return NotFoundError
|
||||
}
|
||||
|
||||
|
||||
val := reflect.ValueOf(store).Elem()
|
||||
scan_args := make([]interface{}, nargs);
|
||||
for i := 0; i < nargs; i++ {
|
||||
valueField := val.Field(i)
|
||||
scan_args[i] = valueField.Addr().Interface()
|
||||
}
|
||||
|
||||
err = rows.Scan(scan_args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user