feat: more work for #23

This commit is contained in:
Andre Henriques 2023-10-03 11:55:22 +01:00
parent a1d1a81ec5
commit 884a978e3b
2 changed files with 32 additions and 12 deletions

View File

@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
if !classes.Next() { return } if !classes.Next() { return }
if err = classes.Scan(&count); err != nil { return } if err = classes.Scan(&count); err != nil { return }
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
if err != nil { return } if err != nil { return }
defer data.Close() defer data.Close()
@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
class_order int class_order int
} }
got := []row{}
f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil { return }
defer f.Close()
f.Write([]byte("Id,Index\n"))
for data.Next() { for data.Next() {
var id string var id string
@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
var file_path string var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil { return } if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
if file_path == "id://" { if file_path == "id://" {
got = append(got, row{id, class_order}) f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
} else { } else {
return count, errors.New("TODO generateCvs to file_path " + file_path) return count, errors.New("TODO generateCvs to file_path " + file_path)
} }
@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
return return
} }
if err = generateCvs(handle, run_path); err != nil { _, err = generateCvs(handle, run_path, model_id)
return if err != nil { return }
}
// Create python script // Create python script
f, err := os.Create(path.Join(run_path, "run.py")) f, err := os.Create(path.Join(run_path, "run.py"))
@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
} }
defer f.Close() defer f.Close()
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
if err != nil { if err != nil {
return return
@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
"Layers": got, "Layers": got,
"Size": got[0].Shape, "Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model_id, "data"), "DataDir": path.Join(getDir(), "savedData", model_id, "data"),
"RunPath": run_path,
}); err != nil { }); err != nil {
return return
} }
@ -426,7 +429,7 @@ func handleTrain(handle *Handle) {
return Error500(err) return Error500(err)
} }
// Using sparce // Using sparce
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls))) err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
if err != nil { if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response // TODO improve this response

View File

@ -1,5 +1,6 @@
import tensorflow as tf import tensorflow as tf
import random import random
import pandas as pd
from tensorflow import keras from tensorflow import keras
from tensorflow.data import AUTOTUNE from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers from keras import layers, losses, optimizers
@ -7,12 +8,28 @@ from keras import layers, losses, optimizers
DATA_DIR = "{{ .DataDir }}" DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }}) image_size = ({{ .Size }})
#based on https://www.tensorflow.org/tutorials/load_data/images df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
keys = tf.constant(df['Id'].dropna())
values = tf.constant(list(map(int, df['Index'].dropna())))
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=keys,
values=values,
),
default_value=tf.constant(-1),
name="Indexes"
)
DATA_DIR_PREPARE = DATA_DIR + "/"
#based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path): def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR, "") path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
path = tf.strings.regex_replace(path, ".jpg", "") path = tf.strings.regex_replace(path, ".jpg", "")
return train_labels[tf.strings.to_number(path, out_type=tf.int32)] path = tf.strings.regex_replace(path, ".png", "")
return table.lookup(tf.strings.as_string([path]))
#return tf.strings.as_string([path])
def decode_image(img): def decode_image(img):
# channels were reduced to 1 since image is grayscale # channels were reduced to 1 since image is grayscale
@ -75,7 +92,7 @@ model = keras.Sequential([
]) ])
model.compile( model.compile(
loss=losses.SparceCategoricalCrossentropy(), loss=losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(), optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy']) metrics=['accuracy'])