worked on #32
This commit is contained in:
parent
805be22388
commit
90bc3f6acf
@ -454,7 +454,7 @@ func generateDefinition(c *Context, model *BaseModel, target_accuracy int, numbe
|
|||||||
}
|
}
|
||||||
order++;
|
order++;
|
||||||
|
|
||||||
loop := int(math.Log2(float64(number_of_classes))/2)
|
loop := int(math.Log2(float64(number_of_classes)))
|
||||||
for i := 0; i < loop; i++ {
|
for i := 0; i < loop; i++ {
|
||||||
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
|
||||||
order++;
|
order++;
|
||||||
|
@ -35,7 +35,6 @@ def pathToLabel(path):
|
|||||||
{{ if eq .Model.Format "png" }}
|
{{ if eq .Model.Format "png" }}
|
||||||
path = tf.strings.regex_replace(path, ".png", "")
|
path = tf.strings.regex_replace(path, ".png", "")
|
||||||
{{ else if eq .Model.Format "jpeg" }}
|
{{ else if eq .Model.Format "jpeg" }}
|
||||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
|
||||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||||
{{ else }}
|
{{ else }}
|
||||||
ERROR
|
ERROR
|
||||||
@ -60,23 +59,28 @@ def process_path(path):
|
|||||||
|
|
||||||
return img, label
|
return img, label
|
||||||
|
|
||||||
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset:
|
def configure_for_performance(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
|
||||||
#ds = ds.cache()
|
#ds = ds.cache()
|
||||||
ds = ds.shuffle(buffer_size= 1000)
|
ds = ds.shuffle(buffer_size=size)
|
||||||
ds = ds.batch(batch_size)
|
ds = ds.batch(batch_size)
|
||||||
ds = ds.prefetch(AUTOTUNE)
|
ds = ds.prefetch(AUTOTUNE)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset:
|
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
|
||||||
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
||||||
ds = configure_for_performance(ds)
|
ds = configure_for_performance(ds, size)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
def filterDataset(path):
|
def filterDataset(path):
|
||||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||||
|
|
||||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
{{ if eq .Model.Format "png" }}
|
||||||
|
path = tf.strings.regex_replace(path, ".png", "")
|
||||||
|
{{ else if eq .Model.Format "jpeg" }}
|
||||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||||
|
{{ else }}
|
||||||
|
ERROR
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
|
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
|
||||||
|
|
||||||
@ -97,8 +101,8 @@ val_size = int(image_count * 0.3)
|
|||||||
train_ds = list_ds.skip(val_size)
|
train_ds = list_ds.skip(val_size)
|
||||||
val_ds = list_ds.take(val_size)
|
val_ds = list_ds.take(val_size)
|
||||||
|
|
||||||
dataset = prepare_dataset(train_ds)
|
dataset = prepare_dataset(train_ds, image_count)
|
||||||
dataset_validation = prepare_dataset(val_ds)
|
dataset_validation = prepare_dataset(val_ds, val_size)
|
||||||
|
|
||||||
track = 0
|
track = 0
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user