fyp/logic/models/train/train.go

197 lines
5.8 KiB
Go
Raw Normal View History

package models_train
import (
"database/sql"
"errors"
"fmt"
"net/http"
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = ""
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
if err != nil {
return
}
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return id, errors.New("Something wrong!")
}
err = rows.Scan(&id)
if err != nil {
return
}
return
}
type ModelDefinitionStatus int
const (
MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
MODEL_DEFINITION_STATUS_INIT = 2
MODEL_DEFINITION_STATUS_TRAINING = 3
MODEL_DEFINITION_STATUS_TRANIED = 4
MODEL_DEFINITION_STATUS_READY = 5
)
func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) {
_, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id)
return
}
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape string) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
return
}
func trainModel(handle *Handle, model *BaseModel) {
definitionsRows, err := handle.Db.Query("select id from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT)
if err != nil {
fmt.Printf("Failed to trainModel!Err:\n")
fmt.Println(err)
ModelUpdateStatus(handle, model.Id, FAILED_TRAINING)
return
}
defer definitionsRows.Close()
definitions := []string{}
for definitionsRows.Next() {
var id string
if err = definitionsRows.Scan(&id); err != nil {
fmt.Printf("Failed to trainModel!Err:\n")
fmt.Println(err)
ModelUpdateStatus(handle, model.Id, FAILED_TRAINING)
return
}
definitions = append(definitions, id)
}
if len(definitions) == 0 {
fmt.Printf("Failed to trainModel!Err:\n")
fmt.Println(err)
ModelUpdateStatus(handle, model.Id, FAILED_TRAINING)
return
}
for _, def_id := range definitions {
_ = def_id
}
}
func handleTrain(handle *Handle) {
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
if c.Mode == JSON {
panic("TODO /models/train JSON")
}
r.ParseForm()
f := r.Form
number_of_models := 0
accuracy := 0
if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) {
fmt.Println(
!CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy),
)
// TODO improve this response
return ErrorCode(nil, 400, c.AddMap(nil))
}
id := f.Get("id")
model_type := f.Get("model_type")
// Its not used rn
_ = model_type
model, err := GetBaseModel(handle.Db, id)
if err == ModelNotFoundError {
return ErrorCode(nil, http.StatusNotFound, c.AddMap(AnyMap{
"NotFoundMessage": "Model not found",
"GoBackLink": "/models",
}))
} else if err != nil {
// TODO improve this response
return Error500(err)
}
if model.Status != CONFIRM_PRE_TRAINING {
// TODO improve this response
return ErrorCode(nil, 400, c.AddMap(nil))
}
cls, err := model_classes.ListClasses(handle.Db, model.Id)
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
var fid string
for i := 0; i < number_of_models; i++ {
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
if fid == "" {
fid = def_id
}
// TODO change shape of it depends on the type of the image
err = MakeLayer(handle.Db, def_id, 1, 1, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
err = MakeLayer(handle.Db, def_id, 4, 3, fmt.Sprintf("%d,1", len(cls)))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
err = MakeLayer(handle.Db, def_id, 5, 2, fmt.Sprintf("%d,1", len(cls)))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
}
// TODO start training with id fid
go trainModel(handle, model)
ModelUpdateStatus(handle, model.Id, TRAINING)
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
return nil
})
}