worked on #32
This commit is contained in:
@@ -101,7 +101,7 @@ func handleEdit(handle *Handle) {
|
||||
type defrow struct {
|
||||
Status int
|
||||
EpochProgress int
|
||||
Accuracy int
|
||||
Accuracy float64
|
||||
}
|
||||
|
||||
def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)
|
||||
|
||||
@@ -53,6 +53,7 @@ const (
|
||||
LAYER_INPUT LayerType = 1
|
||||
LAYER_DENSE = 2
|
||||
LAYER_FLATTEN = 3
|
||||
LAYER_SIMPLE_BLOCK = 4
|
||||
)
|
||||
|
||||
func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionStatus) (err error) {
|
||||
@@ -207,13 +208,13 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
|
||||
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
return
|
||||
}
|
||||
@@ -286,12 +287,11 @@ func trainModel(c *Context, model *BaseModel) {
|
||||
continue
|
||||
}
|
||||
def.epoch += EPOCH_PER_RUN
|
||||
accuracy = accuracy * 100
|
||||
|
||||
int_accuracy := int(accuracy * 100)
|
||||
|
||||
if int_accuracy >= def.target_accuracy {
|
||||
if accuracy >= float64(def.target_accuracy) {
|
||||
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||
@@ -310,13 +310,19 @@ func trainModel(c *Context, model *BaseModel) {
|
||||
}
|
||||
|
||||
if def.epoch > MAX_EPOCH {
|
||||
fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
|
||||
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy)
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
toTrain = toTrain - 1
|
||||
newDefinitions = remove(newDefinitions, i)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2 where id=$3", accuracy, def.epoch, def.id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||
return
|
||||
}
|
||||
}
|
||||
copy(definitions, newDefinitions)
|
||||
firstRound = false
|
||||
@@ -403,9 +409,9 @@ func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
p := path.Join(base_path, dataPointId + "." + model.Format)
|
||||
|
||||
c.Logger.Warn("Removing image", "path", p)
|
||||
p := path.Join(base_path, dataPointId+"."+model.Format)
|
||||
|
||||
c.Logger.Warn("Removing image", "path", p)
|
||||
|
||||
err = os.RemoveAll(p)
|
||||
if err != nil {
|
||||
@@ -418,57 +424,93 @@ func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
|
||||
}
|
||||
|
||||
// This generates a definition
|
||||
func generateDefinition(c *Context, model *BaseModel, number_of_classes int, complexity int) *Error {
|
||||
var err error = nil
|
||||
failed := func() *Error {
|
||||
func generateDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
|
||||
var err error = nil
|
||||
failed := func() *Error {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
def_id, err := MakeDefenition(c.Db, model.Id, target_accuracy)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
def_id, err := MakeDefenition(c.Db, model.Id, 0)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
// Note the shape for now is no used
|
||||
err = MakeLayer(c.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
order := 1;
|
||||
|
||||
if complexity == 0 {
|
||||
// Note the shape for now is no used
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
order++;
|
||||
|
||||
err = MakeLayer(c.Db, def_id, 4, LAYER_FLATTEN, "")
|
||||
if complexity == 0 {
|
||||
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "")
|
||||
if err != nil {
|
||||
return failed()
|
||||
return failed()
|
||||
}
|
||||
order++;
|
||||
|
||||
loop := int(math.Log2(float64(number_of_classes))/2)
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
||||
order++;
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
loop := int(math.Log2(float64(number_of_classes)))
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayer(c.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop - i)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
} else if (complexity == 1) {
|
||||
|
||||
loop := int((math.Log(float64(model.Width))/math.Log(float64(10))))
|
||||
if loop == 0 {
|
||||
loop = 1;
|
||||
}
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "")
|
||||
order++;
|
||||
if err != nil {
|
||||
return failed();
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
c.Logger.Error("Unkown complexity", "complexity", complexity)
|
||||
return failed()
|
||||
}
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "")
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
order++;
|
||||
|
||||
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
loop = int((math.Log(float64(number_of_classes))/math.Log(float64(10)))/2)
|
||||
if loop == 0 {
|
||||
loop = 1;
|
||||
}
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
||||
order++;
|
||||
if err != nil {
|
||||
return failed();
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
} else {
|
||||
c.Logger.Error("Unkown complexity", "complexity", complexity)
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Error {
|
||||
func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, number_of_models int) *Error {
|
||||
cls, err := model_classes.ListClasses(c.Db, model.Id)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
@@ -481,12 +523,21 @@ func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Er
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
for i := 0; i < number_of_models; i++ {
|
||||
if (number_of_models == 1) {
|
||||
if (model.Width < 100 && model.Height < 100 && len(cls) < 30) {
|
||||
generateDefinition(c, model, target_accuracy, len(cls), 0)
|
||||
} else {
|
||||
generateDefinition(c, model, target_accuracy, len(cls), 1)
|
||||
}
|
||||
} else {
|
||||
// TODO handle incrisea the complexity
|
||||
generateDefinition(c, model, len(cls), 0)
|
||||
}
|
||||
for i := 0; i < number_of_models; i++ {
|
||||
generateDefinition(c, model, target_accuracy, len(cls), 0)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
@@ -551,10 +602,10 @@ func handleTrain(handle *Handle) {
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
full_error := generateDefinitions(c, model, number_of_models)
|
||||
if full_error != nil {
|
||||
return full_error
|
||||
}
|
||||
full_error := generateDefinitions(c, model, accuracy, number_of_models)
|
||||
if full_error != nil {
|
||||
return full_error
|
||||
}
|
||||
|
||||
go trainModel(c, model)
|
||||
|
||||
@@ -573,11 +624,15 @@ func handleTrain(handle *Handle) {
|
||||
|
||||
f := r.URL.Query()
|
||||
|
||||
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
|
||||
c.Logger.Warn("Invalid: model_id or definition or epoch")
|
||||
accuracy := 0.0
|
||||
|
||||
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") || !CheckFloat64(f, "accuracy", &accuracy){
|
||||
c.Logger.Warn("Invalid: model_id or definition or epoch or accuracy")
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
accuracy = accuracy * 100
|
||||
|
||||
model_id := f.Get("model_id")
|
||||
def_id := f.Get("definition")
|
||||
epoch, err := strconv.Atoi(f.Get("epoch"))
|
||||
@@ -610,7 +665,9 @@ func handleTrain(handle *Handle) {
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
|
||||
c.Logger.Info("Updated model_definition!", "model", model_id, "progress", epoch, "accuracy", accuracy)
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set epoch_progress=$1, accuracy=$2 where id=$3", epoch, accuracy, def_id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,21 @@ func CheckNumber(f url.Values, path string, number *int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckFloat64(f url.Values, path string, number *float64) bool {
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.ParseFloat(f.Get(path), 64)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckId(f url.Values, path string) bool {
|
||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user