From ff9aca269927c1c84c8f6d205325095345e5e950 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Fri, 20 Oct 2023 13:11:46 +0100 Subject: [PATCH] closes #41 --- logic/models/train/train.go | 2 +- views/py/python_model_template.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index b25e8db..3e1196a 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -79,7 +79,7 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e return } - data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) + data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING) if err != nil { return } diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 0046bde..a55b99f 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -74,14 +74,23 @@ def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset: ds = configure_for_performance(ds) return ds +def filterDataset(path): + path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") + + path = tf.strings.regex_replace(path, ".jpg", "") + path = tf.strings.regex_replace(path, ".jpeg", "") + + return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1 + seed = random.randint(0, 100000000) batch_size = 100 # Read all the files from the direcotry list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False) +list_ds = list_ds.filter(filterDataset) -image_count = len(list_ds) +image_count = len(list(list_ds.as_numpy_iterator())) list_ds = list_ds.shuffle(image_count, seed=seed)