From 46705ba5d9f2bb9155c818d15b20b0e6e2082c53 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Thu, 19 Oct 2023 11:42:38 +0100 Subject: [PATCH] chore: started working on #32 --- logic/models/train/train.go | 139 ++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 61 deletions(-) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index d5d24dc..5a1bb10 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math" "net/http" "os" "os/exec" @@ -17,7 +18,7 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) -const EPOCH_PER_RUN = 20; +const EPOCH_PER_RUN = 20 const MAX_EPOCH = 100 func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { @@ -411,6 +412,78 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) { return } +// This generates a definition +func generateDefinition(c *Context, model *BaseModel, number_of_classes int, complexity int) *Error { + var err error = nil + failed := func() *Error { + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return c.Error500(err) + } + + + def_id, err := MakeDefenition(c.Db, model.Id, 0) + if err != nil { + return failed() + } + + // Note the shape for now is no used + err = MakeLayer(c.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) + if err != nil { + return failed() + } + + if complexity == 0 { + + err = MakeLayer(c.Db, def_id, 4, LAYER_FLATTEN, "") + if err != nil { + return failed() + } + + loop := int(math.Log2(float64(number_of_classes))) + for i := 0; i < loop; i++ { + err = MakeLayer(c.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop - i))) + if err != nil { + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return c.Error500(err) + } + } + + } else { + c.Logger.Error("Unkown complexity", "complexity", complexity) + return failed() + } + + err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT) + if err != nil { + return failed() + } + + return nil +} + +func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Error { + cls, err := model_classes.ListClasses(c.Db, model.Id) + if err != nil { + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) + // TODO improve this response + return c.Error500(err) + } + + err = removeFailedDataPoints(c.Db, model) + if err != nil { + return c.Error500(err) + } + + for i := 0; i < number_of_models; i++ { + // TODO handle incrisea the complexity + generateDefinition(c, model, len(cls), 0) + } + + return nil +} + func handleTrain(handle *Handle) { handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { @@ -473,67 +546,11 @@ func handleTrain(handle *Handle) { return ErrorCode(nil, 400, c.AddMap(nil)) } - cls, err := model_classes.ListClasses(handle.Db, model.Id) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } + full_error := generateDefinitions(c, model, number_of_models) + if full_error != nil { + return full_error + } - err = removeFailedDataPoints(c.Db, model) - if err != nil { - return c.Error500(err) - } - - var fid string - for i := 0; i < number_of_models; i++ { - def_id, err := MakeDefenition(handle.Db, model.Id, accuracy) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } - - if fid == "" { - fid = def_id - } - - // TODO change shape of it depends on the type of the image - err = MakeLayer(handle.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } - err = MakeLayer(handle.Db, def_id, 4, LAYER_FLATTEN, fmt.Sprintf("%d,1", len(cls))) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } - err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls)*3)) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } - // Using sparce - err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls))) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } - - err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT) - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - // TODO improve this response - return c.Error500(err) - } - } - - // TODO start training with id fid go trainModel(c, model) ModelUpdateStatus(c, model.Id, TRAINING)