feat: more work for #23
This commit is contained in:
parent
a1d1a81ec5
commit
884a978e3b
@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
if !classes.Next() { return }
|
||||
if err = classes.Scan(&count); err != nil { return }
|
||||
|
||||
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
||||
data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
||||
if err != nil { return }
|
||||
defer data.Close()
|
||||
|
||||
@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
class_order int
|
||||
}
|
||||
|
||||
got := []row{}
|
||||
|
||||
f, err := os.Create(path.Join(run_path, "train.csv"))
|
||||
if err != nil { return }
|
||||
defer f.Close()
|
||||
f.Write([]byte("Id,Index\n"))
|
||||
|
||||
for data.Next() {
|
||||
var id string
|
||||
@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
var file_path string
|
||||
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
|
||||
if file_path == "id://" {
|
||||
got = append(got, row{id, class_order})
|
||||
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
|
||||
} else {
|
||||
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||
}
|
||||
@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
return
|
||||
}
|
||||
|
||||
if err = generateCvs(handle, run_path); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = generateCvs(handle, run_path, model_id)
|
||||
if err != nil { return }
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
|
||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||
if err != nil {
|
||||
return
|
||||
@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
||||
"RunPath": run_path,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@ -426,7 +429,7 @@ func handleTrain(handle *Handle) {
|
||||
return Error500(err)
|
||||
}
|
||||
// Using sparce
|
||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls)))
|
||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
|
@ -1,5 +1,6 @@
|
||||
import tensorflow as tf
|
||||
import random
|
||||
import pandas as pd
|
||||
from tensorflow import keras
|
||||
from tensorflow.data import AUTOTUNE
|
||||
from keras import layers, losses, optimizers
|
||||
@ -7,12 +8,28 @@ from keras import layers, losses, optimizers
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
image_size = ({{ .Size }})
|
||||
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
||||
keys = tf.constant(df['Id'].dropna())
|
||||
values = tf.constant(list(map(int, df['Index'].dropna())))
|
||||
|
||||
table = tf.lookup.StaticHashTable(
|
||||
initializer=tf.lookup.KeyValueTensorInitializer(
|
||||
keys=keys,
|
||||
values=values,
|
||||
),
|
||||
default_value=tf.constant(-1),
|
||||
name="Indexes"
|
||||
)
|
||||
|
||||
DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
def pathToLabel(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR, "")
|
||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
return table.lookup(tf.strings.as_string([path]))
|
||||
#return tf.strings.as_string([path])
|
||||
|
||||
def decode_image(img):
|
||||
# channels were reduced to 1 since image is grayscale
|
||||
@ -75,7 +92,7 @@ model = keras.Sequential([
|
||||
])
|
||||
|
||||
model.compile(
|
||||
loss=losses.SparceCategoricalCrossentropy(),
|
||||
loss=losses.SparseCategoricalCrossentropy(),
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user