moved code arround
This commit is contained in:
parent
cbfd49c5fb
commit
6a0ac457d7
@ -42,6 +42,11 @@ func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionSt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UpdateStatus (c *Context, table string, id string, status int) (err error) {
|
||||||
|
_, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
|
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
|
||||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
||||||
return
|
return
|
||||||
@ -108,9 +113,11 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
|||||||
type layerrow struct {
|
type layerrow struct {
|
||||||
LayerType int
|
LayerType int
|
||||||
Shape string
|
Shape string
|
||||||
|
LayerNum int
|
||||||
}
|
}
|
||||||
|
|
||||||
got := []layerrow{}
|
got := []layerrow{}
|
||||||
|
i := 1
|
||||||
|
|
||||||
for layers.Next() {
|
for layers.Next() {
|
||||||
var row = layerrow{}
|
var row = layerrow{}
|
||||||
@ -118,7 +125,9 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row.Shape = shapeToSize(row.Shape)
|
row.Shape = shapeToSize(row.Shape)
|
||||||
|
row.LayerNum = 1
|
||||||
got = append(got, row)
|
got = append(got, row)
|
||||||
|
i = i + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate run folder
|
// Generate run folder
|
||||||
@ -207,7 +216,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
// Get untrained models heads
|
// Get untrained models heads
|
||||||
|
|
||||||
// Status = 2 (INIT)
|
// Status = 2 (INIT)
|
||||||
rows, err := c.Db.Query("select id, range_start, range_end exp_model_head where def_id=$1 and status = 2", definition_id)
|
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and status = 2", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -237,6 +246,8 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRANIED)
|
||||||
|
|
||||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -300,7 +311,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
tmpl, err := template.New("python_model_template-exp.py").ParseFiles("views/py/python_model_template.py")
|
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -312,6 +323,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
"Layers": got,
|
"Layers": got,
|
||||||
"Size": got[0].Shape,
|
"Size": got[0].Shape,
|
||||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||||
|
"HeadId": exp.id,
|
||||||
"RunPath": run_path,
|
"RunPath": run_path,
|
||||||
"ColorMode": model.ImageMode,
|
"ColorMode": model.ImageMode,
|
||||||
"Model": model,
|
"Model": model,
|
||||||
@ -613,17 +625,14 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
var rowv TrainModelRow
|
var rowv TrainModelRow
|
||||||
rowv.acuracy = 0
|
rowv.acuracy = 0
|
||||||
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
|
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
|
||||||
c.Logger.Error("Failed to train Model Could not read definition from db!Err:")
|
failed("Failed to train Model Could not read definition from db!")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
definitions = append(definitions, rowv)
|
definitions = append(definitions, rowv)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(definitions) == 0 {
|
if len(definitions) == 0 {
|
||||||
c.Logger.Error("No Definitions defined!")
|
failed("No Definitions defined!")
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -652,15 +661,13 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
||||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
|
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
failed("Failed to train definition!")
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
failed("Failed to train definition!")
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -677,8 +684,7 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
|
|
||||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
|
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
failed("Failed to train definition!")
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -731,40 +737,30 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
|
|
||||||
rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
|
rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("DB: failed to read definition")
|
failed("DB: failed to read definition")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
// TODO Make the Model status have a message
|
failed("All definitions failed to train!")
|
||||||
c.Logger.Error("All definitions failed to train!")
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var id string
|
var id string
|
||||||
if err = rows.Scan(&id); err != nil {
|
if err = rows.Scan(&id); err != nil {
|
||||||
c.Logger.Error("Failed to read id:")
|
failed("Failed to read id")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = c.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil {
|
if _, err = c.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil {
|
||||||
c.Logger.Error("Failed to update model definition")
|
failed("Failed to update model definition")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
to_delete, err := c.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
|
to_delete, err := c.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Failed to select model_definition to delete")
|
failed("Failed to select model_definition to delete")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer to_delete.Close()
|
defer to_delete.Close()
|
||||||
@ -772,9 +768,7 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
for to_delete.Next() {
|
for to_delete.Next() {
|
||||||
var id string
|
var id string
|
||||||
if to_delete.Scan(&id); err != nil {
|
if to_delete.Scan(&id); err != nil {
|
||||||
c.Logger.Error("Failed to scan the id of a model_definition to delete")
|
failed("Failed to scan the id of a model_definition to delete")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
|
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
|
||||||
@ -782,9 +776,7 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
|
|
||||||
// TODO Check if returning also works here
|
// TODO Check if returning also works here
|
||||||
if _, err = c.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
|
if _, err = c.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
|
||||||
c.Logger.Error("Failed to delete model_definition")
|
failed("Failed to delete model_definition")
|
||||||
c.Logger.Error(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1255,4 +1247,63 @@ func handleTrain(handle *Handle) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
handle.Get("/model/head/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
|
// TODO check auth level
|
||||||
|
if c.Mode != NORMAL {
|
||||||
|
// This should only handle normal requests
|
||||||
|
c.Logger.Warn("This function only works with normal")
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
f := r.URL.Query()
|
||||||
|
|
||||||
|
accuracy := 0.0
|
||||||
|
|
||||||
|
if !CheckId(f, "head_id") || CheckEmpty(f, "epoch") || !CheckFloat64(f, "accuracy", &accuracy) {
|
||||||
|
c.Logger.Warn("Invalid: model_id or head_id or epoch or accuracy")
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
accuracy = accuracy * 100
|
||||||
|
|
||||||
|
head_id := f.Get("head_id")
|
||||||
|
epoch, err := strconv.Atoi(f.Get("epoch"))
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Warn("Epoch is not a number")
|
||||||
|
// No need to improve message because this function is only called internaly
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := c.Db.Query("select hd.status from exp_model_head as hd where hd.id=$1;", head_id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
c.Logger.Error("Could not get status of model head")
|
||||||
|
return c.Error500(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status int
|
||||||
|
err = rows.Scan(&status)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != 3 {
|
||||||
|
c.Logger.Warn("Head not on status 3(training)", "status", status)
|
||||||
|
// No need to improve message because this function is only called internaly
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("Updated model_head!", "head", head_id, "progress", epoch, "accuracy", accuracy)
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update exp_model_head set epoch_progress=$1, accuracy=$2 where id=$3", epoch, accuracy, head_id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,177 +0,0 @@
|
|||||||
import tensorflow as tf
|
|
||||||
import random
|
|
||||||
import pandas as pd
|
|
||||||
from tensorflow import keras
|
|
||||||
from tensorflow.data import AUTOTUNE
|
|
||||||
from keras import layers, losses, optimizers
|
|
||||||
import requests
|
|
||||||
|
|
||||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
|
||||||
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
|
||||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = "{{ .DataDir }}"
|
|
||||||
image_size = ({{ .Size }})
|
|
||||||
|
|
||||||
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
|
||||||
keys = tf.constant(df['Id'].dropna())
|
|
||||||
values = tf.constant(list(map(int, df['Index'].dropna())))
|
|
||||||
|
|
||||||
table = tf.lookup.StaticHashTable(
|
|
||||||
initializer=tf.lookup.KeyValueTensorInitializer(
|
|
||||||
keys=keys,
|
|
||||||
values=values,
|
|
||||||
),
|
|
||||||
default_value=tf.constant(-1),
|
|
||||||
name="Indexes"
|
|
||||||
)
|
|
||||||
|
|
||||||
DATA_DIR_PREPARE = DATA_DIR + "/"
|
|
||||||
|
|
||||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
|
||||||
def pathToLabel(path):
|
|
||||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
|
||||||
{{ if eq .Model.Format "png" }}
|
|
||||||
path = tf.strings.regex_replace(path, ".png", "")
|
|
||||||
{{ else if eq .Model.Format "jpeg" }}
|
|
||||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
|
||||||
{{ else }}
|
|
||||||
ERROR
|
|
||||||
{{ end }}
|
|
||||||
return table.lookup(tf.strings.as_string([path]))
|
|
||||||
|
|
||||||
def decode_image(img):
|
|
||||||
{{ if eq .Model.Format "png" }}
|
|
||||||
img = tf.io.decode_png(img, channels={{.ColorMode}})
|
|
||||||
{{ else if eq .Model.Format "jpeg" }}
|
|
||||||
img = tf.io.decode_jpeg(img, channels={{.ColorMode}})
|
|
||||||
{{ else }}
|
|
||||||
ERROR
|
|
||||||
{{ end }}
|
|
||||||
return tf.image.resize(img, image_size)
|
|
||||||
|
|
||||||
def process_path(path):
|
|
||||||
label = pathToLabel(path)
|
|
||||||
|
|
||||||
img = tf.io.read_file(path)
|
|
||||||
img = decode_image(img)
|
|
||||||
|
|
||||||
return img, label
|
|
||||||
|
|
||||||
def configure_for_performance(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
|
|
||||||
#ds = ds.cache()
|
|
||||||
ds = ds.shuffle(buffer_size=size)
|
|
||||||
ds = ds.batch(batch_size)
|
|
||||||
ds = ds.prefetch(AUTOTUNE)
|
|
||||||
return ds
|
|
||||||
|
|
||||||
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
|
|
||||||
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
|
||||||
ds = configure_for_performance(ds, size)
|
|
||||||
return ds
|
|
||||||
|
|
||||||
def filterDataset(path):
|
|
||||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
|
||||||
|
|
||||||
{{ if eq .Model.Format "png" }}
|
|
||||||
path = tf.strings.regex_replace(path, ".png", "")
|
|
||||||
{{ else if eq .Model.Format "jpeg" }}
|
|
||||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
|
||||||
{{ else }}
|
|
||||||
ERROR
|
|
||||||
{{ end }}
|
|
||||||
|
|
||||||
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
|
|
||||||
|
|
||||||
seed = random.randint(0, 100000000)
|
|
||||||
|
|
||||||
batch_size = 64
|
|
||||||
|
|
||||||
# Read all the files from the direcotry
|
|
||||||
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
|
|
||||||
list_ds = list_ds.filter(filterDataset)
|
|
||||||
|
|
||||||
image_count = len(list(list_ds.as_numpy_iterator()))
|
|
||||||
|
|
||||||
list_ds = list_ds.shuffle(image_count, seed=seed)
|
|
||||||
|
|
||||||
val_size = int(image_count * 0.3)
|
|
||||||
|
|
||||||
train_ds = list_ds.skip(val_size)
|
|
||||||
val_ds = list_ds.take(val_size)
|
|
||||||
|
|
||||||
dataset = prepare_dataset(train_ds, image_count)
|
|
||||||
dataset_validation = prepare_dataset(val_ds, val_size)
|
|
||||||
|
|
||||||
track = 0
|
|
||||||
|
|
||||||
def addBlock(
|
|
||||||
b_size: int,
|
|
||||||
filter_size: int,
|
|
||||||
kernel_size: int = 3,
|
|
||||||
top: bool = True,
|
|
||||||
pooling_same: bool = False,
|
|
||||||
pool_func=layers.MaxPool2D,
|
|
||||||
layerNum = 0
|
|
||||||
):
|
|
||||||
global track
|
|
||||||
# model = keras.Sequential(name=f"{track}-{b_size}-{filter_size}-{kernel_size}")
|
|
||||||
model = keras.Sequential(name=f"layer{layerNum}")
|
|
||||||
track += 1
|
|
||||||
for _ in range(b_size):
|
|
||||||
model.add(layers.Conv2D(
|
|
||||||
filter_size,
|
|
||||||
kernel_size,
|
|
||||||
padding="same"
|
|
||||||
))
|
|
||||||
model.add(layers.ReLU())
|
|
||||||
if top:
|
|
||||||
if pooling_same:
|
|
||||||
model.add(pool_func(padding="same", strides=(1, 1)))
|
|
||||||
else:
|
|
||||||
model.add(pool_func())
|
|
||||||
model.add(layers.BatchNormalization())
|
|
||||||
model.add(layers.LeakyReLU())
|
|
||||||
model.add(layers.Dropout(0.4))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
{{ if .LoadPrev }}
|
|
||||||
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
|
|
||||||
{{ else }}
|
|
||||||
model = keras.Sequential()
|
|
||||||
|
|
||||||
{{- range .Layers }}
|
|
||||||
{{- if eq .LayerType 1}}
|
|
||||||
model.add(layers.Rescaling(1./255, name="layer{{ .LayerNum }}"))
|
|
||||||
{{- else if eq .LayerType 2 }}
|
|
||||||
model.add(layers.Dense({{ .Shape }}, activation="sigmoid", name="layer{{ .LayerNum }}"))
|
|
||||||
{{- else if eq .LayerType 3}}
|
|
||||||
model.add(layers.Flatten(name="layer{{ .LayerNum }}"))
|
|
||||||
{{- else if eq .LayerType 4}}
|
|
||||||
model.add(addBlock(2, 128, 3, pool_func=layers.AveragePooling2D, layerNum={{.LayerNum}}))
|
|
||||||
{{- else }}
|
|
||||||
ERROR
|
|
||||||
{{- end }}
|
|
||||||
{{- end }}
|
|
||||||
{{ end }}
|
|
||||||
|
|
||||||
model.compile(
|
|
||||||
loss=losses.SparseCategoricalCrossentropy(),
|
|
||||||
optimizer=tf.keras.optimizers.Adam(),
|
|
||||||
metrics=['accuracy'])
|
|
||||||
|
|
||||||
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[
|
|
||||||
NotifyServerCallback(),
|
|
||||||
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)], use_multiprocessing = True)
|
|
||||||
|
|
||||||
acc = his.history["accuracy"]
|
|
||||||
|
|
||||||
f = open("accuracy.val", "w")
|
|
||||||
f.write(str(acc[-1]))
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
|
|
||||||
tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
|
|
||||||
model.save("{{ .SaveModelPath }}/model.keras")
|
|
@ -8,7 +8,12 @@ import requests
|
|||||||
|
|
||||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||||
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
||||||
|
{{ if .HeadId }}
|
||||||
|
requests.get(f'http://localhost:8000//model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
|
||||||
|
{{ else }}
|
||||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = "{{ .DataDir }}"
|
DATA_DIR = "{{ .DataDir }}"
|
||||||
|
Loading…
Reference in New Issue
Block a user