feat: more work for #23
This commit is contained in:
@@ -79,7 +79,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
if !classes.Next() { return }
|
||||
if err = classes.Scan(&count); err != nil { return }
|
||||
|
||||
data, err := handle.Db.Query("select mpd.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
||||
data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
|
||||
if err != nil { return }
|
||||
defer data.Close()
|
||||
|
||||
@@ -88,7 +88,11 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
class_order int
|
||||
}
|
||||
|
||||
got := []row{}
|
||||
|
||||
f, err := os.Create(path.Join(run_path, "train.csv"))
|
||||
if err != nil { return }
|
||||
defer f.Close()
|
||||
f.Write([]byte("Id,Index\n"))
|
||||
|
||||
for data.Next() {
|
||||
var id string
|
||||
@@ -96,7 +100,7 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int,
|
||||
var file_path string
|
||||
if err = data.Scan(&id, &class_order, &file_path); err != nil { return }
|
||||
if file_path == "id://" {
|
||||
got = append(got, row{id, class_order})
|
||||
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
|
||||
} else {
|
||||
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||
}
|
||||
@@ -137,9 +141,8 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
return
|
||||
}
|
||||
|
||||
if err = generateCvs(handle, run_path); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = generateCvs(handle, run_path, model_id)
|
||||
if err != nil { return }
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
@@ -148,7 +151,6 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
|
||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||
if err != nil {
|
||||
return
|
||||
@@ -158,6 +160,7 @@ func trainDefinition(handle *Handle, model_id string, definition_id string) (acc
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", model_id, "data"),
|
||||
"RunPath": run_path,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@@ -426,7 +429,7 @@ func handleTrain(handle *Handle) {
|
||||
return Error500(err)
|
||||
}
|
||||
// Using sparce
|
||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("1,1", len(cls)))
|
||||
err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
|
||||
Reference in New Issue
Block a user