From c844aeabe4925e0b77f20aa83574d66be4589692 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Sun, 22 Oct 2023 23:02:39 +0100 Subject: [PATCH] closes #49 and possible done #46 --- .air.toml | 2 +- logic/models/add.go | 43 ++--- logic/models/test.go | 41 ++--- logic/models/train/main.go | 8 +- logic/models/train/train.go | 257 +++++++++++++++++++++--------- main.go | 2 + sql/models.sql | 2 +- views/models/edit.html | 26 ++- views/py/python_model_template.py | 10 +- 9 files changed, 256 insertions(+), 135 deletions(-) diff --git a/.air.toml b/.air.toml index 4d1231b..2d1d980 100644 --- a/.air.toml +++ b/.air.toml @@ -14,7 +14,7 @@ tmp_dir = "tmp" follow_symlink = false full_bin = "" include_dir = [] - include_ext = ["go", "tpl", "tmpl", "html"] + include_ext = ["go", "tpl", "tmpl"] # , "html" include_file = [] kill_delay = "0s" log = "build-errors.log" diff --git a/logic/models/add.go b/logic/models/add.go index 5b573e0..0149164 100644 --- a/logic/models/add.go +++ b/logic/models/add.go @@ -2,7 +2,6 @@ package models import ( "bytes" - "fmt" "image" "image/color" _ "image/jpeg" @@ -21,7 +20,7 @@ func loadBaseImage(c *Context, id string) { infile, err := os.Open(path.Join("savedData", id, "baseimage.png")) if err != nil { c.Logger.Errorf("Failed to read image for model with id %s\n", id) - c.Logger.Error(err) + c.Logger.Error(err) ModelUpdateStatus(c, id, FAILED_PREPARING) return } @@ -30,7 +29,7 @@ func loadBaseImage(c *Context, id string) { src, format, err := image.Decode(infile) if err != nil { c.Logger.Errorf("Failed to decode image for model with id %s\n", id) - c.Logger.Error(err) + c.Logger.Error(err) ModelUpdateStatus(c, id, FAILED_PREPARING) return } @@ -38,8 +37,8 @@ func loadBaseImage(c *Context, id string) { case "png": case "jpeg": default: - // TODO better logging - fmt.Printf("Found unkown format '%s'\n", format) + ModelUpdateStatus(c, id, FAILED_PREPARING) + c.Logger.Error("Found unkown format '%s'\n", "format", format) panic("Handle diferent files than .png") } @@ -53,24 +52,26 @@ func loadBaseImage(c *Context, id string) { fallthrough case color.GrayModel: model_color = "greyscale" + case color.NRGBAModel: + fallthrough case color.YCbCrModel: model_color = "rgb" default: - fmt.Println("Do not know how to handle this color model") + c.Logger.Error("Do not know how to handle this color model") if src.ColorModel() == color.RGBA64Model { - fmt.Println("Color is rgb") - } else if src.ColorModel() == color.NRGBAModel { - fmt.Println("Color is nrgb") + c.Logger.Error("Color is rgb") + } else if src.ColorModel() == color.NRGBA64Model { + c.Logger.Error("Color is nrgb 64") } else if src.ColorModel() == color.AlphaModel { - fmt.Println("Color is alpha") + c.Logger.Error("Color is alpha") } else if src.ColorModel() == color.CMYKModel { - fmt.Println("Color is cmyk") + c.Logger.Error("Color is cmyk") } else { - fmt.Println("Other so assuming color") + c.Logger.Error("Other so assuming color") } - ModelUpdateStatus(c, id, -1) + ModelUpdateStatus(c, id, FAILED_PREPARING) return } @@ -91,7 +92,7 @@ func handleAdd(handle *Handle) { return nil } if c.Mode == JSON { - // TODO json + // TODO json panic("TODO JSON") } @@ -130,7 +131,7 @@ func handleAdd(handle *Handle) { row, err := handle.Db.Query("select id from models where name=$1 and user_id=$2;", name, c.User.Id) if err != nil { - return Error500(err) + return c.Error500(err) } if row.Next() { @@ -143,12 +144,12 @@ func handleAdd(handle *Handle) { _, err = handle.Db.Exec("insert into models (user_id, name) values ($1, $2)", c.User.Id, name) if err != nil { - return Error500(err) + return c.Error500(err) } row, err = handle.Db.Query("select id from models where name=$1 and user_id=$2;", name, c.User.Id) if err != nil { - return Error500(err) + return c.Error500(err) } if !row.Next() { @@ -158,7 +159,7 @@ func handleAdd(handle *Handle) { var id string err = row.Scan(&id) if err != nil { - return Error500(err) + return c.Error500(err) } // TODO mk this path configurable @@ -166,17 +167,17 @@ func handleAdd(handle *Handle) { err = os.Mkdir(dir_path, os.ModePerm) if err != nil { - return Error500(err) + return c.Error500(err) } f, err := os.Create(path.Join(dir_path, "baseimage.png")) if err != nil { - return Error500(err) + return c.Error500(err) } defer f.Close() f.Write(file) - fmt.Printf("Created model with id %s! Started to proccess image!\n", id) + c.Logger.Warn("Created model with id %s! Started to proccess image!\n", "id", id) go loadBaseImage(c, id) Redirect("/models/edit?id="+id, c.Mode, w, r) diff --git a/logic/models/test.go b/logic/models/test.go index bb33d7b..5b7f915 100644 --- a/logic/models/test.go +++ b/logic/models/test.go @@ -15,14 +15,14 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { infile, err := os.Open(path) if err != nil { - c.Logger.Errorf("Failed to read image for model with id %s\nErr:%s", model.Id, err) + c.Logger.Errorf("Failed to read image for model with id %s\nErr:%s", model.Id, err) return } defer infile.Close() src, format, err := image.Decode(infile) if err != nil { - c.Logger.Errorf("Failed to decode image for model with id %s\nErr:%s", model.Id, err) + c.Logger.Errorf("Failed to decode image for model with id %s\nErr:%s", model.Id, err) return } @@ -32,18 +32,21 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { width, height := bounds.Max.X, bounds.Max.Y switch src.ColorModel() { - case color.Gray16Model: fallthrough + case color.Gray16Model: + fallthrough case color.GrayModel: model_color = "greyscale" + case color.NRGBAModel: + fallthrough case color.YCbCrModel: model_color = "rgb" default: - c.Logger.Error("Do not know how to handle this color model") + c.Logger.Error("Do not know how to handle this color model") if src.ColorModel() == color.RGBA64Model { c.Logger.Info("Color is rgb") - } else if src.ColorModel() == color.NRGBAModel { - c.Logger.Info("Color is nrgb") + } else if src.ColorModel() == color.NRGBA64Model { + c.Logger.Info("Color is nrgb 64") } else if src.ColorModel() == color.AlphaModel { c.Logger.Info("Color is alpha") } else if src.ColorModel() == color.CMYKModel { @@ -51,23 +54,23 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { } else { c.Logger.Info("Other so assuming color") } - return + return } - if (StringToImageMode(model_color) != model.ImageMode) { - c.Logger.Warn("Color Mode does not match with model color mode", model_color, model.ImageMode) - return - } + if StringToImageMode(model_color) != model.ImageMode { + c.Logger.Warn("Color Mode does not match with model color mode", model_color, model.ImageMode) + return + } - if height != model.Height || width != model.Width { - c.Logger.Warn("Image size does not match model size", width, height, model.Width, model.Height) - return - } + if height != model.Height || width != model.Width { + c.Logger.Warn("Image size does not match model size", width, height, model.Width, model.Height) + return + } - if format != model.Format { - c.Logger.Warn("Image format does not match model", format, model.Format) - return - } + if format != model.Format { + c.Logger.Warn("Image format does not match model", format, model.Format) + return + } return true } diff --git a/logic/models/train/main.go b/logic/models/train/main.go index 966ec71..befdb5b 100644 --- a/logic/models/train/main.go +++ b/logic/models/train/main.go @@ -5,9 +5,9 @@ import ( ) func HandleTrainEndpoints(handle *Handle) { - handleTrain(handle) - handleRest(handle) + handleTrain(handle) + handleRest(handle) - //TODO remove - handleTest(handle) + //TODO remove + handleTest(handle) } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index dd363fc..0f7605f 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path" + "sort" "strconv" "text/template" @@ -43,6 +44,7 @@ const ( MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 MODEL_DEFINITION_STATUS_INIT = 2 MODEL_DEFINITION_STATUS_TRAINING = 3 + MODEL_DEFINITION_STATUS_PAUSED_TRAINING = 6 MODEL_DEFINITION_STATUS_TRANIED = 4 MODEL_DEFINITION_STATUS_READY = 5 ) @@ -50,10 +52,10 @@ const ( type LayerType int const ( - LAYER_INPUT LayerType = 1 - LAYER_DENSE = 2 - LAYER_FLATTEN = 3 - LAYER_SIMPLE_BLOCK = 4 + LAYER_INPUT LayerType = 1 + LAYER_DENSE = 2 + LAYER_FLATTEN = 3 + LAYER_SIMPLE_BLOCK = 4 ) func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionStatus) (err error) { @@ -142,6 +144,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr if err != nil { return } + defer os.RemoveAll(run_path) _, err = generateCvs(c, run_path, model.Id) if err != nil { @@ -174,29 +177,24 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr "DefId": definition_id, "LoadPrev": load_prev, "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"), + "SaveModelPath": path.Join(getDir(), result_path), }); err != nil { return } // Run the command - out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output() + out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput() if err != nil { c.Logger.Debug(string(out)) return } + c.Logger.Info("Python finished running") + if err = os.MkdirAll(result_path, os.ModePerm); err != nil { return } - if err = exec.Command("cp", "-r", path.Join(run_path, "model"), path.Join(result_path, "model")).Run(); err != nil { - return - } - - if err = exec.Command("cp", "-r", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run(); err != nil { - return - } - accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val")) if err != nil { return @@ -214,8 +212,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr } c.Logger.Info("Model finished training!", "accuracy", accuracy) - - os.RemoveAll(run_path) return } @@ -236,6 +232,29 @@ func remove[T interface{}](lst []T, i int) []T { return append(lst[:i], lst[i+1:]...) } +type TrainModelRow struct { + id string + target_accuracy int + epoch int + acuracy float64 +} + +type TraingModelRowDefinitions []TrainModelRow + +func (nf TraingModelRowDefinitions) Len() int { return len(nf) } +func (nf TraingModelRowDefinitions) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] } +func (nf TraingModelRowDefinitions) Less(i, j int) bool { + return nf[i].acuracy < nf[j].acuracy +} + +type ToRemoveList []int + +func (nf ToRemoveList) Len() int { return len(nf) } +func (nf ToRemoveList) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] } +func (nf ToRemoveList) Less(i, j int) bool { + return nf[i] < nf[j] +} + func trainModel(c *Context, model *BaseModel) { definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) if err != nil { @@ -246,16 +265,11 @@ func trainModel(c *Context, model *BaseModel) { } defer definitionsRows.Close() - type row struct { - id string - target_accuracy int - epoch int - } - - definitions := []row{} + var definitions TraingModelRowDefinitions = []TrainModelRow{} for definitionsRows.Next() { - var rowv row + var rowv TrainModelRow + rowv.acuracy = 0 if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil { c.Logger.Error("Failed to train Model Could not read definition from db!Err:") c.Logger.Error(err) @@ -271,23 +285,23 @@ func trainModel(c *Context, model *BaseModel) { return } - toTrain := len(definitions) firstRound := true - var newDefinitions = []row{} - copy(newDefinitions, definitions) + finished := false + for { + var toRemove ToRemoveList = []int{} for i, def := range definitions { ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) accuracy, err := trainDefinition(c, model, def.id, !firstRound) if err != nil { c.Logger.Error("Failed to train definition!Err:", "err", err) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) - toTrain = toTrain - 1 - newDefinitions = remove(newDefinitions, i) + toRemove = append(toRemove, i) continue } def.epoch += EPOCH_PER_RUN accuracy = accuracy * 100 + def.acuracy = accuracy if accuracy >= float64(def.target_accuracy) { c.Logger.Info("Found a definition that reaches target_accuracy!") @@ -305,30 +319,68 @@ func trainModel(c *Context, model *BaseModel) { return } - toTrain = 0 + finished = true break } if def.epoch > MAX_EPOCH { fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) - toTrain = toTrain - 1 - newDefinitions = remove(newDefinitions, i) + toRemove = append(toRemove, i) continue } - _, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2 where id=$3", accuracy, def.epoch, def.id) - if err != nil { - c.Logger.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return - } + _, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id) + if err != nil { + c.Logger.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } } - copy(definitions, newDefinitions) + firstRound = false - if toTrain == 0 { + if finished { break } + + sort.Reverse(toRemove) + + c.Logger.Info("Round done", "toRemove", toRemove) + + for _, n := range toRemove { + definitions = remove(definitions, n) + } + + len_def := len(definitions) + + if len_def == 0 { + break + } + + if len_def == 1 { + continue + } + + sort.Sort(definitions) + + acc := definitions[0].acuracy - 20 + + c.Logger.Info("Training models, Highest acc", "acc", acc) + + toRemove = []int{} + for i, def := range definitions { + if def.acuracy < acc { + toRemove = append(toRemove, i) + } + } + + c.Logger.Info("Removing due to accuracy", "toRemove", toRemove) + + sort.Reverse(toRemove) + for _, n := range toRemove { + c.Logger.Warn("Removing definition not fast enough learning", "n", n) + definitions = remove(definitions, n) + } } rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) @@ -437,14 +489,26 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe return failed() } - order := 1; + order := 1 - // Note the shape for now is no used - err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) - if err != nil { - return failed() + // Note the shape of the first layer defines the import size + if complexity == 2 { + // Note the shape for now is no used + width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0)))) + height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0)))) + c.Logger.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height) + err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height)) + if err != nil { + return failed() + } + order++ + } else { + err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) + if err != nil { + return failed() + } + order++ } - order++; if complexity == 0 { @@ -452,12 +516,12 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe if err != nil { return failed() } - order++; + order++ loop := int(math.Log2(float64(number_of_classes))) for i := 0; i < loop; i++ { err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) - order++; + order++ if err != nil { ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response @@ -465,17 +529,17 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe } } - } else if (complexity == 1) { + } else if complexity == 1 { - loop := int((math.Log(float64(model.Width))/math.Log(float64(10)))) - if loop == 0 { - loop = 1; - } + loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) + if loop == 0 { + loop = 1 + } for i := 0; i < loop; i++ { err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "") - order++; + order++ if err != nil { - return failed(); + return failed() } } @@ -483,17 +547,49 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe if err != nil { return failed() } - order++; + order++ - loop = int((math.Log(float64(number_of_classes))/math.Log(float64(10)))/2) - if loop == 0 { - loop = 1; - } + loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) + if loop == 0 { + loop = 1 + } for i := 0; i < loop; i++ { err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) - order++; + order++ if err != nil { - return failed(); + return failed() + } + } + + } else if complexity == 2 { + + loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) + if loop == 0 { + loop = 1 + } + for i := 0; i < loop; i++ { + err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "") + order++ + if err != nil { + return failed() + } + } + + err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "") + if err != nil { + return failed() + } + order++ + + loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) + if loop == 0 { + loop = 1 + } + for i := 0; i < loop; i++ { + err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) + order++ + if err != nil { + return failed() } } @@ -523,19 +619,26 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb return c.Error500(err) } - if (number_of_models == 1) { - if (model.Width < 100 && model.Height < 100 && len(cls) < 30) { - generateDefinition(c, model, target_accuracy, len(cls), 0) - } else { - generateDefinition(c, model, target_accuracy, len(cls), 1) - } - } else { - // TODO handle incrisea the complexity - for i := 0; i < number_of_models; i++ { - generateDefinition(c, model, target_accuracy, len(cls), 0) - } - } + cls_len := len(cls) + if number_of_models == 1 { + if model.Width < 100 && model.Height < 100 && cls_len < 30 { + generateDefinition(c, model, target_accuracy, cls_len, 0) + } else if model.Width > 100 && model.Height > 100 { + generateDefinition(c, model, target_accuracy, cls_len, 2) + } else { + generateDefinition(c, model, target_accuracy, cls_len, 1) + } + } else if number_of_models == 3 { + for i := 0; i < number_of_models; i++ { + generateDefinition(c, model, target_accuracy, cls_len, i) + } + } else { + // TODO handle incrisea the complexity + for i := 0; i < number_of_models; i++ { + generateDefinition(c, model, target_accuracy, cls_len, 0) + } + } return nil } @@ -624,14 +727,14 @@ func handleTrain(handle *Handle) { f := r.URL.Query() - accuracy := 0.0 + accuracy := 0.0 - if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") || !CheckFloat64(f, "accuracy", &accuracy){ + if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") || !CheckFloat64(f, "accuracy", &accuracy) { c.Logger.Warn("Invalid: model_id or definition or epoch or accuracy") return c.UnsafeErrorCode(nil, 400, nil) } - accuracy = accuracy * 100 + accuracy = accuracy * 100 model_id := f.Get("model_id") def_id := f.Get("definition") @@ -665,7 +768,7 @@ func handleTrain(handle *Handle) { return c.UnsafeErrorCode(nil, 400, nil) } - c.Logger.Info("Updated model_definition!", "model", model_id, "progress", epoch, "accuracy", accuracy) + c.Logger.Info("Updated model_definition!", "model", model_id, "progress", epoch, "accuracy", accuracy) _, err = c.Db.Exec("update model_definition set epoch_progress=$1, accuracy=$2 where id=$3", epoch, accuracy, def_id) if err != nil { diff --git a/main.go b/main.go index 0017644..570caf0 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" + "github.com/charmbracelet/log" _ "github.com/lib/pq" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" @@ -36,6 +37,7 @@ func main() { _, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING) if err != nil { + log.Warn("Database might not be on") panic(err) } diff --git a/sql/models.sql b/sql/models.sql index ba22056..aa57d62 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -14,7 +14,7 @@ create table if not exists models ( width integer, height integer, color_mode varchar (20), - format varchar (20) + format varchar (20) default '' ); -- drop table if exists model_data_point; diff --git a/views/models/edit.html b/views/models/edit.html index 011d7c6..aacc485 100644 --- a/views/models/edit.html +++ b/views/models/edit.html @@ -438,27 +438,37 @@ - Status - - - EpochProgress + Training Round Progress Accuracy + + Status + {{ range .Defs}} - {{.Status}} + {{.EpochProgress}}/20 - {{.EpochProgress}} + {{.Accuracy}}% - - {{.Accuracy}} + + {{ if (eq .Status 2) }} + + {{ else if (eq .Status 3) }} + + {{ else if (eq .Status 6) }} + + {{ else if (eq .Status -3) }} + + {{ else }} + {{.Status}} + {{ end }} {{ end }} diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 7fc2ec5..4c28085 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -8,7 +8,7 @@ import requests class NotifyServerCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, log, *args, **kwargs): - requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&accuracy={log["accuracy"]}&definition={{.DefId}}') + requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}') DATA_DIR = "{{ .DataDir }}" @@ -160,7 +160,9 @@ model.compile( optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) -his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()], use_multiprocessing = True) +his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[ + NotifyServerCallback(), + tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)], use_multiprocessing = True) acc = his.history["accuracy"] @@ -169,5 +171,5 @@ f.write(str(acc[-1])) f.close() -tf.saved_model.save(model, "model") -model.save("model.keras") +tf.saved_model.save(model, "{{ .SaveModelPath }}/model") +model.save("{{ .SaveModelPath }}/model.keras")