closes #49 and possible done #46

This commit is contained in:
Andre Henriques 2023-10-22 23:02:39 +01:00
parent 90bc3f6acf
commit c844aeabe4
9 changed files with 256 additions and 135 deletions

View File

@ -14,7 +14,7 @@ tmp_dir = "tmp"
follow_symlink = false follow_symlink = false
full_bin = "" full_bin = ""
include_dir = [] include_dir = []
include_ext = ["go", "tpl", "tmpl", "html"] include_ext = ["go", "tpl", "tmpl"] # , "html"
include_file = [] include_file = []
kill_delay = "0s" kill_delay = "0s"
log = "build-errors.log" log = "build-errors.log"

View File

@ -2,7 +2,6 @@ package models
import ( import (
"bytes" "bytes"
"fmt"
"image" "image"
"image/color" "image/color"
_ "image/jpeg" _ "image/jpeg"
@ -38,8 +37,8 @@ func loadBaseImage(c *Context, id string) {
case "png": case "png":
case "jpeg": case "jpeg":
default: default:
// TODO better logging ModelUpdateStatus(c, id, FAILED_PREPARING)
fmt.Printf("Found unkown format '%s'\n", format) c.Logger.Error("Found unkown format '%s'\n", "format", format)
panic("Handle diferent files than .png") panic("Handle diferent files than .png")
} }
@ -53,24 +52,26 @@ func loadBaseImage(c *Context, id string) {
fallthrough fallthrough
case color.GrayModel: case color.GrayModel:
model_color = "greyscale" model_color = "greyscale"
case color.NRGBAModel:
fallthrough
case color.YCbCrModel: case color.YCbCrModel:
model_color = "rgb" model_color = "rgb"
default: default:
fmt.Println("Do not know how to handle this color model") c.Logger.Error("Do not know how to handle this color model")
if src.ColorModel() == color.RGBA64Model { if src.ColorModel() == color.RGBA64Model {
fmt.Println("Color is rgb") c.Logger.Error("Color is rgb")
} else if src.ColorModel() == color.NRGBAModel { } else if src.ColorModel() == color.NRGBA64Model {
fmt.Println("Color is nrgb") c.Logger.Error("Color is nrgb 64")
} else if src.ColorModel() == color.AlphaModel { } else if src.ColorModel() == color.AlphaModel {
fmt.Println("Color is alpha") c.Logger.Error("Color is alpha")
} else if src.ColorModel() == color.CMYKModel { } else if src.ColorModel() == color.CMYKModel {
fmt.Println("Color is cmyk") c.Logger.Error("Color is cmyk")
} else { } else {
fmt.Println("Other so assuming color") c.Logger.Error("Other so assuming color")
} }
ModelUpdateStatus(c, id, -1) ModelUpdateStatus(c, id, FAILED_PREPARING)
return return
} }
@ -130,7 +131,7 @@ func handleAdd(handle *Handle) {
row, err := handle.Db.Query("select id from models where name=$1 and user_id=$2;", name, c.User.Id) row, err := handle.Db.Query("select id from models where name=$1 and user_id=$2;", name, c.User.Id)
if err != nil { if err != nil {
return Error500(err) return c.Error500(err)
} }
if row.Next() { if row.Next() {
@ -143,12 +144,12 @@ func handleAdd(handle *Handle) {
_, err = handle.Db.Exec("insert into models (user_id, name) values ($1, $2)", c.User.Id, name) _, err = handle.Db.Exec("insert into models (user_id, name) values ($1, $2)", c.User.Id, name)
if err != nil { if err != nil {
return Error500(err) return c.Error500(err)
} }
row, err = handle.Db.Query("select id from models where name=$1 and user_id=$2;", name, c.User.Id) row, err = handle.Db.Query("select id from models where name=$1 and user_id=$2;", name, c.User.Id)
if err != nil { if err != nil {
return Error500(err) return c.Error500(err)
} }
if !row.Next() { if !row.Next() {
@ -158,7 +159,7 @@ func handleAdd(handle *Handle) {
var id string var id string
err = row.Scan(&id) err = row.Scan(&id)
if err != nil { if err != nil {
return Error500(err) return c.Error500(err)
} }
// TODO mk this path configurable // TODO mk this path configurable
@ -166,17 +167,17 @@ func handleAdd(handle *Handle) {
err = os.Mkdir(dir_path, os.ModePerm) err = os.Mkdir(dir_path, os.ModePerm)
if err != nil { if err != nil {
return Error500(err) return c.Error500(err)
} }
f, err := os.Create(path.Join(dir_path, "baseimage.png")) f, err := os.Create(path.Join(dir_path, "baseimage.png"))
if err != nil { if err != nil {
return Error500(err) return c.Error500(err)
} }
defer f.Close() defer f.Close()
f.Write(file) f.Write(file)
fmt.Printf("Created model with id %s! Started to proccess image!\n", id) c.Logger.Warn("Created model with id %s! Started to proccess image!\n", "id", id)
go loadBaseImage(c, id) go loadBaseImage(c, id)
Redirect("/models/edit?id="+id, c.Mode, w, r) Redirect("/models/edit?id="+id, c.Mode, w, r)

View File

@ -32,9 +32,12 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) {
width, height := bounds.Max.X, bounds.Max.Y width, height := bounds.Max.X, bounds.Max.Y
switch src.ColorModel() { switch src.ColorModel() {
case color.Gray16Model: fallthrough case color.Gray16Model:
fallthrough
case color.GrayModel: case color.GrayModel:
model_color = "greyscale" model_color = "greyscale"
case color.NRGBAModel:
fallthrough
case color.YCbCrModel: case color.YCbCrModel:
model_color = "rgb" model_color = "rgb"
default: default:
@ -42,8 +45,8 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) {
if src.ColorModel() == color.RGBA64Model { if src.ColorModel() == color.RGBA64Model {
c.Logger.Info("Color is rgb") c.Logger.Info("Color is rgb")
} else if src.ColorModel() == color.NRGBAModel { } else if src.ColorModel() == color.NRGBA64Model {
c.Logger.Info("Color is nrgb") c.Logger.Info("Color is nrgb 64")
} else if src.ColorModel() == color.AlphaModel { } else if src.ColorModel() == color.AlphaModel {
c.Logger.Info("Color is alpha") c.Logger.Info("Color is alpha")
} else if src.ColorModel() == color.CMYKModel { } else if src.ColorModel() == color.CMYKModel {
@ -54,7 +57,7 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) {
return return
} }
if (StringToImageMode(model_color) != model.ImageMode) { if StringToImageMode(model_color) != model.ImageMode {
c.Logger.Warn("Color Mode does not match with model color mode", model_color, model.ImageMode) c.Logger.Warn("Color Mode does not match with model color mode", model_color, model.ImageMode)
return return
} }

View File

@ -10,6 +10,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path" "path"
"sort"
"strconv" "strconv"
"text/template" "text/template"
@ -43,6 +44,7 @@ const (
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
MODEL_DEFINITION_STATUS_INIT = 2 MODEL_DEFINITION_STATUS_INIT = 2
MODEL_DEFINITION_STATUS_TRAINING = 3 MODEL_DEFINITION_STATUS_TRAINING = 3
MODEL_DEFINITION_STATUS_PAUSED_TRAINING = 6
MODEL_DEFINITION_STATUS_TRANIED = 4 MODEL_DEFINITION_STATUS_TRANIED = 4
MODEL_DEFINITION_STATUS_READY = 5 MODEL_DEFINITION_STATUS_READY = 5
) )
@ -142,6 +144,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
if err != nil { if err != nil {
return return
} }
defer os.RemoveAll(run_path)
_, err = generateCvs(c, run_path, model.Id) _, err = generateCvs(c, run_path, model.Id)
if err != nil { if err != nil {
@ -174,29 +177,24 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
"DefId": definition_id, "DefId": definition_id,
"LoadPrev": load_prev, "LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"), "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path),
}); err != nil { }); err != nil {
return return
} }
// Run the command // Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output() out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
if err != nil { if err != nil {
c.Logger.Debug(string(out)) c.Logger.Debug(string(out))
return return
} }
c.Logger.Info("Python finished running")
if err = os.MkdirAll(result_path, os.ModePerm); err != nil { if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return return
} }
if err = exec.Command("cp", "-r", path.Join(run_path, "model"), path.Join(result_path, "model")).Run(); err != nil {
return
}
if err = exec.Command("cp", "-r", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run(); err != nil {
return
}
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val")) accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil { if err != nil {
return return
@ -214,8 +212,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
} }
c.Logger.Info("Model finished training!", "accuracy", accuracy) c.Logger.Info("Model finished training!", "accuracy", accuracy)
os.RemoveAll(run_path)
return return
} }
@ -236,6 +232,29 @@ func remove[T interface{}](lst []T, i int) []T {
return append(lst[:i], lst[i+1:]...) return append(lst[:i], lst[i+1:]...)
} }
type TrainModelRow struct {
id string
target_accuracy int
epoch int
acuracy float64
}
type TraingModelRowDefinitions []TrainModelRow
func (nf TraingModelRowDefinitions) Len() int { return len(nf) }
func (nf TraingModelRowDefinitions) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
func (nf TraingModelRowDefinitions) Less(i, j int) bool {
return nf[i].acuracy < nf[j].acuracy
}
type ToRemoveList []int
func (nf ToRemoveList) Len() int { return len(nf) }
func (nf ToRemoveList) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
func (nf ToRemoveList) Less(i, j int) bool {
return nf[i] < nf[j]
}
func trainModel(c *Context, model *BaseModel) { func trainModel(c *Context, model *BaseModel) {
definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
if err != nil { if err != nil {
@ -246,16 +265,11 @@ func trainModel(c *Context, model *BaseModel) {
} }
defer definitionsRows.Close() defer definitionsRows.Close()
type row struct { var definitions TraingModelRowDefinitions = []TrainModelRow{}
id string
target_accuracy int
epoch int
}
definitions := []row{}
for definitionsRows.Next() { for definitionsRows.Next() {
var rowv row var rowv TrainModelRow
rowv.acuracy = 0
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil { if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
c.Logger.Error("Failed to train Model Could not read definition from db!Err:") c.Logger.Error("Failed to train Model Could not read definition from db!Err:")
c.Logger.Error(err) c.Logger.Error(err)
@ -271,23 +285,23 @@ func trainModel(c *Context, model *BaseModel) {
return return
} }
toTrain := len(definitions)
firstRound := true firstRound := true
var newDefinitions = []row{} finished := false
copy(newDefinitions, definitions)
for { for {
var toRemove ToRemoveList = []int{}
for i, def := range definitions { for i, def := range definitions {
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, def.id, !firstRound) accuracy, err := trainDefinition(c, model, def.id, !firstRound)
if err != nil { if err != nil {
c.Logger.Error("Failed to train definition!Err:", "err", err) c.Logger.Error("Failed to train definition!Err:", "err", err)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toTrain = toTrain - 1 toRemove = append(toRemove, i)
newDefinitions = remove(newDefinitions, i)
continue continue
} }
def.epoch += EPOCH_PER_RUN def.epoch += EPOCH_PER_RUN
accuracy = accuracy * 100 accuracy = accuracy * 100
def.acuracy = accuracy
if accuracy >= float64(def.target_accuracy) { if accuracy >= float64(def.target_accuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!") c.Logger.Info("Found a definition that reaches target_accuracy!")
@ -305,30 +319,68 @@ func trainModel(c *Context, model *BaseModel) {
return return
} }
toTrain = 0 finished = true
break break
} }
if def.epoch > MAX_EPOCH { if def.epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy) fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toTrain = toTrain - 1 toRemove = append(toRemove, i)
newDefinitions = remove(newDefinitions, i)
continue continue
} }
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2 where id=$3", accuracy, def.epoch, def.id) _, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
if err != nil { if err != nil {
c.Logger.Error("Failed to train definition!Err:\n", "err", err) c.Logger.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING) ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return return
} }
} }
copy(definitions, newDefinitions)
firstRound = false firstRound = false
if toTrain == 0 { if finished {
break break
} }
sort.Reverse(toRemove)
c.Logger.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove {
definitions = remove(definitions, n)
}
len_def := len(definitions)
if len_def == 0 {
break
}
if len_def == 1 {
continue
}
sort.Sort(definitions)
acc := definitions[0].acuracy - 20
c.Logger.Info("Training models, Highest acc", "acc", acc)
toRemove = []int{}
for i, def := range definitions {
if def.acuracy < acc {
toRemove = append(toRemove, i)
}
}
c.Logger.Info("Removing due to accuracy", "toRemove", toRemove)
sort.Reverse(toRemove)
for _, n := range toRemove {
c.Logger.Warn("Removing definition not fast enough learning", "n", n)
definitions = remove(definitions, n)
}
} }
rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
@ -437,14 +489,26 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
return failed() return failed()
} }
order := 1; order := 1
// Note the shape of the first layer defines the import size
if complexity == 2 {
// Note the shape for now is no used // Note the shape for now is no used
width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
c.Logger.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height))
if err != nil {
return failed()
}
order++
} else {
err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) err = MakeLayer(c.Db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil { if err != nil {
return failed() return failed()
} }
order++; order++
}
if complexity == 0 { if complexity == 0 {
@ -452,12 +516,12 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
if err != nil { if err != nil {
return failed() return failed()
} }
order++; order++
loop := int(math.Log2(float64(number_of_classes))) loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++; order++
if err != nil { if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response // TODO improve this response
@ -465,17 +529,17 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
} }
} }
} else if (complexity == 1) { } else if complexity == 1 {
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
if loop == 0 { if loop == 0 {
loop = 1; loop = 1
} }
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "") err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "")
order++; order++
if err != nil { if err != nil {
return failed(); return failed()
} }
} }
@ -483,17 +547,49 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
if err != nil { if err != nil {
return failed() return failed()
} }
order++; order++
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
if loop == 0 { if loop == 0 {
loop = 1; loop = 1
} }
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++; order++
if err != nil { if err != nil {
return failed(); return failed()
}
}
} else if complexity == 2 {
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
if loop == 0 {
loop = 1
}
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "")
order++
if err != nil {
return failed()
}
}
err = MakeLayer(c.Db, def_id, order, LAYER_FLATTEN, "")
if err != nil {
return failed()
}
order++
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
if loop == 0 {
loop = 1
}
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++
if err != nil {
return failed()
} }
} }
@ -523,20 +619,27 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
return c.Error500(err) return c.Error500(err)
} }
if (number_of_models == 1) { cls_len := len(cls)
if (model.Width < 100 && model.Height < 100 && len(cls) < 30) {
generateDefinition(c, model, target_accuracy, len(cls), 0) if number_of_models == 1 {
if model.Width < 100 && model.Height < 100 && cls_len < 30 {
generateDefinition(c, model, target_accuracy, cls_len, 0)
} else if model.Width > 100 && model.Height > 100 {
generateDefinition(c, model, target_accuracy, cls_len, 2)
} else { } else {
generateDefinition(c, model, target_accuracy, len(cls), 1) generateDefinition(c, model, target_accuracy, cls_len, 1)
}
} else if number_of_models == 3 {
for i := 0; i < number_of_models; i++ {
generateDefinition(c, model, target_accuracy, cls_len, i)
} }
} else { } else {
// TODO handle incrisea the complexity // TODO handle incrisea the complexity
for i := 0; i < number_of_models; i++ { for i := 0; i < number_of_models; i++ {
generateDefinition(c, model, target_accuracy, len(cls), 0) generateDefinition(c, model, target_accuracy, cls_len, 0)
} }
} }
return nil return nil
} }

View File

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/charmbracelet/log"
_ "github.com/lib/pq" _ "github.com/lib/pq"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
@ -36,6 +37,7 @@ func main() {
_, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING) _, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING)
if err != nil { if err != nil {
log.Warn("Database might not be on")
panic(err) panic(err)
} }

View File

@ -14,7 +14,7 @@ create table if not exists models (
width integer, width integer,
height integer, height integer,
color_mode varchar (20), color_mode varchar (20),
format varchar (20) format varchar (20) default ''
); );
-- drop table if exists model_data_point; -- drop table if exists model_data_point;

View File

@ -438,27 +438,37 @@
<thead> <thead>
<tr> <tr>
<th> <th>
Status Training Round Progress
</th>
<th>
EpochProgress
</th> </th>
<th> <th>
Accuracy Accuracy
</th> </th>
<th>
Status
</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{{ range .Defs}} {{ range .Defs}}
<tr> <tr>
<td> <td>
{{.EpochProgress}}/20
</td>
<td>
{{.Accuracy}}%
</td>
<td style="text-align: center;">
{{ if (eq .Status 2) }}
<span class="bi bi-book" style="color: green;"></span>
{{ else if (eq .Status 3) }}
<span class="bi bi-book-half" style="color: green;"></span>
{{ else if (eq .Status 6) }}
<span class="bi bi-book-half" style="color: orange;"></span>
{{ else if (eq .Status -3) }}
<span class="bi bi-book-half" style="color: red;"></span>
{{ else }}
{{.Status}} {{.Status}}
</td> {{ end }}
<td>
{{.EpochProgress}}
</td>
<td>
{{.Accuracy}}
</td> </td>
</tr> </tr>
{{ end }} {{ end }}

View File

@ -8,7 +8,7 @@ import requests
class NotifyServerCallback(tf.keras.callbacks.Callback): class NotifyServerCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, log, *args, **kwargs): def on_epoch_end(self, epoch, log, *args, **kwargs):
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&accuracy={log["accuracy"]}&definition={{.DefId}}') requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
DATA_DIR = "{{ .DataDir }}" DATA_DIR = "{{ .DataDir }}"
@ -160,7 +160,9 @@ model.compile(
optimizer=tf.keras.optimizers.Adam(), optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy']) metrics=['accuracy'])
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()], use_multiprocessing = True) his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[
NotifyServerCallback(),
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)], use_multiprocessing = True)
acc = his.history["accuracy"] acc = his.history["accuracy"]
@ -169,5 +171,5 @@ f.write(str(acc[-1]))
f.close() f.close()
tf.saved_model.save(model, "model") tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
model.save("model.keras") model.save("{{ .SaveModelPath }}/model.keras")