chore: started working on #32
This commit is contained in:
parent
2c3539b81a
commit
46705ba5d9
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@ -17,7 +18,7 @@ import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
const EPOCH_PER_RUN = 20;
|
||||
const EPOCH_PER_RUN = 20
|
||||
const MAX_EPOCH = 100
|
||||
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
@ -411,6 +412,78 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// This generates a definition
|
||||
func generateDefinition(c *Context, model *BaseModel, number_of_classes int, complexity int) *Error {
|
||||
var err error = nil
|
||||
failed := func() *Error {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
|
||||
def_id, err := MakeDefenition(c.Db, model.Id, 0)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
// Note the shape for now is no used
|
||||
err = MakeLayer(c.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
if complexity == 0 {
|
||||
|
||||
err = MakeLayer(c.Db, def_id, 4, LAYER_FLATTEN, "")
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
loop := int(math.Log2(float64(number_of_classes)))
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayer(c.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop - i)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
c.Logger.Error("Unkown complexity", "complexity", complexity)
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Error {
|
||||
cls, err := model_classes.ListClasses(c.Db, model.Id)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
err = removeFailedDataPoints(c.Db, model)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
for i := 0; i < number_of_models; i++ {
|
||||
// TODO handle incrisea the complexity
|
||||
generateDefinition(c, model, len(cls), 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
@ -473,67 +546,11 @@ func handleTrain(handle *Handle) {
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
cls, err := model_classes.ListClasses(handle.Db, model.Id)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
full_error := generateDefinitions(c, model, number_of_models)
|
||||
if full_error != nil {
|
||||
return full_error
|
||||
}
|
||||
|
||||
err = removeFailedDataPoints(c.Db, model)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
var fid string
|
||||
for i := 0; i < number_of_models; i++ {
|
||||
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
if fid == "" {
|
||||
fid = def_id
|
||||
}
|
||||
|
||||
// TODO change shape of it depends on the type of the image
|
||||
err = MakeLayer(handle.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
err = MakeLayer(handle.Db, def_id, 4, LAYER_FLATTEN, fmt.Sprintf("%d,1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls)*3))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
// Using sparce
|
||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO start training with id fid
|
||||
go trainModel(c, model)
|
||||
|
||||
ModelUpdateStatus(c, model.Id, TRAINING)
|
||||
|
Loading…
Reference in New Issue
Block a user