This commit is contained in:
Andre Henriques 2023-10-21 12:01:10 +01:00
parent 805be22388
commit 90bc3f6acf
2 changed files with 13 additions and 9 deletions

View File

@ -454,7 +454,7 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
} }
order++; order++;
loop := int(math.Log2(float64(number_of_classes))/2) loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++; order++;

View File

@ -35,7 +35,6 @@ def pathToLabel(path):
{{ if eq .Model.Format "png" }} {{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "") path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }} {{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpg", "")
path = tf.strings.regex_replace(path, ".jpeg", "") path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }} {{ else }}
ERROR ERROR
@ -60,23 +59,28 @@ def process_path(path):
return img, label return img, label
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset: def configure_for_performance(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
#ds = ds.cache() #ds = ds.cache()
ds = ds.shuffle(buffer_size= 1000) ds = ds.shuffle(buffer_size=size)
ds = ds.batch(batch_size) ds = ds.batch(batch_size)
ds = ds.prefetch(AUTOTUNE) ds = ds.prefetch(AUTOTUNE)
return ds return ds
def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset: def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE) ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
ds = configure_for_performance(ds) ds = configure_for_performance(ds, size)
return ds return ds
def filterDataset(path): def filterDataset(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
path = tf.strings.regex_replace(path, ".jpg", "") {{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpeg", "") path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
{{ end }}
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1 return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
@ -97,8 +101,8 @@ val_size = int(image_count * 0.3)
train_ds = list_ds.skip(val_size) train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size) val_ds = list_ds.take(val_size)
dataset = prepare_dataset(train_ds) dataset = prepare_dataset(train_ds, image_count)
dataset_validation = prepare_dataset(val_ds) dataset_validation = prepare_dataset(val_ds, val_size)
track = 0 track = 0