feat: more work for #23
This commit is contained in:
parent
a1d1a81ec5
commit
884a978e3b
@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
|||||||
if !classes.Next() { return }
|
if !classes.Next() { return }
|
||||||
if err = classes.Scan(&count); err != nil { return }
|
if err = classes.Scan(&count); err != nil { return }
|
||||||
|
|
||||||
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
||||||
if err != nil { return }
|
if err != nil { return }
|
||||||
defer data.Close()
|
defer data.Close()
|
||||||
|
|
||||||
@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
|||||||
class_order int
|
class_order int
|
||||||
}
|
}
|
||||||
|
|
||||||
got := []row{}
|
|
||||||
|
f, err := os.Create(path.Join(run_path, "train.csv"))
|
||||||
|
if err != nil { return }
|
||||||
|
defer f.Close()
|
||||||
|
f.Write([]byte("Id,Index\n"))
|
||||||
|
|
||||||
for data.Next() {
|
for data.Next() {
|
||||||
var id string
|
var id string
|
||||||
@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
|||||||
var file_path string
|
var file_path string
|
||||||
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
|
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
|
||||||
if file_path == "id://" {
|
if file_path == "id://" {
|
||||||
got = append(got, row{id, class_order})
|
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
|
||||||
} else {
|
} else {
|
||||||
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||||
}
|
}
|
||||||
@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = generateCvs(handle, run_path); err != nil {
|
_, err = generateCvs(handle, run_path, model_id)
|
||||||
return
|
if err != nil { return }
|
||||||
}
|
|
||||||
|
|
||||||
// Create python script
|
// Create python script
|
||||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||||
@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
|
|
||||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
|||||||
"Layers": got,
|
"Layers": got,
|
||||||
"Size": got[0].Shape,
|
"Size": got[0].Shape,
|
||||||
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
||||||
|
"RunPath": run_path,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -426,7 +429,7 @@ func handleTrain(handle *Handle) {
|
|||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
// Using sparce
|
// Using sparce
|
||||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls)))
|
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
// TODO improve this response
|
// TODO improve this response
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import random
|
import random
|
||||||
|
import pandas as pd
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from tensorflow.data import AUTOTUNE
|
from tensorflow.data import AUTOTUNE
|
||||||
from keras import layers, losses, optimizers
|
from keras import layers, losses, optimizers
|
||||||
@ -7,12 +8,28 @@ from keras import layers, losses, optimizers
|
|||||||
DATA_DIR = "{{ .DataDir }}"
|
DATA_DIR = "{{ .DataDir }}"
|
||||||
image_size = ({{ .Size }})
|
image_size = ({{ .Size }})
|
||||||
|
|
||||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
||||||
|
keys = tf.constant(df['Id'].dropna())
|
||||||
|
values = tf.constant(list(map(int, df['Index'].dropna())))
|
||||||
|
|
||||||
|
table = tf.lookup.StaticHashTable(
|
||||||
|
initializer=tf.lookup.KeyValueTensorInitializer(
|
||||||
|
keys=keys,
|
||||||
|
values=values,
|
||||||
|
),
|
||||||
|
default_value=tf.constant(-1),
|
||||||
|
name="Indexes"
|
||||||
|
)
|
||||||
|
|
||||||
|
DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||||
|
|
||||||
|
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||||
def pathToLabel(path):
|
def pathToLabel(path):
|
||||||
path = tf.strings.regex_replace(path, DATA_DIR, "")
|
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||||
return train_labels[tf.strings.to_number(path, out_type=tf.int32)]
|
path = tf.strings.regex_replace(path, ".png", "")
|
||||||
|
return table.lookup(tf.strings.as_string([path]))
|
||||||
|
#return tf.strings.as_string([path])
|
||||||
|
|
||||||
def decode_image(img):
|
def decode_image(img):
|
||||||
# channels were reduced to 1 since image is grayscale
|
# channels were reduced to 1 since image is grayscale
|
||||||
@ -75,7 +92,7 @@ model = keras.Sequential([
|
|||||||
])
|
])
|
||||||
|
|
||||||
model.compile(
|
model.compile(
|
||||||
loss=losses.SparceCategoricalCrossentropy(),
|
loss=losses.SparseCategoricalCrossentropy(),
|
||||||
optimizer=tf.keras.optimizers.Adam(),
|
optimizer=tf.keras.optimizers.Adam(),
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user