This commit is contained in:
Andre Henriques 2023-10-21 12:01:10 +01:00
parent 805be22388
commit 90bc3f6acf
2 changed files with 13 additions and 9 deletions

View File

@ -454,7 +454,7 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
}
order++;
loop := int(math.Log2(float64(number_of_classes))/2)
loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++;

View File

@ -35,7 +35,6 @@ def pathToLabel(path):
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpg", "")
path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
@ -60,23 +59,28 @@ def process_path(path):
return img, label
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset:
def configure_for_performance(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
#ds = ds.cache()
ds = ds.shuffle(buffer_size= 1000)
ds = ds.shuffle(buffer_size=size)
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTOTUNE)
return ds
def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset:
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
ds = configure_for_performance(ds)
ds = configure_for_performance(ds, size)
return ds
def filterDataset(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
path = tf.strings.regex_replace(path, ".jpg", "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
{{ end }}
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
@ -97,8 +101,8 @@ val_size = int(image_count * 0.3)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)
dataset = prepare_dataset(train_ds)
dataset_validation = prepare_dataset(val_ds)
dataset = prepare_dataset(train_ds, image_count)
dataset_validation = prepare_dataset(val_ds, val_size)
track = 0