chore: started working on #32

This commit is contained in:
Andre Henriques 2023-10-19 11:42:38 +01:00
parent 2c3539b81a
commit 46705ba5d9

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@ -17,7 +18,7 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
) )
const EPOCH_PER_RUN = 20; const EPOCH_PER_RUN = 20
const MAX_EPOCH = 100 const MAX_EPOCH = 100
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
@ -411,6 +412,78 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
return return
} }
// This generates a definition
func generateDefinition(c *Context, model *BaseModel, number_of_classes int, complexity int) *Error {
var err error = nil
failed := func() *Error {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
def_id, err := MakeDefenition(c.Db, model.Id, 0)
if err != nil {
return failed()
}
// Note the shape for now is no used
err = MakeLayer(c.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil {
return failed()
}
if complexity == 0 {
err = MakeLayer(c.Db, def_id, 4, LAYER_FLATTEN, "")
if err != nil {
return failed()
}
loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop - i)))
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
}
} else {
c.Logger.Error("Unkown complexity", "complexity", complexity)
return failed()
}
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
}
return nil
}
func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Error {
cls, err := model_classes.ListClasses(c.Db, model.Id)
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
err = removeFailedDataPoints(c.Db, model)
if err != nil {
return c.Error500(err)
}
for i := 0; i < number_of_models; i++ {
// TODO handle incrisea the complexity
generateDefinition(c, model, len(cls), 0)
}
return nil
}
func handleTrain(handle *Handle) { func handleTrain(handle *Handle) {
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) { if !CheckAuthLevel(1, w, r, c) {
@ -473,67 +546,11 @@ func handleTrain(handle *Handle) {
return ErrorCode(nil, 400, c.AddMap(nil)) return ErrorCode(nil, 400, c.AddMap(nil))
} }
cls, err := model_classes.ListClasses(handle.Db, model.Id) full_error := generateDefinitions(c, model, number_of_models)
if err != nil { if full_error != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) return full_error
// TODO improve this response
return c.Error500(err)
} }
err = removeFailedDataPoints(c.Db, model)
if err != nil {
return c.Error500(err)
}
var fid string
for i := 0; i < number_of_models; i++ {
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
if fid == "" {
fid = def_id
}
// TODO change shape of it depends on the type of the image
err = MakeLayer(handle.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
err = MakeLayer(handle.Db, def_id, 4, LAYER_FLATTEN, fmt.Sprintf("%d,1", len(cls)))
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls)*3))
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
// Using sparce
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
}
// TODO start training with id fid
go trainModel(c, model) go trainModel(c, model)
ModelUpdateStatus(c, model.Id, TRAINING) ModelUpdateStatus(c, model.Id, TRAINING)