feat: more work for #23

This commit is contained in:
2023-10-03 11:55:22 +01:00
parent a1d1a81ec5
commit 884a978e3b
2 changed files with 32 additions and 12 deletions

View File

@@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
if !classes.Next() { return }
if err = classes.Scan(&count); err != nil { return }
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
if err != nil { return }
defer data.Close()
@@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
class_order int
}
got := []row{}
f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil { return }
defer f.Close()
f.Write([]byte("Id,Index\n"))
for data.Next() {
var id string
@@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
if file_path == "id://" {
got = append(got, row{id, class_order})
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
@@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
return
}
if err = generateCvs(handle, run_path); err != nil {
return
}
_, err = generateCvs(handle, run_path, model_id)
if err != nil { return }
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
@@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
}
defer f.Close()
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
if err != nil {
return
@@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
"RunPath": run_path,
}); err != nil {
return
}
@@ -426,7 +429,7 @@ func handleTrain(handle *Handle) {
return Error500(err)
}
// Using sparce
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls)))
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response