diff --git a/logic/models/train/train.go b/logic/models/train/train.go index f1834d6..effc4c0 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int, if !classes.Next() { return } if err = classes.Scan(&count); err != nil { return } - data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) + data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) if err != nil { return } defer data.Close() @@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int, class_order int } - got := []row{} + + f, err := os.Create(path.Join(run_path, "train.csv")) + if err != nil { return } + defer f.Close() + f.Write([]byte("Id,Index\n")) for data.Next() { var id string @@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int, var file_path string if err = data.Scan(&id, &class_order, &file_path); err != nil { return } if file_path == "id://" { - got = append(got, row{id, class_order}) + f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n")) } else { return count, errors.New("TODO generateCvs to file_path " + file_path) } @@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc return } - if err = generateCvs(handle, run_path); err != nil { - return - } + _, err = generateCvs(handle, run_path, model_id) + if err != nil { return } // Create python script f, err := os.Create(path.Join(run_path, "run.py")) @@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc } defer f.Close() - tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") if err != nil { return @@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc "Layers": got, "Size": got[0].Shape, "DataDir": path.Join(getDir(), "savedData", model_id, "data"), + "RunPath": run_path, }); err != nil { return } @@ -426,7 +429,7 @@ func handleTrain(handle *Handle) { return Error500(err) } // Using sparce - err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls))) + err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls))) if err != nil { ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 2d07674..f33b23f 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -1,5 +1,6 @@ import tensorflow as tf import random +import pandas as pd from tensorflow import keras from tensorflow.data import AUTOTUNE from keras import layers, losses, optimizers @@ -7,12 +8,28 @@ from keras import layers, losses, optimizers DATA_DIR = "{{ .DataDir }}" image_size = ({{ .Size }}) -#based on https://www.tensorflow.org/tutorials/load_data/images +df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str) +keys = tf.constant(df['Id'].dropna()) +values = tf.constant(list(map(int, df['Index'].dropna()))) +table = tf.lookup.StaticHashTable( + initializer=tf.lookup.KeyValueTensorInitializer( + keys=keys, + values=values, + ), + default_value=tf.constant(-1), + name="Indexes" +) + +DATA_DIR_PREPARE = DATA_DIR + "/" + +#based on https://www.tensorflow.org/tutorials/load_data/images def pathToLabel(path): - path = tf.strings.regex_replace(path, DATA_DIR, "") + path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") path = tf.strings.regex_replace(path, ".jpg", "") - return train_labels[tf.strings.to_number(path, out_type=tf.int32)] + path = tf.strings.regex_replace(path, ".png", "") + return table.lookup(tf.strings.as_string([path])) + #return tf.strings.as_string([path]) def decode_image(img): # channels were reduced to 1 since image is grayscale @@ -75,7 +92,7 @@ model = keras.Sequential([ ]) model.compile( - loss=losses.SparceCategoricalCrossentropy(), + loss=losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])