feat: more work for #23
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import tensorflow as tf
|
||||
import random
|
||||
import pandas as pd
|
||||
from tensorflow import keras
|
||||
from tensorflow.data import AUTOTUNE
|
||||
from keras import layers, losses, optimizers
|
||||
@@ -7,12 +8,28 @@ from keras import layers, losses, optimizers
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
image_size = ({{ .Size }})
|
||||
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
||||
keys = tf.constant(df['Id'].dropna())
|
||||
values = tf.constant(list(map(int, df['Index'].dropna())))
|
||||
|
||||
table = tf.lookup.StaticHashTable(
|
||||
initializer=tf.lookup.KeyValueTensorInitializer(
|
||||
keys=keys,
|
||||
values=values,
|
||||
),
|
||||
default_value=tf.constant(-1),
|
||||
name="Indexes"
|
||||
)
|
||||
|
||||
DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
def pathToLabel(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR, "")
|
||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
return table.lookup(tf.strings.as_string([path]))
|
||||
#return tf.strings.as_string([path])
|
||||
|
||||
def decode_image(img):
|
||||
# channels were reduced to 1 since image is grayscale
|
||||
@@ -75,7 +92,7 @@ model = keras.Sequential([
|
||||
])
|
||||
|
||||
model.compile(
|
||||
loss=losses.SparceCategoricalCrossentropy(),
|
||||
loss=losses.SparseCategoricalCrossentropy(),
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user