feat: more work for #23

This commit is contained in:
2023-10-03 11:55:22 +01:00
parent a1d1a81ec5
commit 884a978e3b
2 changed files with 32 additions and 12 deletions

View File

@@ -1,5 +1,6 @@
import tensorflow as tf
import random
import pandas as pd
from tensorflow import keras
from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers
@@ -7,12 +8,28 @@ from keras import layers, losses, optimizers
DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }})
#based on https://www.tensorflow.org/tutorials/load_data/images
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
keys = tf.constant(df['Id'].dropna())
values = tf.constant(list(map(int, df['Index'].dropna())))
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=keys,
values=values,
),
default_value=tf.constant(-1),
name="Indexes"
)
DATA_DIR_PREPARE = DATA_DIR + "/"
#based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR, "")
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
path = tf.strings.regex_replace(path, ".jpg", "")
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
path = tf.strings.regex_replace(path, ".png", "")
return table.lookup(tf.strings.as_string([path]))
#return tf.strings.as_string([path])
def decode_image(img):
# channels were reduced to 1 since image is grayscale
@@ -75,7 +92,7 @@ model = keras.Sequential([
])
model.compile(
loss=losses.SparceCategoricalCrossentropy(),
loss=losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])