feat: started working on #23
This commit is contained in:
parent
bc948d4796
commit
a1d1a81ec5
@ -5,23 +5,19 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error")
|
||||||
|
|
||||||
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
|
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
|
||||||
id = ""
|
id = ""
|
||||||
_, err = db.Exec("insert into model_data_point (class_id, file_path, model_mode) values ($1, $2, $3);", class_id, file_path, mode)
|
result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode) values ($1, $2, $3) returning id;", class_id, file_path, mode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer result.Close()
|
||||||
rows, err := db.Query("select id from model_data_point where class_id=$1 and file_path=$2 and model_mode=$3", class_id, file_path, mode)
|
if !result.Next() {
|
||||||
if err != nil {
|
err = FailedToGetIdAfterInsertError
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
err = result.Scan(&id)
|
||||||
|
|
||||||
if !rows.Next() {
|
|
||||||
return id, errors.New("Something worng")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Scan(&id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
|||||||
|
|
||||||
var ClassAlreadyExists = errors.New("Class aready exists")
|
var ClassAlreadyExists = errors.New("Class aready exists")
|
||||||
|
|
||||||
func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) {
|
func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) {
|
||||||
id = ""
|
id = ""
|
||||||
rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
|
rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -58,25 +58,16 @@ func CreateClass(db *sql.DB, model_id string, name string) (id string, err error
|
|||||||
return id, ClassAlreadyExists
|
return id, ClassAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec("insert into model_classes (model_id, name) values ($1, $2)", model_id, name)
|
rows, err = db.Query("insert into model_classes (model_id, name, class_order) values ($1, $2, $3) returning id;", model_id, name, order)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err = db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return id, errors.New("Something wrong")
|
return id, errors.New("Insert did not return anything")
|
||||||
}
|
|
||||||
|
|
||||||
if err = rows.Scan(&id); err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(&id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -28,11 +28,11 @@ func InsertIfNotPresent(ss []string, s string) []string {
|
|||||||
return ss
|
return ss
|
||||||
}
|
}
|
||||||
|
|
||||||
func processZipFile(handle *Handle, id string) {
|
func processZipFile(handle *Handle, model_id string) {
|
||||||
reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip"))
|
reader, err := zip.OpenReader(path.Join("savedData", model_id, "base_data.zip"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO add msg to error
|
// TODO add msg to error
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
fmt.Printf("Faield to proccess zip file failed to open reader\n")
|
fmt.Printf("Faield to proccess zip file failed to open reader\n")
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
@ -52,7 +52,7 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
|
|
||||||
if paths[0] != "training" && paths[0] != "testing" {
|
if paths[0] != "training" && paths[0] != "testing" {
|
||||||
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
|
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,31 +67,21 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
fmt.Printf("testing and training are diferent\n")
|
fmt.Printf("testing and training are diferent\n")
|
||||||
fmt.Println(testing)
|
fmt.Println(testing)
|
||||||
fmt.Println(training)
|
fmt.Println(training)
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
base_path := path.Join("savedData", id, "data")
|
base_path := path.Join("savedData", model_id, "data")
|
||||||
|
if err = os.MkdirAll(base_path, os.ModePerm); err != nil {
|
||||||
|
fmt.Printf("Failed to create base_path dir\n")
|
||||||
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ids := map[string]string{}
|
ids := map[string]string{}
|
||||||
|
|
||||||
for _, name := range training {
|
for i, name := range training {
|
||||||
dir_path := path.Join(base_path, "training", name)
|
id, err := model_classes.CreateClass(handle.Db, model_id, i, name)
|
||||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dir_path = path.Join(base_path, "testing", name)
|
|
||||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := model_classes.CreateClass(handle.Db, id, name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to create class '%s' on db\n", name)
|
fmt.Printf("Failed to create class '%s' on db\n", name)
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
@ -105,23 +95,19 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
file_path := path.Join(base_path, file.Name)
|
|
||||||
f, err := os.Create(file_path)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Could not create file %s\n", file_path)
|
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
data, err := reader.Open(file.Name)
|
data, err := reader.Open(file.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Could not create file %s\n", file_path)
|
fmt.Printf("Could not open file in zip %s\n", file.Name)
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer data.Close()
|
defer data.Close()
|
||||||
file_data, err := io.ReadAll(data)
|
file_data, err := io.ReadAll(data)
|
||||||
f.Write(file_data)
|
if err != nil {
|
||||||
|
fmt.Printf("Could not read file file in zip %s\n", file.Name)
|
||||||
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// TODO check if the file is a valid photo that matched the defined photo on the database
|
// TODO check if the file is a valid photo that matched the defined photo on the database
|
||||||
|
|
||||||
@ -132,17 +118,27 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
mode = model_classes.DATA_POINT_MODE_TESTING
|
mode = model_classes.DATA_POINT_MODE_TESTING
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = model_classes.AddDataPoint(handle.Db, ids[parts[1]], "file://" + parts[2], mode)
|
data_point_id, err := model_classes.AddDataPoint(handle.Db, ids[parts[1]], "id://", mode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to add data point for %s\n", id)
|
fmt.Printf("Failed to add data point for %s\n", model_id)
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
file_path := path.Join(base_path, data_point_id + ".png")
|
||||||
|
f, err := os.Create(file_path)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Could not create file %s\n", file_path)
|
||||||
|
ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
f.Write(file_data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Added data to model '%s'!\n", id)
|
fmt.Printf("Added data to model '%s'!\n", model_id)
|
||||||
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
ModelUpdateStatus(handle, model_id, CONFIRM_PRE_TRAINING)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDataUpload(handle *Handle) {
|
func handleDataUpload(handle *Handle) {
|
||||||
|
@ -81,6 +81,8 @@ func handleDelete(handle *Handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch model.Status {
|
switch model.Status {
|
||||||
|
case FAILED_TRAINING:
|
||||||
|
fallthrough
|
||||||
case FAILED_PREPARING_TRAINING:
|
case FAILED_PREPARING_TRAINING:
|
||||||
fallthrough
|
fallthrough
|
||||||
case FAILED_PREPARING:
|
case FAILED_PREPARING:
|
||||||
|
@ -71,6 +71,40 @@ func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateCvs(handle *Handle, run_path string, model_id string) (count int, err error) {
|
||||||
|
|
||||||
|
classes, err := handle.Db.Query("select count(*) from model_classes where model_id=$1;", model_id)
|
||||||
|
if err != nil { return }
|
||||||
|
defer classes.Close()
|
||||||
|
if !classes.Next() { return }
|
||||||
|
if err = classes.Scan(&count); err != nil { return }
|
||||||
|
|
||||||
|
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
||||||
|
if err != nil { return }
|
||||||
|
defer data.Close()
|
||||||
|
|
||||||
|
type row struct {
|
||||||
|
path string
|
||||||
|
class_order int
|
||||||
|
}
|
||||||
|
|
||||||
|
got := []row{}
|
||||||
|
|
||||||
|
for data.Next() {
|
||||||
|
var id string
|
||||||
|
var class_order int
|
||||||
|
var file_path string
|
||||||
|
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
|
||||||
|
if file_path == "id://" {
|
||||||
|
got = append(got, row{id, class_order})
|
||||||
|
} else {
|
||||||
|
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) {
|
func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) {
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
@ -103,6 +137,10 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = generateCvs(handle, run_path); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Create python script
|
// Create python script
|
||||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -110,6 +148,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
|
|
||||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -118,13 +157,15 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
if err = tmpl.Execute(f, AnyMap{
|
if err = tmpl.Execute(f, AnyMap{
|
||||||
"Layers": got,
|
"Layers": got,
|
||||||
"Size": got[0].Shape,
|
"Size": got[0].Shape,
|
||||||
"DataDir": path.Join(getDir(), "savedData", model_id, "data", "training"),
|
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the command
|
// Run the command
|
||||||
if err = exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Run(); err != nil {
|
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(string(out))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -384,7 +425,8 @@ func handleTrain(handle *Handle) {
|
|||||||
// TODO improve this response
|
// TODO improve this response
|
||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls)))
|
// Using sparce
|
||||||
|
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
// TODO improve this response
|
// TODO improve this response
|
||||||
@ -400,7 +442,6 @@ func handleTrain(handle *Handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO start training with id fid
|
// TODO start training with id fid
|
||||||
|
|
||||||
go trainModel(handle, model)
|
go trainModel(handle, model)
|
||||||
|
|
||||||
ModelUpdateStatus(handle, model.Id, TRAINING)
|
ModelUpdateStatus(handle, model.Id, TRAINING)
|
||||||
|
@ -21,7 +21,8 @@ create table if not exists models (
|
|||||||
create table if not exists model_classes (
|
create table if not exists model_classes (
|
||||||
id uuid primary key default gen_random_uuid(),
|
id uuid primary key default gen_random_uuid(),
|
||||||
model_id uuid references models (id) on delete cascade,
|
model_id uuid references models (id) on delete cascade,
|
||||||
name varchar (70) not null
|
name varchar (70) not null,
|
||||||
|
class_order integer
|
||||||
);
|
);
|
||||||
|
|
||||||
-- drop table if exists model_data_point;
|
-- drop table if exists model_data_point;
|
||||||
|
@ -63,7 +63,11 @@
|
|||||||
{{range .List}}
|
{{range .List}}
|
||||||
<tr>
|
<tr>
|
||||||
<td>
|
<td>
|
||||||
|
{{ if eq .FilePath "id://" }}
|
||||||
|
Managed
|
||||||
|
{{ else }}
|
||||||
{{.FilePath}}
|
{{.FilePath}}
|
||||||
|
{{ end }}
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
{{ if (eq .Mode 2) }}
|
{{ if (eq .Mode 2) }}
|
||||||
@ -73,8 +77,8 @@
|
|||||||
{{ end }}
|
{{ end }}
|
||||||
</td>
|
</td>
|
||||||
<td class="text-center">
|
<td class="text-center">
|
||||||
{{ if startsWith .FilePath "file://" }}
|
{{ if startsWith .FilePath "id://" }}
|
||||||
<img src="/savedData/{{ $.ModelId }}/data/{{ if (eq .Mode 2) }}testing{{ else }}training{{ end }}/{{ $.Name }}/{{ replace .FilePath "file://" "" 1 }}" height="30px" width="30px" style="object-fit: contain;" />
|
<img src="/savedData/{{ $.ModelId }}/data/{{ .Id }}.png" height="30px" width="30px" style="object-fit: contain;" />
|
||||||
{{ else }}
|
{{ else }}
|
||||||
TODO
|
TODO
|
||||||
img {{ .FilePath }}
|
img {{ .FilePath }}
|
||||||
|
@ -1,33 +1,64 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import random
|
import random
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
from tensorflow.data import AUTOTUNE
|
||||||
from keras import layers, losses, optimizers
|
from keras import layers, losses, optimizers
|
||||||
|
|
||||||
|
DATA_DIR = "{{ .DataDir }}"
|
||||||
|
image_size = ({{ .Size }})
|
||||||
|
|
||||||
|
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||||
|
|
||||||
|
def pathToLabel(path):
|
||||||
|
path = tf.strings.regex_replace(path, DATA_DIR, "")
|
||||||
|
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||||
|
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
|
||||||
|
|
||||||
|
def decode_image(img):
|
||||||
|
# channels were reduced to 1 since image is grayscale
|
||||||
|
# TODO chnage channel number based if grayscale
|
||||||
|
img = tf.io.decode_png(img, channels=1)
|
||||||
|
|
||||||
|
return tf.image.resize(img, image_size)
|
||||||
|
|
||||||
|
def process_path(path):
|
||||||
|
label = pathToLabel(path)
|
||||||
|
|
||||||
|
img = tf.io.read_file(path)
|
||||||
|
img = decode_image(img)
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset:
|
||||||
|
#ds = ds.cache()
|
||||||
|
ds = ds.shuffle(buffer_size= 1000)
|
||||||
|
ds = ds.batch(batch_size)
|
||||||
|
ds = ds.prefetch(AUTOTUNE)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset:
|
||||||
|
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
||||||
|
ds = configure_for_performance(ds)
|
||||||
|
return ds
|
||||||
|
|
||||||
seed = random.randint(0, 100000000)
|
seed = random.randint(0, 100000000)
|
||||||
|
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
|
|
||||||
dataset = keras.utils.image_dataset_from_directory(
|
# Read all the files from the direcotry
|
||||||
"{{ .DataDir }}",
|
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
|
||||||
color_mode="rgb",
|
|
||||||
validation_split=0.2,
|
|
||||||
label_mode='categorical',
|
|
||||||
seed=seed,
|
|
||||||
shuffle=True,
|
|
||||||
subset="training",
|
|
||||||
image_size=({{ .Size }}),
|
|
||||||
batch_size=batch_size)
|
|
||||||
|
|
||||||
dataset_validation = keras.utils.image_dataset_from_directory(
|
image_count = len(list_ds)
|
||||||
"{{ .DataDir }}",
|
|
||||||
color_mode="rgb",
|
list_ds = list_ds.shuffle(image_count, seed=seed)
|
||||||
validation_split=0.2,
|
|
||||||
label_mode='categorical',
|
val_size = int(image_count * 0.3)
|
||||||
seed=seed,
|
|
||||||
shuffle=True,
|
train_ds = list_ds.skip(val_size)
|
||||||
subset="validation",
|
val_ds = list_ds.take(val_size)
|
||||||
image_size=({{ .Size }}),
|
|
||||||
batch_size=batch_size)
|
dataset = prepare_dataset(train_ds)
|
||||||
|
dataset_validation = prepare_dataset(val_ds)
|
||||||
|
|
||||||
model = keras.Sequential([
|
model = keras.Sequential([
|
||||||
{{- range .Layers }}
|
{{- range .Layers }}
|
||||||
@ -44,7 +75,7 @@ model = keras.Sequential([
|
|||||||
])
|
])
|
||||||
|
|
||||||
model.compile(
|
model.compile(
|
||||||
loss=losses.CategoricalCrossentropy(),
|
loss=losses.SparceCategoricalCrossentropy(),
|
||||||
optimizer=tf.keras.optimizers.Adam(),
|
optimizer=tf.keras.optimizers.Adam(),
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user