2023-09-26 20:15:28 +01:00
|
|
|
import tensorflow as tf
|
|
|
|
import random
|
2023-10-03 11:55:22 +01:00
|
|
|
import pandas as pd
|
2023-09-26 20:15:28 +01:00
|
|
|
from tensorflow import keras
|
2023-10-02 21:15:31 +01:00
|
|
|
from tensorflow.data import AUTOTUNE
|
2023-09-26 20:15:28 +01:00
|
|
|
from keras import layers, losses, optimizers
|
|
|
|
|
2023-10-02 21:15:31 +01:00
|
|
|
DATA_DIR = "{{ .DataDir }}"
|
|
|
|
image_size = ({{ .Size }})
|
|
|
|
|
2023-10-03 11:55:22 +01:00
|
|
|
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
|
|
|
keys = tf.constant(df['Id'].dropna())
|
|
|
|
values = tf.constant(list(map(int, df['Index'].dropna())))
|
|
|
|
|
|
|
|
table = tf.lookup.StaticHashTable(
|
|
|
|
initializer=tf.lookup.KeyValueTensorInitializer(
|
|
|
|
keys=keys,
|
|
|
|
values=values,
|
|
|
|
),
|
|
|
|
default_value=tf.constant(-1),
|
|
|
|
name="Indexes"
|
|
|
|
)
|
2023-10-02 21:15:31 +01:00
|
|
|
|
2023-10-03 11:55:22 +01:00
|
|
|
DATA_DIR_PREPARE = DATA_DIR + "/"
|
|
|
|
|
|
|
|
#based on https://www.tensorflow.org/tutorials/load_data/images
|
2023-10-02 21:15:31 +01:00
|
|
|
def pathToLabel(path):
|
2023-10-03 11:55:22 +01:00
|
|
|
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
2023-10-02 21:15:31 +01:00
|
|
|
path = tf.strings.regex_replace(path, ".jpg", "")
|
2023-10-03 11:55:22 +01:00
|
|
|
path = tf.strings.regex_replace(path, ".png", "")
|
|
|
|
return table.lookup(tf.strings.as_string([path]))
|
|
|
|
#return tf.strings.as_string([path])
|
2023-10-02 21:15:31 +01:00
|
|
|
|
|
|
|
def decode_image(img):
|
2023-10-10 15:56:00 +01:00
|
|
|
{{ if eq .Model.Format "png" }}
|
2023-10-03 19:02:02 +01:00
|
|
|
img = tf.io.decode_png(img, channels={{.ColorMode}})
|
2023-10-10 15:56:00 +01:00
|
|
|
{{ else if eq .Model.Format "jpeg" }}
|
|
|
|
img = tf.io.decode_jpeg(img, channels={{.ColorMode}})
|
|
|
|
{{ else }}
|
|
|
|
ERROR
|
|
|
|
{{ end }}
|
2023-10-02 21:15:31 +01:00
|
|
|
return tf.image.resize(img, image_size)
|
|
|
|
|
|
|
|
def process_path(path):
|
|
|
|
label = pathToLabel(path)
|
|
|
|
|
|
|
|
img = tf.io.read_file(path)
|
|
|
|
img = decode_image(img)
|
|
|
|
|
|
|
|
return img, label
|
|
|
|
|
|
|
|
def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset:
|
|
|
|
#ds = ds.cache()
|
|
|
|
ds = ds.shuffle(buffer_size= 1000)
|
|
|
|
ds = ds.batch(batch_size)
|
|
|
|
ds = ds.prefetch(AUTOTUNE)
|
|
|
|
return ds
|
|
|
|
|
|
|
|
def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset:
|
|
|
|
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
|
|
|
ds = configure_for_performance(ds)
|
|
|
|
return ds
|
|
|
|
|
2023-09-26 20:15:28 +01:00
|
|
|
seed = random.randint(0, 100000000)
|
|
|
|
|
|
|
|
batch_size = 100
|
|
|
|
|
2023-10-02 21:15:31 +01:00
|
|
|
# Read all the files from the direcotry
|
|
|
|
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
|
|
|
|
|
|
|
|
image_count = len(list_ds)
|
|
|
|
|
|
|
|
list_ds = list_ds.shuffle(image_count, seed=seed)
|
|
|
|
|
|
|
|
val_size = int(image_count * 0.3)
|
|
|
|
|
|
|
|
train_ds = list_ds.skip(val_size)
|
|
|
|
val_ds = list_ds.take(val_size)
|
|
|
|
|
|
|
|
dataset = prepare_dataset(train_ds)
|
|
|
|
dataset_validation = prepare_dataset(val_ds)
|
2023-09-26 20:15:28 +01:00
|
|
|
|
|
|
|
model = keras.Sequential([
|
|
|
|
{{- range .Layers }}
|
|
|
|
{{- if eq .LayerType 1}}
|
|
|
|
layers.Rescaling(1./255),
|
|
|
|
{{- else if eq .LayerType 2 }}
|
2023-09-27 13:55:29 +01:00
|
|
|
layers.Dense({{ .Shape }}, activation="sigmoid"),
|
2023-09-26 20:15:28 +01:00
|
|
|
{{- else if eq .LayerType 3}}
|
|
|
|
layers.Flatten(),
|
|
|
|
{{- else }}
|
|
|
|
ERROR
|
|
|
|
{{- end }}
|
|
|
|
{{- end }}
|
|
|
|
])
|
|
|
|
|
2023-09-27 13:55:29 +01:00
|
|
|
model.compile(
|
2023-10-03 11:55:22 +01:00
|
|
|
loss=losses.SparseCategoricalCrossentropy(),
|
2023-09-27 13:55:29 +01:00
|
|
|
optimizer=tf.keras.optimizers.Adam(),
|
|
|
|
metrics=['accuracy'])
|
2023-09-26 20:15:28 +01:00
|
|
|
|
2023-09-29 13:27:43 +01:00
|
|
|
his = model.fit(dataset, validation_data= dataset_validation, epochs=50)
|
2023-09-26 20:15:28 +01:00
|
|
|
|
2023-09-27 13:55:29 +01:00
|
|
|
acc = his.history["accuracy"]
|
|
|
|
|
|
|
|
f = open("accuracy.val", "w")
|
|
|
|
f.write(str(acc[-1]))
|
|
|
|
f.close()
|
|
|
|
|
2023-09-27 21:20:39 +01:00
|
|
|
tf.saved_model.save(model, "model")
|
|
|
|
|
|
|
|
# model.save("model.keras", save_format="tf")
|