diff --git a/a.out b/a.out deleted file mode 100755 index bdddb1f..0000000 Binary files a/a.out and /dev/null differ diff --git a/logic/models/classes/data_point.go b/logic/models/classes/data_point.go index c7263b4..14b5ab0 100644 --- a/logic/models/classes/data_point.go +++ b/logic/models/classes/data_point.go @@ -5,23 +5,19 @@ import ( "errors" ) +var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error") + func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) { id = "" - _, err = db.Exec("insert into model_data_point (class_id, file_path, model_mode) values ($1, $2, $3);", class_id, file_path, mode) + result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode) values ($1, $2, $3) returning id;", class_id, file_path, mode) if err != nil { return } - - rows, err := db.Query("select id from model_data_point where class_id=$1 and file_path=$2 and model_mode=$3", class_id, file_path, mode) - if err != nil { + defer result.Close() + if !result.Next() { + err = FailedToGetIdAfterInsertError return } - defer rows.Close() - - if !rows.Next() { - return id, errors.New("Something worng") - } - - err = rows.Scan(&id) + err = result.Scan(&id) return } diff --git a/logic/models/classes/main.go b/logic/models/classes/main.go index fd43b2c..2921248 100644 --- a/logic/models/classes/main.go +++ b/logic/models/classes/main.go @@ -46,7 +46,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) { var ClassAlreadyExists = errors.New("Class aready exists") -func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) { +func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) { id = "" rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name) if err != nil { @@ -58,25 +58,16 @@ func CreateClass(db *sql.DB, model_id string, name string) (id string, err error return id, ClassAlreadyExists } - _, err = db.Exec("insert into model_classes (model_id, name) values ($1, $2)", model_id, name) - - if err != nil { - return - } - - rows, err = db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name) + rows, err = db.Query("insert into model_classes (model_id, name, class_order) values ($1, $2, $3) returning id;", model_id, name, order) if err != nil { return } defer rows.Close() if !rows.Next() { - return id, errors.New("Something wrong") + return id, errors.New("Insert did not return anything") } - if err = rows.Scan(&id); err != nil { - return - } - - return + err = rows.Scan(&id) + return } diff --git a/logic/models/data.go b/logic/models/data.go index 875434a..354aae5 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -28,11 +28,11 @@ func InsertIfNotPresent(ss []string, s string) []string { return ss } -func processZipFile(handle *Handle, id string) { - reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip")) +func processZipFile(handle *Handle, model_id string) { + reader, err := zip.OpenReader(path.Join("savedData", model_id, "base_data.zip")) if err != nil { // TODO add msg to error - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) fmt.Printf("Faield to proccess zip file failed to open reader\n") fmt.Println(err) return @@ -52,7 +52,7 @@ func processZipFile(handle *Handle, id string) { if paths[0] != "training" && paths[0] != "testing" { fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) return } @@ -67,31 +67,21 @@ func processZipFile(handle *Handle, id string) { fmt.Printf("testing and training are diferent\n") fmt.Println(testing) fmt.Println(training) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) return } - base_path := path.Join("savedData", id, "data") + base_path := path.Join("savedData", model_id, "data") + if err = os.MkdirAll(base_path, os.ModePerm); err != nil { + fmt.Printf("Failed to create base_path dir\n") + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) + return + } ids := map[string]string{} - for _, name := range training { - dir_path := path.Join(base_path, "training", name) - err = os.MkdirAll(dir_path, os.ModePerm) - if err != nil { - fmt.Printf("Failed to create dir %s\n", dir_path) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) - return - } - dir_path = path.Join(base_path, "testing", name) - err = os.MkdirAll(dir_path, os.ModePerm) - if err != nil { - fmt.Printf("Failed to create dir %s\n", dir_path) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) - return - } - - id, err := model_classes.CreateClass(handle.Db, id, name) + for i, name := range training { + id, err := model_classes.CreateClass(handle.Db, model_id, i, name) if err != nil { fmt.Printf("Failed to create class '%s' on db\n", name) ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) @@ -105,23 +95,19 @@ func processZipFile(handle *Handle, id string) { continue } - file_path := path.Join(base_path, file.Name) - f, err := os.Create(file_path) - if err != nil { - fmt.Printf("Could not create file %s\n", file_path) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) - return - } - defer f.Close() data, err := reader.Open(file.Name) if err != nil { - fmt.Printf("Could not create file %s\n", file_path) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + fmt.Printf("Could not open file in zip %s\n", file.Name) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) return } defer data.Close() file_data, err := io.ReadAll(data) - f.Write(file_data) + if err != nil { + fmt.Printf("Could not read file file in zip %s\n", file.Name) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) + return + } // TODO check if the file is a valid photo that matched the defined photo on the database @@ -132,17 +118,27 @@ func processZipFile(handle *Handle, id string) { mode = model_classes.DATA_POINT_MODE_TESTING } - _, err = model_classes.AddDataPoint(handle.Db, ids[parts[1]], "file://" + parts[2], mode) + data_point_id, err := model_classes.AddDataPoint(handle.Db, ids[parts[1]], "id://", mode) if err != nil { - fmt.Printf("Failed to add data point for %s\n", id) + fmt.Printf("Failed to add data point for %s\n", model_id) fmt.Println(err) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) return } + + file_path := path.Join(base_path, data_point_id + ".png") + f, err := os.Create(file_path) + if err != nil { + fmt.Printf("Could not create file %s\n", file_path) + ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) + return + } + defer f.Close() + f.Write(file_data) } - fmt.Printf("Added data to model '%s'!\n", id) - ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING) + fmt.Printf("Added data to model '%s'!\n", model_id) + ModelUpdateStatus(handle, model_id, CONFIRM_PRE_TRAINING) } func handleDataUpload(handle *Handle) { diff --git a/logic/models/delete.go b/logic/models/delete.go index b745b6f..c3d1ea9 100644 --- a/logic/models/delete.go +++ b/logic/models/delete.go @@ -81,6 +81,8 @@ func handleDelete(handle *Handle) { } switch model.Status { + case FAILED_TRAINING: + fallthrough case FAILED_PREPARING_TRAINING: fallthrough case FAILED_PREPARING: diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 9a504e7..f1834d6 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -71,6 +71,40 @@ func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, return } +func generateCvs(handle *Handle, run_path string, model_id string) (count int, err error) { + + classes, err := handle.Db.Query("select count(*) from model_classes where model_id=$1;", model_id) + if err != nil { return } + defer classes.Close() + if !classes.Next() { return } + if err = classes.Scan(&count); err != nil { return } + + data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) + if err != nil { return } + defer data.Close() + + type row struct { + path string + class_order int + } + + got := []row{} + + for data.Next() { + var id string + var class_order int + var file_path string + if err = data.Scan(&id, &class_order, &file_path); err != nil { return } + if file_path == "id://" { + got = append(got, row{id, class_order}) + } else { + return count, errors.New("TODO generateCvs to file_path " + file_path) + } + } + + return +} + func trainDefinition(handle *Handle, model_id string, definition_id string) (accuracy float64, err error) { accuracy = 0 layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) @@ -103,6 +137,10 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc return } + if err = generateCvs(handle, run_path); err != nil { + return + } + // Create python script f, err := os.Create(path.Join(run_path, "run.py")) if err != nil { @@ -110,6 +148,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc } defer f.Close() + tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py") if err != nil { return @@ -118,13 +157,15 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc if err = tmpl.Execute(f, AnyMap{ "Layers": got, "Size": got[0].Shape, - "DataDir": path.Join(getDir(), "savedData", model_id, "data", "training"), + "DataDir": path.Join(getDir(), "savedData", model_id, "data"), }); err != nil { return } // Run the command - if err = exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Run(); err != nil { + out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output() + if err != nil { + fmt.Println(string(out)) return } @@ -384,7 +425,8 @@ func handleTrain(handle *Handle) { // TODO improve this response return Error500(err) } - err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls))) + // Using sparce + err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls))) if err != nil { ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response @@ -400,7 +442,6 @@ func handleTrain(handle *Handle) { } // TODO start training with id fid - go trainModel(handle, model) ModelUpdateStatus(handle, model.Id, TRAINING) diff --git a/sql/models.sql b/sql/models.sql index 239a626..5024380 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -21,7 +21,8 @@ create table if not exists models ( create table if not exists model_classes ( id uuid primary key default gen_random_uuid(), model_id uuid references models (id) on delete cascade, - name varchar (70) not null + name varchar (70) not null, + class_order integer ); -- drop table if exists model_data_point; diff --git a/views/models/edit.html b/views/models/edit.html index 2c03d9a..48a58a7 100644 --- a/views/models/edit.html +++ b/views/models/edit.html @@ -63,7 +63,11 @@ {{range .List}} - {{.FilePath}} + {{ if eq .FilePath "id://" }} + Managed + {{ else }} + {{.FilePath}} + {{ end }} {{ if (eq .Mode 2) }} @@ -73,8 +77,8 @@ {{ end }} - {{ if startsWith .FilePath "file://" }} - + {{ if startsWith .FilePath "id://" }} + {{ else }} TODO img {{ .FilePath }} diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index a3ce85a..2d07674 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -1,33 +1,64 @@ import tensorflow as tf import random from tensorflow import keras +from tensorflow.data import AUTOTUNE from keras import layers, losses, optimizers +DATA_DIR = "{{ .DataDir }}" +image_size = ({{ .Size }}) + +#based on https://www.tensorflow.org/tutorials/load_data/images + +def pathToLabel(path): + path = tf.strings.regex_replace(path, DATA_DIR, "") + path = tf.strings.regex_replace(path, ".jpg", "") + return train_labels[tf.strings.to_number(path, out_type=tf.int32)] + +def decode_image(img): + # channels were reduced to 1 since image is grayscale + # TODO chnage channel number based if grayscale + img = tf.io.decode_png(img, channels=1) + + return tf.image.resize(img, image_size) + +def process_path(path): + label = pathToLabel(path) + + img = tf.io.read_file(path) + img = decode_image(img) + + return img, label + +def configure_for_performance(ds: tf.data.Dataset) -> tf.data.Dataset: + #ds = ds.cache() + ds = ds.shuffle(buffer_size= 1000) + ds = ds.batch(batch_size) + ds = ds.prefetch(AUTOTUNE) + return ds + +def prepare_dataset(ds: tf.data.Dataset) -> tf.data.Dataset: + ds = ds.map(process_path, num_parallel_calls=AUTOTUNE) + ds = configure_for_performance(ds) + return ds + seed = random.randint(0, 100000000) batch_size = 100 -dataset = keras.utils.image_dataset_from_directory( - "{{ .DataDir }}", - color_mode="rgb", - validation_split=0.2, - label_mode='categorical', - seed=seed, - shuffle=True, - subset="training", - image_size=({{ .Size }}), - batch_size=batch_size) +# Read all the files from the direcotry +list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False) -dataset_validation = keras.utils.image_dataset_from_directory( - "{{ .DataDir }}", - color_mode="rgb", - validation_split=0.2, - label_mode='categorical', - seed=seed, - shuffle=True, - subset="validation", - image_size=({{ .Size }}), - batch_size=batch_size) +image_count = len(list_ds) + +list_ds = list_ds.shuffle(image_count, seed=seed) + +val_size = int(image_count * 0.3) + +train_ds = list_ds.skip(val_size) +val_ds = list_ds.take(val_size) + +dataset = prepare_dataset(train_ds) +dataset_validation = prepare_dataset(val_ds) model = keras.Sequential([ {{- range .Layers }} @@ -44,7 +75,7 @@ model = keras.Sequential([ ]) model.compile( - loss=losses.CategoricalCrossentropy(), + loss=losses.SparceCategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])