feat: more work for #23

This commit is contained in:
Andre Henriques 2023-10-03 11:55:22 +01:00
parent a1d1a81ec5
commit 884a978e3b
2 changed files with 32 additions and 12 deletions

View File

@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
if !classes.Next() { return }
if err = classes.Scan(&count); err != nil { return }
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
if err != nil { return }
defer data.Close()
@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
class_order int
}
got := []row{}
f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil { return }
defer f.Close()
f.Write([]byte("Id,Index\n"))
for data.Next() {
var id string
@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
if file_path == "id://" {
got = append(got, row{id, class_order})
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
return
}
if err = generateCvs(handle, run_path); err != nil {
return
}
_, err = generateCvs(handle, run_path, model_id)
if err != nil { return }
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
}
defer f.Close()
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
if err != nil {
return
@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
"RunPath": run_path,
}); err != nil {
return
}
@ -426,7 +429,7 @@ func handleTrain(handle *Handle) {
return Error500(err)
}
// Using sparce
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls)))
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response

View File

@ -1,5 +1,6 @@
import tensorflow as tf
import random
import pandas as pd
from tensorflow import keras
from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers
@ -7,12 +8,28 @@ from keras import layers, losses, optimizers
DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }})
#based on https://www.tensorflow.org/tutorials/load_data/images
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
keys = tf.constant(df['Id'].dropna())
values = tf.constant(list(map(int, df['Index'].dropna())))
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=keys,
values=values,
),
default_value=tf.constant(-1),
name="Indexes"
)
DATA_DIR_PREPARE = DATA_DIR + "/"
#based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR, "")
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
path = tf.strings.regex_replace(path, ".jpg", "")
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
path = tf.strings.regex_replace(path, ".png", "")
return table.lookup(tf.strings.as_string([path]))
#return tf.strings.as_string([path])
def decode_image(img):
# channels were reduced to 1 since image is grayscale
@ -75,7 +92,7 @@ model = keras.Sequential([
])
model.compile(
loss=losses.SparceCategoricalCrossentropy(),
loss=losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])