chore: started working on #32
This commit is contained in:
parent
2c3539b81a
commit
46705ba5d9
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@ -17,7 +18,7 @@ import (
|
|||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
const EPOCH_PER_RUN = 20;
|
const EPOCH_PER_RUN = 20
|
||||||
const MAX_EPOCH = 100
|
const MAX_EPOCH = 100
|
||||||
|
|
||||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||||
@ -411,6 +412,78 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This generates a definition
|
||||||
|
func generateDefinition(c *Context, model *BaseModel, number_of_classes int, complexity int) *Error {
|
||||||
|
var err error = nil
|
||||||
|
failed := func() *Error {
|
||||||
|
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def_id, err := MakeDefenition(c.Db, model.Id, 0)
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note the shape for now is no used
|
||||||
|
err = MakeLayer(c.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
if complexity == 0 {
|
||||||
|
|
||||||
|
err = MakeLayer(c.Db, def_id, 4, LAYER_FLATTEN, "")
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
loop := int(math.Log2(float64(number_of_classes)))
|
||||||
|
for i := 0; i < loop; i++ {
|
||||||
|
err = MakeLayer(c.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop - i)))
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
c.Logger.Error("Unkown complexity", "complexity", complexity)
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Error {
|
||||||
|
cls, err := model_classes.ListClasses(c.Db, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = removeFailedDataPoints(c.Db, model)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < number_of_models; i++ {
|
||||||
|
// TODO handle incrisea the complexity
|
||||||
|
generateDefinition(c, model, len(cls), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func handleTrain(handle *Handle) {
|
func handleTrain(handle *Handle) {
|
||||||
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
if !CheckAuthLevel(1, w, r, c) {
|
if !CheckAuthLevel(1, w, r, c) {
|
||||||
@ -473,67 +546,11 @@ func handleTrain(handle *Handle) {
|
|||||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
cls, err := model_classes.ListClasses(handle.Db, model.Id)
|
full_error := generateDefinitions(c, model, number_of_models)
|
||||||
if err != nil {
|
if full_error != nil {
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
return full_error
|
||||||
// TODO improve this response
|
}
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = removeFailedDataPoints(c.Db, model)
|
|
||||||
if err != nil {
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var fid string
|
|
||||||
for i := 0; i < number_of_models; i++ {
|
|
||||||
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if fid == "" {
|
|
||||||
fid = def_id
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO change shape of it depends on the type of the image
|
|
||||||
err = MakeLayer(handle.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
err = MakeLayer(handle.Db, def_id, 4, LAYER_FLATTEN, fmt.Sprintf("%d,1", len(cls)))
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls)*3))
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
// Using sparce
|
|
||||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO start training with id fid
|
|
||||||
go trainModel(c, model)
|
go trainModel(c, model)
|
||||||
|
|
||||||
ModelUpdateStatus(c, model.Id, TRAINING)
|
ModelUpdateStatus(c, model.Id, TRAINING)
|
||||||
|
Loading…
Reference in New Issue
Block a user