fix model retrain not working closes #93

This commit is contained in:
2024-04-16 16:02:57 +01:00
parent 026932cfab
commit f182b205f8
7 changed files with 144 additions and 179 deletions

View File

@@ -1658,150 +1658,6 @@ func trainRetrain(c *Context, model *BaseModel, defId string) {
}
}
func handleRetrain(c *Context) *Error {
var err error = nil
if !c.CheckAuthLevel(1) {
return nil
}
var dat JustId
if err_ := c.ToJSON(&dat); err_ != nil {
return err_
}
if dat.Id == "" {
return c.JsonBadRequest("Please provide a id")
}
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.Error500(err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
// TODO improve this response
return c.Error500(err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
var newHead = struct {
DefId string `db:"def_id"`
RangeStart int `db:"range_start"`
RangeEnd int `db:"range_end"`
status ModelDefinitionStatus `db:"status"`
}{
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
}
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
}
func RunTaskTrain(b BasePack, task Task) (err error) {
l := b.GetLogger()
@@ -1929,7 +1785,132 @@ func handleTrain(handle *Handle) {
return c.SendJSON(id)
})
handle.Post("/model/train/retrain", handleRetrain)
PostAuthJson(handle, "/model/train/retrain", User_Normal, func(c *Context, dat *JustId) *Error {
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Faield to get model", err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
return c.E500M("Failed to retrain model", err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
var newHead = struct {
DefId string `db:"def_id"`
RangeStart int `db:"range_start"`
RangeEnd int `db:"range_end"`
Status ModelDefinitionStatus `db:"status"`
}{
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
}
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
})
handle.Get("/model/epoch/update", func(c *Context) *Error {
f := c.R.URL.Query()

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"os"
"path"
"runtime/debug"
"strings"
"time"
@@ -57,7 +58,7 @@ func handleError(err *Error, c *Context) {
e = c.SendJSON(500)
}
if e != nil {
c.Logger.Error("Something went very wrong while trying to send and error message")
c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
c.Writer.Write([]byte("505"))
}
}
@@ -195,7 +196,7 @@ func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.Use
func handleLoop(array []HandleFunc, context *Context) {
defer func() {
if r := recover(); r != nil {
context.Logger.Error("Something went very wrong", "Error", r)
context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
handleError(&Error{500, "500"}, context)
}
}()
@@ -418,8 +419,7 @@ func (c Context) ErrorCode(err error, code int, data any) *Error {
c.SetReportCaller(false)
}
if err != nil {
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
c.Logger.Error(err)
c.Logger.Error("Something went wrong returning with:", "Error", err)
}
return &Error{code, data}
}