feat: closes #40
This commit is contained in:
parent
f163e25fba
commit
2c3539b81a
@ -17,6 +17,9 @@ import (
|
|||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const EPOCH_PER_RUN = 20;
|
||||||
|
const MAX_EPOCH = 100
|
||||||
|
|
||||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||||
id = ""
|
id = ""
|
||||||
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
||||||
@ -34,12 +37,13 @@ func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string
|
|||||||
type ModelDefinitionStatus int
|
type ModelDefinitionStatus int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
|
MODEL_DEFINITION_STATUS_CANCELD_TRAINING = -4
|
||||||
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
|
MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
|
||||||
MODEL_DEFINITION_STATUS_INIT = 2
|
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
|
||||||
MODEL_DEFINITION_STATUS_TRAINING = 3
|
MODEL_DEFINITION_STATUS_INIT = 2
|
||||||
MODEL_DEFINITION_STATUS_TRANIED = 4
|
MODEL_DEFINITION_STATUS_TRAINING = 3
|
||||||
MODEL_DEFINITION_STATUS_READY = 5
|
MODEL_DEFINITION_STATUS_TRANIED = 4
|
||||||
|
MODEL_DEFINITION_STATUS_READY = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
type LayerType int
|
type LayerType int
|
||||||
@ -104,7 +108,8 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func trainDefinition(c *Context, model *BaseModel, definition_id string) (accuracy float64, err error) {
|
func trainDefinition(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||||
|
c.Logger.Warn("About to start training definition")
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -153,14 +158,20 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy result around
|
||||||
|
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
||||||
|
|
||||||
if err = tmpl.Execute(f, AnyMap{
|
if err = tmpl.Execute(f, AnyMap{
|
||||||
"Layers": got,
|
"Layers": got,
|
||||||
"Size": got[0].Shape,
|
"Size": got[0].Shape,
|
||||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||||
"RunPath": run_path,
|
"RunPath": run_path,
|
||||||
"ColorMode": model.ImageMode,
|
"ColorMode": model.ImageMode,
|
||||||
"Model": model,
|
"Model": model,
|
||||||
"DefId": definition_id,
|
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
||||||
|
"DefId": definition_id,
|
||||||
|
"LoadPrev": load_prev,
|
||||||
|
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -168,13 +179,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
|||||||
// Run the command
|
// Run the command
|
||||||
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output()
|
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Debug(string(out))
|
c.Logger.Debug(string(out))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy result around
|
|
||||||
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
|
||||||
|
|
||||||
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -183,6 +191,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = exec.Command("cp", "-r", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
|
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -194,7 +206,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(string(accuracy_file_bytes))
|
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||||
|
|
||||||
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -205,8 +217,25 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func remove[T interface{}](lst []T, i int) []T {
|
||||||
|
lng := len(lst)
|
||||||
|
if i >= lng {
|
||||||
|
return []T{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if i+1 >= lng {
|
||||||
|
return lst[:lng-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
return lst[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(lst[:i], lst[i+1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
func trainModel(c *Context, model *BaseModel) {
|
func trainModel(c *Context, model *BaseModel) {
|
||||||
definitionsRows, err := c.Db.Query("select id, target_accuracy from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
|
definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Failed to trainModel!Err:")
|
c.Logger.Error("Failed to trainModel!Err:")
|
||||||
c.Logger.Error(err)
|
c.Logger.Error(err)
|
||||||
@ -218,13 +247,14 @@ func trainModel(c *Context, model *BaseModel) {
|
|||||||
type row struct {
|
type row struct {
|
||||||
id string
|
id string
|
||||||
target_accuracy int
|
target_accuracy int
|
||||||
|
epoch int
|
||||||
}
|
}
|
||||||
|
|
||||||
definitions := []row{}
|
definitions := []row{}
|
||||||
|
|
||||||
for definitionsRows.Next() {
|
for definitionsRows.Next() {
|
||||||
var rowv row
|
var rowv row
|
||||||
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy); err != nil {
|
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
|
||||||
c.Logger.Error("Failed to train Model Could not read definition from db!Err:")
|
c.Logger.Error("Failed to train Model Could not read definition from db!Err:")
|
||||||
c.Logger.Error(err)
|
c.Logger.Error(err)
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||||
@ -239,30 +269,58 @@ func trainModel(c *Context, model *BaseModel) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, def := range definitions {
|
toTrain := len(definitions)
|
||||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
firstRound := true
|
||||||
accuracy, err := trainDefinition(c, model, def.id)
|
var newDefinitions = []row{}
|
||||||
if err != nil {
|
copy(newDefinitions, definitions)
|
||||||
c.Logger.Error("Failed to train definition!Err:")
|
for {
|
||||||
c.Logger.Error(err)
|
for i, def := range definitions {
|
||||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||||
continue
|
accuracy, err := trainDefinition(c, model, def.id, !firstRound)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to train definition!Err:", "err", err)
|
||||||
|
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||||
|
toTrain = toTrain - 1
|
||||||
|
newDefinitions = remove(newDefinitions, i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
def.epoch += EPOCH_PER_RUN
|
||||||
|
|
||||||
|
int_accuracy := int(accuracy * 100)
|
||||||
|
|
||||||
|
if int_accuracy >= def.target_accuracy {
|
||||||
|
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
||||||
|
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
||||||
|
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
|
||||||
|
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toTrain = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if def.epoch > MAX_EPOCH {
|
||||||
|
fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
|
||||||
|
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||||
|
toTrain = toTrain - 1
|
||||||
|
newDefinitions = remove(newDefinitions, i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
copy(definitions, newDefinitions)
|
||||||
int_accuracy := int(accuracy * 100)
|
firstRound = false
|
||||||
|
if toTrain == 0 {
|
||||||
if int_accuracy < def.target_accuracy {
|
break
|
||||||
fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
|
|
||||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2 where id=$3", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.id)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to train definition!Err:\n")
|
|
||||||
fmt.Println(err)
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -335,7 +393,7 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
|
|||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
base_path := path.Join("savedData", model.Id, "data")
|
base_path := path.Join("savedData", model.Id, "data")
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dataPointId string
|
var dataPointId string
|
||||||
@ -343,13 +401,13 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = os.RemoveAll(path.Join(base_path, dataPointId + model.Format))
|
err = os.RemoveAll(path.Join(base_path, dataPointId+model.Format))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id)
|
_, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -484,56 +542,56 @@ func handleTrain(handle *Handle) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
// TODO check auth level
|
// TODO check auth level
|
||||||
if c.Mode != NORMAL {
|
if c.Mode != NORMAL {
|
||||||
// This should only handle normal requests
|
// This should only handle normal requests
|
||||||
c.Logger.Warn("This function only works with normal")
|
c.Logger.Warn("This function only works with normal")
|
||||||
return c.UnsafeErrorCode(nil, 400, nil)
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
f := r.URL.Query()
|
f := r.URL.Query()
|
||||||
|
|
||||||
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
|
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
|
||||||
c.Logger.Warn("Invalid: model_id or definition or epoch")
|
c.Logger.Warn("Invalid: model_id or definition or epoch")
|
||||||
return c.UnsafeErrorCode(nil, 400, nil)
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
model_id := f.Get("model_id")
|
model_id := f.Get("model_id")
|
||||||
def_id := f.Get("definition")
|
def_id := f.Get("definition")
|
||||||
epoch, err := strconv.Atoi(f.Get("epoch"))
|
epoch, err := strconv.Atoi(f.Get("epoch"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Warn("Epoch is not a number")
|
c.Logger.Warn("Epoch is not a number")
|
||||||
// No need to improve message because this function is only called internaly
|
// No need to improve message because this function is only called internaly
|
||||||
return c.UnsafeErrorCode(nil, 400, nil)
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
|
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
c.Logger.Error("Could not get status of model definition")
|
c.Logger.Error("Could not get status of model definition")
|
||||||
return c.Error500(nil)
|
return c.Error500(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var status int
|
var status int
|
||||||
err = rows.Scan(&status)
|
err = rows.Scan(&status)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if status != 3 {
|
if status != 3 {
|
||||||
c.Logger.Warn("Definition not on status 3(training)", "status", status)
|
c.Logger.Warn("Definition not on status 3(training)", "status", status)
|
||||||
// No need to improve message because this function is only called internaly
|
// No need to improve message because this function is only called internaly
|
||||||
return c.UnsafeErrorCode(nil, 400, nil)
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
|
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,6 @@ create table if not exists model_data_point (
|
|||||||
status_message text
|
status_message text
|
||||||
);
|
);
|
||||||
|
|
||||||
-- drop table if exists model_definition;
|
|
||||||
-- drop table if exists model_definition;
|
-- drop table if exists model_definition;
|
||||||
create table if not exists model_definition (
|
create table if not exists model_definition (
|
||||||
id uuid primary key default gen_random_uuid(),
|
id uuid primary key default gen_random_uuid(),
|
||||||
|
@ -434,19 +434,36 @@
|
|||||||
{{/* TODO improve this */}}
|
{{/* TODO improve this */}}
|
||||||
Training the model...<br/>
|
Training the model...<br/>
|
||||||
{{/* TODO Add progress status on definitions */}}
|
{{/* TODO Add progress status on definitions */}}
|
||||||
{{ range .Defs}}
|
<table>
|
||||||
<div>
|
<thead>
|
||||||
<div>
|
<tr>
|
||||||
{{.Status}}
|
<th>
|
||||||
</div>
|
Status
|
||||||
<div>
|
</th>
|
||||||
{{.EpochProgress}}
|
<th>
|
||||||
</div>
|
EpochProgress
|
||||||
<div>
|
</th>
|
||||||
{{.Accuracy}}
|
<th>
|
||||||
</div>
|
Accuracy
|
||||||
</div>
|
</th>
|
||||||
{{ end }}
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{ range .Defs}}
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
{{.Status}}
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
{{.EpochProgress}}
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
{{.Accuracy}}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
{{ end }}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
{{/* TODO Add ability to stop training */}}
|
{{/* TODO Add ability to stop training */}}
|
||||||
</div>
|
</div>
|
||||||
{{/* Model Ready */}}
|
{{/* Model Ready */}}
|
||||||
|
@ -93,6 +93,10 @@ val_ds = list_ds.take(val_size)
|
|||||||
dataset = prepare_dataset(train_ds)
|
dataset = prepare_dataset(train_ds)
|
||||||
dataset_validation = prepare_dataset(val_ds)
|
dataset_validation = prepare_dataset(val_ds)
|
||||||
|
|
||||||
|
|
||||||
|
{{ if .LoadPrev }}
|
||||||
|
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
|
||||||
|
{{ else }}
|
||||||
model = keras.Sequential([
|
model = keras.Sequential([
|
||||||
{{- range .Layers }}
|
{{- range .Layers }}
|
||||||
{{- if eq .LayerType 1}}
|
{{- if eq .LayerType 1}}
|
||||||
@ -106,13 +110,14 @@ model = keras.Sequential([
|
|||||||
{{- end }}
|
{{- end }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
])
|
])
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
model.compile(
|
model.compile(
|
||||||
loss=losses.SparseCategoricalCrossentropy(),
|
loss=losses.SparseCategoricalCrossentropy(),
|
||||||
optimizer=tf.keras.optimizers.Adam(),
|
optimizer=tf.keras.optimizers.Adam(),
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
|
his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()])
|
||||||
|
|
||||||
acc = his.history["accuracy"]
|
acc = his.history["accuracy"]
|
||||||
|
|
||||||
@ -120,6 +125,6 @@ f = open("accuracy.val", "w")
|
|||||||
f.write(str(acc[-1]))
|
f.write(str(acc[-1]))
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
tf.saved_model.save(model, "model")
|
|
||||||
|
|
||||||
# model.save("model.keras", save_format="tf")
|
tf.saved_model.save(model, "model")
|
||||||
|
model.save("model.keras")
|
||||||
|
Loading…
Reference in New Issue
Block a user