This commit is contained in:
2023-10-21 00:26:52 +01:00
parent ff9aca2699
commit 805be22388
5 changed files with 179 additions and 77 deletions

View File

@@ -101,7 +101,7 @@ func handleEdit(handle *Handle) {
type defrow struct {
Status int
EpochProgress int
Accuracy int
Accuracy float64
}
def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)

View File

@@ -53,6 +53,7 @@ const (
LAYER_INPUT LayerType = 1
LAYER_DENSE = 2
LAYER_FLATTEN = 3
LAYER_SIMPLE_BLOCK = 4
)
func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionStatus) (err error) {
@@ -207,13 +208,13 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
return
}
c.Logger.Info("Model finished training!", "accuracy", accuracy)
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil {
return
}
c.Logger.Info("Model finished training!", "accuracy", accuracy)
os.RemoveAll(run_path)
return
}
@@ -286,12 +287,11 @@ func trainModel(c *Context, model *BaseModel) {
continue
}
def.epoch += EPOCH_PER_RUN
accuracy = accuracy * 100
int_accuracy := int(accuracy * 100)
if int_accuracy >= def.target_accuracy {
if accuracy >= float64(def.target_accuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
if err != nil {
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
@@ -310,13 +310,19 @@ func trainModel(c *Context, model *BaseModel) {
}
if def.epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toTrain = toTrain - 1
newDefinitions = remove(newDefinitions, i)
continue
}
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2 where id=$3", accuracy, def.epoch, def.id)
if err != nil {
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
}
copy(definitions, newDefinitions)
firstRound = false
@@ -403,9 +409,9 @@ func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
return
}
p := path.Join(base_path, dataPointId + "." + model.Format)
c.Logger.Warn("Removing image", "path", p)
p := path.Join(base_path, dataPointId+"."+model.Format)
c.Logger.Warn("Removing image", "path", p)
err = os.RemoveAll(p)
if err != nil {
@@ -418,57 +424,93 @@ func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
}
// This generates a definition
func generateDefinition(c *Context, model *BaseModel, number_of_classes int, complexity int) *Error {
var err error = nil
failed := func() *Error {
func generateDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
var err error = nil
failed := func() *Error {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
}
def_id, err := MakeDefenition(c.Db, model.Id, target_accuracy)
if err != nil {
return failed()
}
def_id, err := MakeDefenition(c.Db, model.Id, 0)
if err != nil {
return failed()
}
// Note the shape for now is no used
err = MakeLayer(c.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil {
return failed()
}
order := 1;
if complexity == 0 {
// Note the shape for now is no used
err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil {
return failed()
}
order++;
err = MakeLayer(c.Db, def_id, 4, LAYER_FLATTEN, "")
if complexity == 0 {
err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "")
if err != nil {
return failed()
return failed()
}
order++;
loop := int(math.Log2(float64(number_of_classes))/2)
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++;
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
}
loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop - i)))
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return c.Error500(err)
}
} else if (complexity == 1) {
loop := int((math.Log(float64(model.Width))/math.Log(float64(10))))
if loop == 0 {
loop = 1;
}
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "")
order++;
if err != nil {
return failed();
}
}
} else {
c.Logger.Error("Unkown complexity", "complexity", complexity)
return failed()
}
err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "")
if err != nil {
return failed()
}
order++;
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
}
loop = int((math.Log(float64(number_of_classes))/math.Log(float64(10)))/2)
if loop == 0 {
loop = 1;
}
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++;
if err != nil {
return failed();
}
}
return nil
} else {
c.Logger.Error("Unkown complexity", "complexity", complexity)
return failed()
}
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
}
return nil
}
func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Error {
func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, number_of_models int) *Error {
cls, err := model_classes.ListClasses(c.Db, model.Id)
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
@@ -481,12 +523,21 @@ func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Er
return c.Error500(err)
}
for i := 0; i < number_of_models; i++ {
if (number_of_models == 1) {
if (model.Width < 100 && model.Height < 100 && len(cls) < 30) {
generateDefinition(c, model, target_accuracy, len(cls), 0)
} else {
generateDefinition(c, model, target_accuracy, len(cls), 1)
}
} else {
// TODO handle incrisea the complexity
generateDefinition(c, model, len(cls), 0)
}
for i := 0; i < number_of_models; i++ {
generateDefinition(c, model, target_accuracy, len(cls), 0)
}
}
return nil
return nil
}
func handleTrain(handle *Handle) {
@@ -551,10 +602,10 @@ func handleTrain(handle *Handle) {
return ErrorCode(nil, 400, c.AddMap(nil))
}
full_error := generateDefinitions(c, model, number_of_models)
if full_error != nil {
return full_error
}
full_error := generateDefinitions(c, model, accuracy, number_of_models)
if full_error != nil {
return full_error
}
go trainModel(c, model)
@@ -573,11 +624,15 @@ func handleTrain(handle *Handle) {
f := r.URL.Query()
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
c.Logger.Warn("Invalid: model_id or definition or epoch")
accuracy := 0.0
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") || !CheckFloat64(f, "accuracy", &accuracy){
c.Logger.Warn("Invalid: model_id or definition or epoch or accuracy")
return c.UnsafeErrorCode(nil, 400, nil)
}
accuracy = accuracy * 100
model_id := f.Get("model_id")
def_id := f.Get("definition")
epoch, err := strconv.Atoi(f.Get("epoch"))
@@ -610,7 +665,9 @@ func handleTrain(handle *Handle) {
return c.UnsafeErrorCode(nil, 400, nil)
}
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
c.Logger.Info("Updated model_definition!", "model", model_id, "progress", epoch, "accuracy", accuracy)
_, err = c.Db.Exec("update model_definition set epoch_progress=$1, accuracy=$2 where id=$3", epoch, accuracy, def_id)
if err != nil {
return c.Error500(err)
}

View File

@@ -31,6 +31,21 @@ func CheckNumber(f url.Values, path string, number *int) bool {
return true
}
func CheckFloat64(f url.Values, path string, number *float64) bool {
if CheckEmpty(f, path) {
fmt.Println("here", path)
fmt.Println(f.Get(path))
return false
}
n, err := strconv.ParseFloat(f.Get(path), 64)
if err != nil {
fmt.Println(err)
return false
}
*number = n
return true
}
func CheckId(f url.Values, path string) bool {
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
}