fix model retrain not working closes #93
This commit is contained in:
@@ -1658,150 +1658,6 @@ func trainRetrain(c *Context, model *BaseModel, defId string) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleRetrain(c *Context) *Error {
|
||||
var err error = nil
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var dat JustId
|
||||
|
||||
if err_ := c.ToJSON(&dat); err_ != nil {
|
||||
return err_
|
||||
}
|
||||
|
||||
if dat.Id == "" {
|
||||
return c.JsonBadRequest("Please provide a id")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c.Db, dat.Id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||
}
|
||||
|
||||
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||
|
||||
classesUpdated := false
|
||||
|
||||
failed := func() *Error {
|
||||
if classesUpdated {
|
||||
ResetClasses(c, model)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
c.Logger.Error("Failed to retrain", "err", err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
var def struct {
|
||||
Id string
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
type C struct {
|
||||
Id string
|
||||
ClassOrder int `db:"class_order"`
|
||||
}
|
||||
|
||||
err = c.StartTx()
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
classes, err := GetDbMultitple[C](
|
||||
c,
|
||||
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||
model.Id,
|
||||
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||
)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
if len(classes) == 0 {
|
||||
c.Logger.Error("No classes are available!")
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = c.CommitTx()
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
classesUpdated = true
|
||||
}
|
||||
|
||||
var newHead = struct {
|
||||
DefId string `db:"def_id"`
|
||||
RangeStart int `db:"range_start"`
|
||||
RangeEnd int `db:"range_end"`
|
||||
status ModelDefinitionStatus `db:"status"`
|
||||
}{
|
||||
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
|
||||
}
|
||||
|
||||
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
go trainRetrain(c, model, def.Id)
|
||||
|
||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
fmt.Println(err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
}
|
||||
|
||||
func RunTaskTrain(b BasePack, task Task) (err error) {
|
||||
l := b.GetLogger()
|
||||
|
||||
@@ -1929,7 +1785,132 @@ func handleTrain(handle *Handle) {
|
||||
return c.SendJSON(id)
|
||||
})
|
||||
|
||||
handle.Post("/model/train/retrain", handleRetrain)
|
||||
PostAuthJson(handle, "/model/train/retrain", User_Normal, func(c *Context, dat *JustId) *Error {
|
||||
model, err := GetBaseModel(c.Db, dat.Id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
} else if err != nil {
|
||||
return c.E500M("Faield to get model", err)
|
||||
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||
}
|
||||
|
||||
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||
|
||||
classesUpdated := false
|
||||
|
||||
failed := func() *Error {
|
||||
if classesUpdated {
|
||||
ResetClasses(c, model)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
return c.E500M("Failed to retrain model", err)
|
||||
}
|
||||
|
||||
var def struct {
|
||||
Id string
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
type C struct {
|
||||
Id string
|
||||
ClassOrder int `db:"class_order"`
|
||||
}
|
||||
|
||||
err = c.StartTx()
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
classes, err := GetDbMultitple[C](
|
||||
c,
|
||||
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||
model.Id,
|
||||
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||
)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
if len(classes) == 0 {
|
||||
c.Logger.Error("No classes are available!")
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = c.CommitTx()
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
classesUpdated = true
|
||||
}
|
||||
|
||||
var newHead = struct {
|
||||
DefId string `db:"def_id"`
|
||||
RangeStart int `db:"range_start"`
|
||||
RangeEnd int `db:"range_end"`
|
||||
Status ModelDefinitionStatus `db:"status"`
|
||||
}{
|
||||
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
|
||||
}
|
||||
|
||||
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
go trainRetrain(c, model, def.Id)
|
||||
|
||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
fmt.Println(err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
})
|
||||
|
||||
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
||||
f := c.R.URL.Query()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -57,7 +58,7 @@ func handleError(err *Error, c *Context) {
|
||||
e = c.SendJSON(500)
|
||||
}
|
||||
if e != nil {
|
||||
c.Logger.Error("Something went very wrong while trying to send and error message")
|
||||
c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
|
||||
c.Writer.Write([]byte("505"))
|
||||
}
|
||||
}
|
||||
@@ -195,7 +196,7 @@ func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.Use
|
||||
func handleLoop(array []HandleFunc, context *Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
context.Logger.Error("Something went very wrong", "Error", r)
|
||||
context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
|
||||
handleError(&Error{500, "500"}, context)
|
||||
}
|
||||
}()
|
||||
@@ -418,8 +419,7 @@ func (c Context) ErrorCode(err error, code int, data any) *Error {
|
||||
c.SetReportCaller(false)
|
||||
}
|
||||
if err != nil {
|
||||
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
|
||||
c.Logger.Error(err)
|
||||
c.Logger.Error("Something went wrong returning with:", "Error", err)
|
||||
}
|
||||
return &Error{code, data}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user