chore: started working on #38

This commit is contained in:
Andre Henriques 2023-10-10 12:28:49 +01:00
parent 1229ad5373
commit b6afecc682
8 changed files with 225 additions and 137 deletions

View File

@ -9,7 +9,7 @@ var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Er
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) { func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
id = "" id = ""
result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode) values ($1, $2, $3) returning id;", class_id, file_path, mode) result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode)
if err != nil { if err != nil {
return return
} }
@ -21,3 +21,8 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT
err = result.Scan(&id) err = result.Scan(&id)
return return
} }
func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) {
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
return
}

View File

@ -71,3 +71,14 @@ func CreateClass(db *sql.DB, model_id string, order int, name string) (id string
err = rows.Scan(&id) err = rows.Scan(&id)
return return
} }
func GetNumberOfWrongDataPoints(db *sql.DB, model_id string) (number int, err error) {
number = 0
rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id)
if err != nil { return }
defer rows.Close()
// TODO not an error because if there is no result means that there is no need to count
if !rows.Next() { return }
err = rows.Scan(&number)
return
}

View File

@ -137,9 +137,13 @@ func processZipFile(c *Context, model *BaseModel) {
f.Write(file_data) f.Write(file_data)
if !testImgForModel(c, model, file_path) { if !testImgForModel(c, model, file_path) {
c.Logger.Errorf("Image did not have valid format for model %s\n", file_path) c.Logger.Errorf("Image did not have valid format for model %s (in zip: %s)!", file_path, file.Name)
c.Logger.Warn("Not failling updating data point to status -1")
message := "Image did not have valid format for the model"
if err = model_classes.UpdateDataPointStatus(c.Db, data_point_id, -1, &message); err != nil {
c.Logger.Error("Failed to update data point status")
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE)
return }
} }
} }

View File

@ -69,10 +69,12 @@ func handleEdit(handle *Handle) {
"Model": model, "Model": model,
})) }))
case CONFIRM_PRE_TRAINING: case CONFIRM_PRE_TRAINING:
wrong_number, err := model_classes.GetNumberOfWrongDataPoints(c.Db, model.Id)
if err != nil { return c.Error500(err) }
cls, err := model_classes.ListClasses(handle.Db, id) cls, err := model_classes.ListClasses(handle.Db, id)
if err != nil { if err != nil { return c.Error500(err) }
return Error500(err)
}
has_data, err := model_classes.ModelHasDataPoints(handle.Db, id) has_data, err := model_classes.ModelHasDataPoints(handle.Db, id)
if err != nil { if err != nil {
@ -83,6 +85,7 @@ func handleEdit(handle *Handle) {
"Model": model, "Model": model,
"Classes": cls, "Classes": cls,
"HasData": has_data, "HasData": has_data,
"NumberOfInvalidImages": wrong_number,
})) }))
case READY: case READY:
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{

View File

@ -20,9 +20,13 @@ import (
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = "" id = ""
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy) rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
if err != nil { return } if err != nil {
return
}
defer rows.Close() defer rows.Close()
if !rows.Next() { return id, errors.New("Something wrong!") } if !rows.Next() {
return id, errors.New("Something wrong!")
}
err = rows.Scan(&id) err = rows.Scan(&id)
return return
} }
@ -59,17 +63,27 @@ func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType,
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) { func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id) classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id)
if err != nil { return } if err != nil {
return
}
defer classes.Close() defer classes.Close()
if !classes.Next() { return } if !classes.Next() {
if err = classes.Scan(&count); err != nil { return } return
}
if err = classes.Scan(&count); err != nil {
return
}
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id)
if err != nil { return } if err != nil {
return
}
defer data.Close() defer data.Close()
f, err := os.Create(path.Join(run_path, "train.csv")) f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil { return } if err != nil {
return
}
defer f.Close() defer f.Close()
f.Write([]byte("Id,Index\n")) f.Write([]byte("Id,Index\n"))
@ -77,7 +91,9 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err
var id string var id string
var class_order int var class_order int
var file_path string var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil { return } if err = data.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" { if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n")) f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
} else { } else {
@ -121,7 +137,9 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
} }
_, err = generateCvs(c, run_path, model.Id) _, err = generateCvs(c, run_path, model.Id)
if err != nil { return } if err != nil {
return
}
// Create python script // Create python script
f, err := os.Create(path.Join(run_path, "run.py")) f, err := os.Create(path.Join(run_path, "run.py"))
@ -307,6 +325,30 @@ func trainModel(c *Context, model *BaseModel) {
ModelUpdateStatus(c, model.Id, READY) ModelUpdateStatus(c, model.Id, READY)
} }
func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
rows, err := db.Query("select id from model_data_point as mdp join model_classes as mc on mc.id=mpd.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id)
if err != nil {
return
}
defer rows.Close()
base_path := path.Join("savedData", model.Id, "data")
for rows.Next() {
var dataPointId string
err = rows.Scan(&dataPointId)
if err != nil {
return
}
err = os.RemoveAll(path.Join(base_path, dataPointId + model.Format))
if err != nil {
return
}
}
return
}
func handleTrain(handle *Handle) { func handleTrain(handle *Handle) {
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) { if !CheckAuthLevel(1, w, r, c) {
@ -376,6 +418,11 @@ func handleTrain(handle *Handle) {
return c.Error500(err) return c.Error500(err)
} }
err = removeFailedDataPoints(c.Db, model.Id)
if err != nil {
return c.Error500(err)
}
var fid string var fid string
for i := 0; i < number_of_models; i++ { for i := 0; i < number_of_models; i++ {
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy) def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)

View File

@ -33,7 +33,11 @@ create table if not exists model_data_point (
file_path text not null, file_path text not null,
-- 1 training -- 1 training
-- 2 testing -- 2 testing
model_mode integer default 1 model_mode integer default 1,
-- -1 Error on creation
-- 1 OK
status integer not null,
status_message text
); );
-- drop table if exists model_definition; -- drop table if exists model_definition;

View File

@ -247,6 +247,11 @@
<p> <p>
You need to upload data so the model can train. You need to upload data so the model can train.
</p> </p>
{{ if gt .NumberOfInvalidImages 0 }}
<p class="danger">
There are images that were loaded that do not have the correct format. These images will be delete when the model trains.
</p>
{{ end }}
<div class="tabs"> <div class="tabs">
<button class="tab" data-tab="create_class"> <button class="tab" data-tab="create_class">
Create Class Create Class
@ -268,6 +273,11 @@
{{ define "train-model-card" }} {{ define "train-model-card" }}
<form hx-post="/models/train" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML" {{ if .Error }} class="submitted" {{end}} > <form hx-post="/models/train" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML" {{ if .Error }} class="submitted" {{end}} >
{{ if .HasData }} {{ if .HasData }}
{{ if gt .NumberOfInvalidImages 0 }}
<p class="danger">
There are images that were loaded that do not have the correct format. These images will be delete when the model trains.
</p>
{{ end }}
{{/* TODO expading mode */}} {{/* TODO expading mode */}}
<input type="hidden" value="{{ .Model.Id }}" name="id" /> <input type="hidden" value="{{ .Model.Id }}" name="id" />
<fieldset> <fieldset>
@ -296,7 +306,7 @@
</button> </button>
{{ else }} {{ else }}
<h2> <h2>
Please provide data to the model first To train the model please provide data to the model first
</h2> </h2>
{{ end }} {{ end }}
</form> </form>

View File

@ -66,6 +66,10 @@ main {
font-size: 1.1rem; font-size: 1.1rem;
} }
.danger {
color: red;
}
/* Generic */ /* Generic */
.button, .button,
button { button {