From 90bc3f6acfec8905c6c9d456ab4e35804b99a2b9 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Sat, 21 Oct 2023 12:01:10 +0100 Subject: [PATCH] worked on #32 --- logic/models/train/train.go | 2 +- views/py/python_model_template.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 47a4256..dd363fc 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -454,7 +454,7 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe } order++; - loop := int(math.Log2(float64(number_of_classes))/2) + loop := int(math.Log2(float64(number_of_classes))) for i := 0; i < loop; i++ { err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) order++; diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index cbd2186..7fc2ec5 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -35,7 +35,6 @@ def pathToLabel(path): {{ if eq .Model.Format "png" }} path = tf.strings.regex_replace(path, ".png", "") {{ else if eq .Model.Format "jpeg" }} - path = tf.strings.regex_replace(path, ".jpg", "") path = tf.strings.regex_replace(path, ".jpeg", "") {{ else }} ERROR @@ -60,23 +59,28 @@ def process_path(path): return img, label -def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset: +def configure_for_performance(ds: tf.data.Dataset, size: int) -> tf.data.Dataset: #ds = ds.cache() - ds = ds.shuffle(buffer_size= 1000) + ds = ds.shuffle(buffer_size=size) ds = ds.batch(batch_size) ds = ds.prefetch(AUTOTUNE) return ds -def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset: +def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset: ds = ds.map(process_path, num_parallel_calls=AUTOTUNE) - ds = configure_for_performance(ds) + ds = configure_for_performance(ds, size) return ds def filterDataset(path): path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") - path = tf.strings.regex_replace(path, ".jpg", "") + {{ if eq .Model.Format "png" }} + path = tf.strings.regex_replace(path, ".png", "") + {{ else if eq .Model.Format "jpeg" }} path = tf.strings.regex_replace(path, ".jpeg", "") + {{ else }} + ERROR + {{ end }} return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1 @@ -97,8 +101,8 @@ val_size = int(image_count * 0.3) train_ds = list_ds.skip(val_size) val_ds = list_ds.take(val_size) -dataset = prepare_dataset(train_ds) -dataset_validation = prepare_dataset(val_ds) +dataset = prepare_dataset(train_ds, image_count) +dataset_validation = prepare_dataset(val_ds, val_size) track = 0