chore: work on the expandable models
This commit is contained in:
parent
6a0ac457d7
commit
bca44f9ba5
@ -2,4 +2,6 @@
|
|||||||
|
|
||||||
cd $(dirname "$0")
|
cd $(dirname "$0")
|
||||||
|
|
||||||
go run .
|
go run . || true
|
||||||
|
|
||||||
|
while true; true; end
|
||||||
|
@ -24,7 +24,7 @@ func handleEdit(handle *Handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO handle admin users
|
// TODO handle admin users
|
||||||
rows, err := handle.Db.Query("select name, status, width, height, color_mode, format from models where id=$1 and user_id=$2;", id, c.User.Id)
|
rows, err := handle.Db.Query("select name, status, width, height, color_mode, format, model_type from models where id=$1 and user_id=$2;", id, c.User.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
@ -45,12 +45,13 @@ func handleEdit(handle *Handle) {
|
|||||||
Height *int
|
Height *int
|
||||||
Color_mode *string
|
Color_mode *string
|
||||||
Format string
|
Format string
|
||||||
|
Type int
|
||||||
}
|
}
|
||||||
|
|
||||||
var model rowmodel = rowmodel{}
|
var model rowmodel = rowmodel{}
|
||||||
model.Id = id
|
model.Id = id
|
||||||
|
|
||||||
err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format)
|
err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format, &model.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
@ -124,6 +125,7 @@ func handleEdit(handle *Handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type layerdef struct {
|
type layerdef struct {
|
||||||
|
id string
|
||||||
LayerType int
|
LayerType int
|
||||||
Shape string
|
Shape string
|
||||||
}
|
}
|
||||||
@ -132,7 +134,7 @@ func handleEdit(handle *Handle) {
|
|||||||
|
|
||||||
for _, def := range defs {
|
for _, def := range defs {
|
||||||
if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
|
if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
|
||||||
rows, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
|
rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
@ -140,12 +142,63 @@ func handleEdit(handle *Handle) {
|
|||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var layerdef layerdef
|
var layerdef layerdef
|
||||||
err = rows.Scan(&layerdef.LayerType, &layerdef.Shape)
|
err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
layers = append(layers, layerdef)
|
layers = append(layers, layerdef)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if model.Type == 2 {
|
||||||
|
|
||||||
|
type lastLayerType struct {
|
||||||
|
Id string
|
||||||
|
Range_start int
|
||||||
|
Range_end int
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastLayer lastLayerType
|
||||||
|
|
||||||
|
err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("res", "id", lastLayer.Id, "start", lastLayer.Range_start, "end", lastLayer.Range_end)
|
||||||
|
|
||||||
|
layers = append(layers, layerdef{
|
||||||
|
id: lastLayer.Id,
|
||||||
|
LayerType: LAYER_DENSE,
|
||||||
|
Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start),
|
||||||
|
})
|
||||||
|
|
||||||
|
/*
|
||||||
|
lastLayer, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status=3;", def.Id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
defer lastLayer.Close()
|
||||||
|
|
||||||
|
if !lastLayer.Next() {
|
||||||
|
c.Logger.Info("Could not find the model head for", "def_id", def.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
head_id, range_start, range_end := "", 0, 0
|
||||||
|
err = lastLayer.Scan(&head_id, &range_start, &range_end)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers = append(layers, layerdef{
|
||||||
|
id: head_id,
|
||||||
|
LayerType: LAYER_DENSE,
|
||||||
|
Shape: fmt.Sprintf("%d, 1", range_end-range_start),
|
||||||
|
})
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,7 +216,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
// Get untrained models heads
|
// Get untrained models heads
|
||||||
|
|
||||||
// Status = 2 (INIT)
|
// Status = 2 (INIT)
|
||||||
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status = 2", definition_id)
|
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -231,7 +231,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
exp := ExpHead{}
|
exp := ExpHead{}
|
||||||
|
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err == nil {
|
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -246,7 +246,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRANIED)
|
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||||
|
|
||||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -949,7 +949,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
|
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
|
||||||
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -977,6 +977,8 @@ func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatu
|
|||||||
|
|
||||||
// This generates a definition
|
// This generates a definition
|
||||||
func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
|
func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
|
||||||
|
c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity)
|
||||||
|
|
||||||
var err error = nil
|
var err error = nil
|
||||||
failed := func() *Error {
|
failed := func() *Error {
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
@ -1018,9 +1020,14 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
|
|||||||
|
|
||||||
// Create the blocks
|
// Create the blocks
|
||||||
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
|
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
|
||||||
if loop == 0 {
|
|
||||||
loop = 1
|
if model.Width < 50 && model.Height < 50 {
|
||||||
}
|
loop = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Size of the simple block", "loop", loop)
|
||||||
|
|
||||||
|
//loop = max(loop, 3)
|
||||||
|
|
||||||
for i := 0; i < loop; i++ {
|
for i := 0; i < loop; i++ {
|
||||||
err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
|
err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
|
||||||
@ -1045,9 +1052,10 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
|
|||||||
order++
|
order++
|
||||||
|
|
||||||
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
|
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
|
||||||
if loop == 0 {
|
|
||||||
loop = 1
|
log.Info("Size of the dense layers", "loop", loop)
|
||||||
}
|
|
||||||
|
// loop = max(loop, 3)
|
||||||
|
|
||||||
for i := 0; i < loop; i++ {
|
for i := 0; i < loop; i++ {
|
||||||
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
||||||
@ -1056,7 +1064,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
|
|||||||
return failed()
|
return failed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT)
|
_, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return failed()
|
return failed()
|
||||||
@ -1131,11 +1139,6 @@ func handleTrain(handle *Handle) {
|
|||||||
|
|
||||||
if model_type_form == "expandable" {
|
if model_type_form == "expandable" {
|
||||||
model_type_id = 2
|
model_type_id = 2
|
||||||
c.Logger.Warn("TODO: handle expandable")
|
|
||||||
return c.Error400(nil, "TODO: handle expandable!", w, "/models/edit.html", "train-model-card", AnyMap{
|
|
||||||
"HasData": true,
|
|
||||||
"ErrorMessage": "TODO: handle expandable!",
|
|
||||||
})
|
|
||||||
} else if model_type_form != "simple" {
|
} else if model_type_form != "simple" {
|
||||||
return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{
|
return c.Error400(nil, "Invalid model type!", w, "/models/edit.html", "train-model-card", AnyMap{
|
||||||
"HasData": true,
|
"HasData": true,
|
||||||
|
@ -423,6 +423,7 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
|||||||
var token *string
|
var token *string
|
||||||
|
|
||||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||||
|
ReportCaller: true,
|
||||||
ReportTimestamp: true,
|
ReportTimestamp: true,
|
||||||
TimeFormat: time.Kitchen,
|
TimeFormat: time.Kitchen,
|
||||||
Prefix: r.URL.Path,
|
Prefix: r.URL.Path,
|
||||||
|
@ -7,7 +7,9 @@ import (
|
|||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -17,37 +19,37 @@ func CheckEmpty(f url.Values, path string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CheckNumber(f url.Values, path string, number *int) bool {
|
func CheckNumber(f url.Values, path string, number *int) bool {
|
||||||
if CheckEmpty(f, path) {
|
if CheckEmpty(f, path) {
|
||||||
fmt.Println("here", path)
|
fmt.Println("here", path)
|
||||||
fmt.Println(f.Get(path))
|
fmt.Println(f.Get(path))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
n, err := strconv.Atoi(f.Get(path))
|
n, err := strconv.Atoi(f.Get(path))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
*number = n
|
*number = n
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckFloat64(f url.Values, path string, number *float64) bool {
|
func CheckFloat64(f url.Values, path string, number *float64) bool {
|
||||||
if CheckEmpty(f, path) {
|
if CheckEmpty(f, path) {
|
||||||
fmt.Println("here", path)
|
fmt.Println("here", path)
|
||||||
fmt.Println(f.Get(path))
|
fmt.Println(f.Get(path))
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
n, err := strconv.ParseFloat(f.Get(path), 64)
|
n, err := strconv.ParseFloat(f.Get(path), 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
*number = n
|
*number = n
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckId(f url.Values, path string) bool {
|
func CheckId(f url.Values, path string) bool {
|
||||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsValidUUID(u string) bool {
|
func IsValidUUID(u string) bool {
|
||||||
@ -57,19 +59,19 @@ func IsValidUUID(u string) bool {
|
|||||||
|
|
||||||
func GetIdFromUrl(r *http.Request, target string) (string, error) {
|
func GetIdFromUrl(r *http.Request, target string) (string, error) {
|
||||||
if !r.URL.Query().Has(target) {
|
if !r.URL.Query().Has(target) {
|
||||||
return "", errors.New("Query does not have " + target)
|
return "", errors.New("Query does not have " + target)
|
||||||
}
|
}
|
||||||
|
|
||||||
id := r.URL.Query().Get("id")
|
id := r.URL.Query().Get("id")
|
||||||
if len(id) == 0 {
|
if len(id) == 0 {
|
||||||
return "", errors.New("Query is empty for " + target)
|
return "", errors.New("Query is empty for " + target)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !IsValidUUID(id) {
|
if !IsValidUUID(id) {
|
||||||
return "", errors.New("Value of query is not a valid uuid for " + target)
|
return "", errors.New("Value of query is not a valid uuid for " + target)
|
||||||
}
|
}
|
||||||
|
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type maxBytesReader struct {
|
type maxBytesReader struct {
|
||||||
@ -180,3 +182,48 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Generic struct{ reflect.Type }
|
||||||
|
|
||||||
|
var NotFoundError = errors.New("Not found")
|
||||||
|
|
||||||
|
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
||||||
|
|
||||||
|
t := reflect.TypeOf(store).Elem()
|
||||||
|
|
||||||
|
nargs := t.NumField()
|
||||||
|
|
||||||
|
query := ""
|
||||||
|
|
||||||
|
for i := 0; i < nargs; i += 1 {
|
||||||
|
query += strings.ToLower(t.Field(i).Name) + ","
|
||||||
|
}
|
||||||
|
|
||||||
|
query = query[0 : len(query)-1]
|
||||||
|
|
||||||
|
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
return NotFoundError
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val := reflect.ValueOf(store).Elem()
|
||||||
|
scan_args := make([]interface{}, nargs);
|
||||||
|
for i := 0; i < nargs; i++ {
|
||||||
|
valueField := val.Field(i)
|
||||||
|
scan_args[i] = valueField.Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(scan_args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user