feat: started working on #23
This commit is contained in:
@@ -1,33 +1,64 @@
|
||||
import tensorflow as tf
|
||||
import random
|
||||
from tensorflow import keras
|
||||
from tensorflow.data import AUTOTUNE
|
||||
from keras import layers, losses, optimizers
|
||||
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
image_size = ({{ .Size }})
|
||||
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
|
||||
def pathToLabel(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR, "")
|
||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
|
||||
|
||||
def decode_image(img):
|
||||
# channels were reduced to 1 since image is grayscale
|
||||
# TODO chnage channel number based if grayscale
|
||||
img = tf.io.decode_png(img, channels=1)
|
||||
|
||||
return tf.image.resize(img, image_size)
|
||||
|
||||
def process_path(path):
|
||||
label = pathToLabel(path)
|
||||
|
||||
img = tf.io.read_file(path)
|
||||
img = decode_image(img)
|
||||
|
||||
return img, label
|
||||
|
||||
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset:
|
||||
#ds = ds.cache()
|
||||
ds = ds.shuffle(buffer_size= 1000)
|
||||
ds = ds.batch(batch_size)
|
||||
ds = ds.prefetch(AUTOTUNE)
|
||||
return ds
|
||||
|
||||
def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset:
|
||||
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
||||
ds = configure_for_performance(ds)
|
||||
return ds
|
||||
|
||||
seed = random.randint(0, 100000000)
|
||||
|
||||
batch_size = 100
|
||||
|
||||
dataset = keras.utils.image_dataset_from_directory(
|
||||
"{{ .DataDir }}",
|
||||
color_mode="rgb",
|
||||
validation_split=0.2,
|
||||
label_mode='categorical',
|
||||
seed=seed,
|
||||
shuffle=True,
|
||||
subset="training",
|
||||
image_size=({{ .Size }}),
|
||||
batch_size=batch_size)
|
||||
# Read all the files from the direcotry
|
||||
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
|
||||
|
||||
dataset_validation = keras.utils.image_dataset_from_directory(
|
||||
"{{ .DataDir }}",
|
||||
color_mode="rgb",
|
||||
validation_split=0.2,
|
||||
label_mode='categorical',
|
||||
seed=seed,
|
||||
shuffle=True,
|
||||
subset="validation",
|
||||
image_size=({{ .Size }}),
|
||||
batch_size=batch_size)
|
||||
image_count = len(list_ds)
|
||||
|
||||
list_ds = list_ds.shuffle(image_count, seed=seed)
|
||||
|
||||
val_size = int(image_count * 0.3)
|
||||
|
||||
train_ds = list_ds.skip(val_size)
|
||||
val_ds = list_ds.take(val_size)
|
||||
|
||||
dataset = prepare_dataset(train_ds)
|
||||
dataset_validation = prepare_dataset(val_ds)
|
||||
|
||||
model = keras.Sequential([
|
||||
{{- range .Layers }}
|
||||
@@ -44,7 +75,7 @@ model = keras.Sequential([
|
||||
])
|
||||
|
||||
model.compile(
|
||||
loss=losses.CategoricalCrossentropy(),
|
||||
loss=losses.SparceCategoricalCrossentropy(),
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user