runner now done
This commit is contained in:
parent
9eba93cd3c
commit
a0aed71b3c
@ -18,10 +18,10 @@ const (
|
|||||||
type Layer struct {
|
type Layer struct {
|
||||||
Id string `db:"mdl.id" json:"id"`
|
Id string `db:"mdl.id" json:"id"`
|
||||||
DefinitionId string `db:"mdl.def_id" json:"definition_id"`
|
DefinitionId string `db:"mdl.def_id" json:"definition_id"`
|
||||||
LayerOrder string `db:"mdl.layer_order" json:"layer_order"`
|
LayerOrder int `db:"mdl.layer_order" json:"layer_order"`
|
||||||
LayerType LayerType `db:"mdl.layer_type" json:"layer_type"`
|
LayerType LayerType `db:"mdl.layer_type" json:"layer_type"`
|
||||||
Shape string `db:"mdl.shape" json:"shape"`
|
Shape string `db:"mdl.shape" json:"shape"`
|
||||||
ExpType string `db:"mdl.exp_type" json:"exp_type"`
|
ExpType int `db:"mdl.exp_type" json:"exp_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShapeToString(args ...int) string {
|
func ShapeToString(args ...int) string {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,7 +40,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model.ModelType == 2 {
|
if model.ModelType == 2 {
|
||||||
panic("TODO")
|
|
||||||
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
||||||
if full_error != nil {
|
if full_error != nil {
|
||||||
l.Error("Failed to generate defintions", "err", full_error)
|
l.Error("Failed to generate defintions", "err", full_error)
|
||||||
@ -75,4 +75,44 @@ func CleanUpFailed(b BasePack, task *Task) {
|
|||||||
l.Error("Failed to get status", err)
|
l.Error("Failed to get status", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Set the class status to trained
|
||||||
|
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
l.Error("Failed to set class status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CleanUpFailedRetrain(b BasePack, task *Task) {
|
||||||
|
db := b.GetDb()
|
||||||
|
l := b.GetLogger()
|
||||||
|
|
||||||
|
model, err := GetBaseModel(db, *task.ModelId)
|
||||||
|
if err != nil {
|
||||||
|
l.Error("Failed to get model", "err", err)
|
||||||
|
} else {
|
||||||
|
err = model.UpdateStatus(db, FAILED_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
l.Error("Failed to get status", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ResetClasses(b, model)
|
||||||
|
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
|
||||||
|
|
||||||
|
var defData struct {
|
||||||
|
Id string `db:"md.id"`
|
||||||
|
TargetAcuuracy float64 `db:"md.target_accuracy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to get def data", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
|
||||||
|
if err_ != nil {
|
||||||
|
panic(err_)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,10 @@ func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, arg
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) {
|
||||||
|
return setModelClassStatus(c, status, filter, args...)
|
||||||
|
}
|
||||||
|
|
||||||
func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) {
|
func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) {
|
||||||
|
|
||||||
db := c.GetDb()
|
db := c.GetDb()
|
||||||
@ -1090,7 +1094,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = splitModel(c, model); err != nil {
|
if err = SplitModel(c, model); err != nil {
|
||||||
err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Error("Failed to split the model! And Failed to set class status")
|
l.Error("Failed to split the model! And Failed to set class status")
|
||||||
@ -1123,7 +1127,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitModel(c BasePack, model *BaseModel) (err error) {
|
func SplitModel(c BasePack, model *BaseModel) (err error) {
|
||||||
db := c.GetDb()
|
db := c.GetDb()
|
||||||
l := c.GetLogger()
|
l := c.GetLogger()
|
||||||
|
|
||||||
@ -1260,7 +1264,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
|
|||||||
}
|
}
|
||||||
|
|
||||||
db := c.GetDb()
|
db := c.GetDb()
|
||||||
l := c.GetLogger()
|
|
||||||
|
|
||||||
def, err := MakeDefenition(db, model.Id, target_accuracy)
|
def, err := MakeDefenition(db, model.Id, target_accuracy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1279,67 +1282,35 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
|
|||||||
}
|
}
|
||||||
order++
|
order++
|
||||||
|
|
||||||
if complexity == 0 {
|
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
|
||||||
/*
|
for i := 0; i < loop; i++ {
|
||||||
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
|
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
|
||||||
if err != nil {
|
order++
|
||||||
failed()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
order++
|
|
||||||
*/
|
|
||||||
|
|
||||||
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failed()
|
failed()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
order++
|
}
|
||||||
|
|
||||||
loop := int(math.Log2(float64(number_of_classes)))
|
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
|
||||||
for i := 0; i < loop; i++ {
|
if err != nil {
|
||||||
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
|
|
||||||
order++
|
|
||||||
if err != nil {
|
|
||||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if complexity == 1 || complexity == 2 {
|
|
||||||
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
|
|
||||||
for i := 0; i < loop; i++ {
|
|
||||||
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
|
|
||||||
order++
|
|
||||||
if err != nil {
|
|
||||||
failed()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
|
|
||||||
if err != nil {
|
|
||||||
failed()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
order++
|
|
||||||
|
|
||||||
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
|
|
||||||
if loop == 0 {
|
|
||||||
loop = 1
|
|
||||||
}
|
|
||||||
for i := 0; i < loop; i++ {
|
|
||||||
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
|
|
||||||
order++
|
|
||||||
if err != nil {
|
|
||||||
failed()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
l.Error("Unkown complexity", "complexity", complexity)
|
|
||||||
failed()
|
failed()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
order++
|
||||||
|
|
||||||
|
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
|
||||||
|
if loop == 0 {
|
||||||
|
loop = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < loop; i++ {
|
||||||
|
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
|
||||||
|
order++
|
||||||
|
if err != nil {
|
||||||
|
failed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
|
return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
|
||||||
}
|
}
|
||||||
@ -1410,19 +1381,12 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
|||||||
|
|
||||||
order := 1
|
order := 1
|
||||||
|
|
||||||
width := model.Width
|
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height), 1)
|
||||||
height := model.Height
|
if err != nil {
|
||||||
|
failed()
|
||||||
// Note the shape of the first layer defines the import size
|
return
|
||||||
if complexity == 2 {
|
|
||||||
// Note the shape for now is no used
|
|
||||||
width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
|
|
||||||
height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
|
|
||||||
l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
|
|
||||||
|
|
||||||
order++
|
order++
|
||||||
|
|
||||||
// handle the errors inside the pervious if block
|
// handle the errors inside the pervious if block
|
||||||
@ -1460,7 +1424,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
|||||||
order++
|
order++
|
||||||
|
|
||||||
// Flatten the blocks into dense
|
// Flatten the blocks into dense
|
||||||
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1)
|
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*2), 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failed()
|
failed()
|
||||||
return
|
return
|
||||||
@ -1474,7 +1438,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
|||||||
loop = max(loop, 3)
|
loop = max(loop, 3)
|
||||||
|
|
||||||
for i := 0; i < loop; i++ {
|
for i := 0; i < loop; i++ {
|
||||||
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)*2), 2)
|
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)*2), 2)
|
||||||
order++
|
order++
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failed()
|
failed()
|
||||||
@ -1747,6 +1711,13 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
l.Error("Error while updating the classes", "error", err)
|
||||||
|
failed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
|
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) {
|
func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) {
|
||||||
@ -34,6 +35,12 @@ type VerifyTask struct {
|
|||||||
TaskId string `json:"taskId" validate:"required"`
|
TaskId string `json:"taskId" validate:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RunnerTrainDef struct {
|
||||||
|
Id string `json:"id" validate:"required"`
|
||||||
|
TaskId string `json:"taskId" validate:"required"`
|
||||||
|
DefId string `json:"defId" validate:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) {
|
func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) {
|
||||||
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
@ -53,6 +60,18 @@ func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Erro
|
|||||||
return runner_data["task"].(*Task), nil
|
return runner_data["task"].(*Task), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func clearRunnerTask(x *Handle, runner_id string) {
|
||||||
|
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||||
|
mutex.Lock()
|
||||||
|
defer mutex.Unlock()
|
||||||
|
|
||||||
|
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
|
||||||
|
var runner_data map[string]interface{} = runners[runner_id].(map[string]interface{})
|
||||||
|
runner_data["task"] = nil
|
||||||
|
runners[runner_id] = runner_data
|
||||||
|
x.DataMap["runners"] = runners
|
||||||
|
}
|
||||||
|
|
||||||
func handleRemoteRunner(x *Handle) {
|
func handleRemoteRunner(x *Handle) {
|
||||||
|
|
||||||
type RegisterRunner struct {
|
type RegisterRunner struct {
|
||||||
@ -202,6 +221,8 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
switch task.TaskType {
|
switch task.TaskType {
|
||||||
case int(TASK_TYPE_TRAINING):
|
case int(TASK_TYPE_TRAINING):
|
||||||
CleanUpFailed(c, task)
|
CleanUpFailed(c, task)
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
CleanUpFailedRetrain(c, task)
|
||||||
case int(TASK_TYPE_CLASSIFICATION):
|
case int(TASK_TYPE_CLASSIFICATION):
|
||||||
// DO nothing
|
// DO nothing
|
||||||
default:
|
default:
|
||||||
@ -237,6 +258,8 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
switch task.TaskType {
|
switch task.TaskType {
|
||||||
case int(TASK_TYPE_TRAINING):
|
case int(TASK_TYPE_TRAINING):
|
||||||
status = DEFINITION_STATUS_INIT
|
status = DEFINITION_STATUS_INIT
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
fallthrough
|
||||||
case int(TASK_TYPE_CLASSIFICATION):
|
case int(TASK_TYPE_CLASSIFICATION):
|
||||||
status = DEFINITION_STATUS_READY
|
status = DEFINITION_STATUS_READY
|
||||||
default:
|
default:
|
||||||
@ -268,27 +291,35 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
return error
|
return error
|
||||||
}
|
}
|
||||||
|
|
||||||
switch task.TaskType {
|
|
||||||
case int(TASK_TYPE_TRAINING):
|
|
||||||
//DO NOTHING
|
|
||||||
case int(TASK_TYPE_CLASSIFICATION):
|
|
||||||
//DO NOTHING
|
|
||||||
default:
|
|
||||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
|
||||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
|
||||||
}
|
|
||||||
|
|
||||||
model, err := GetBaseModel(c, *task.ModelId)
|
model, err := GetBaseModel(c, *task.ModelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.E500M("Failed to get model information", err)
|
return c.E500M("Failed to get model information", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN)
|
switch task.TaskType {
|
||||||
if err != nil {
|
case int(TASK_TYPE_TRAINING):
|
||||||
return c.E500M("Failed to get the model classes", err)
|
classes, err := model.GetClasses(c, "and status in ($2, $3) order by mc.class_order asc", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to get the model classes", err)
|
||||||
|
}
|
||||||
|
return c.SendJSON(classes)
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to get the model classes", err)
|
||||||
|
}
|
||||||
|
return c.SendJSON(classes)
|
||||||
|
case int(TASK_TYPE_CLASSIFICATION):
|
||||||
|
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINED)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to get the model classes", err)
|
||||||
|
}
|
||||||
|
return c.SendJSON(classes)
|
||||||
|
default:
|
||||||
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.SendJSON(classes)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
type RunnerTrainDefStatus struct {
|
type RunnerTrainDefStatus struct {
|
||||||
@ -326,12 +357,13 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
return c.SendJSON("Ok")
|
return c.SendJSON("Ok")
|
||||||
})
|
})
|
||||||
|
|
||||||
type RunnerTrainDefLayers struct {
|
type RunnerTrainDefHeadStatus struct {
|
||||||
Id string `json:"id" validate:"required"`
|
Id string `json:"id" validate:"required"`
|
||||||
TaskId string `json:"taskId" validate:"required"`
|
TaskId string `json:"taskId" validate:"required"`
|
||||||
DefId string `json:"defId" validate:"required"`
|
DefId string `json:"defId" validate:"required"`
|
||||||
|
Status ModelHeadStatus `json:"status" validate:"required"`
|
||||||
}
|
}
|
||||||
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error {
|
PostAuthJson(x, "/tasks/runner/train/def/head/status", User_Normal, func(c *Context, dat *RunnerTrainDefHeadStatus) *Error {
|
||||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
return error
|
||||||
@ -352,6 +384,69 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
return c.E500M("Failed to get definition information", err)
|
return c.E500M("Failed to get definition information", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = c.Exec("update exp_model_head set status=$1 where def_id=$2;", dat.Status, def.Id)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to train definition!")
|
||||||
|
return c.E500M("Failed to train definition", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SendJSON("Ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
type RunnerRetrainDefHeadStatus struct {
|
||||||
|
Id string `json:"id" validate:"required"`
|
||||||
|
TaskId string `json:"taskId" validate:"required"`
|
||||||
|
HeadId string `json:"defId" validate:"required"`
|
||||||
|
Status ModelHeadStatus `json:"status" validate:"required"`
|
||||||
|
}
|
||||||
|
PostAuthJson(x, "/tasks/runner/retrain/def/head/status", User_Normal, func(c *Context, dat *RunnerRetrainDefHeadStatus) *Error {
|
||||||
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.TaskType != int(TASK_TYPE_RETRAINING) {
|
||||||
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := UpdateStatus(c.GetDb(), "exp_model_head", dat.HeadId, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||||
|
return c.E500M("Failed to update head status", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SendJSON("Ok")
|
||||||
|
})
|
||||||
|
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
|
||||||
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
switch task.TaskType {
|
||||||
|
case int(TASK_TYPE_TRAINING):
|
||||||
|
// Do nothing
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
// Do nothing
|
||||||
|
default:
|
||||||
|
c.Logger.Error("Task not is not the right type to get the layers", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the layers")
|
||||||
|
}
|
||||||
|
|
||||||
|
def, err := GetDefinition(c, dat.DefId)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to get definition information", err)
|
||||||
|
}
|
||||||
|
|
||||||
layers, err := def.GetLayers(c, " order by layer_order asc")
|
layers, err := def.GetLayers(c, " order by layer_order asc")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.E500M("Failed to get layers", err)
|
return c.E500M("Failed to get layers", err)
|
||||||
@ -371,7 +466,12 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
return error
|
return error
|
||||||
}
|
}
|
||||||
|
|
||||||
if task.TaskType != int(TASK_TYPE_TRAINING) {
|
switch task.TaskType {
|
||||||
|
case int(TASK_TYPE_TRAINING):
|
||||||
|
// DO nothing
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
// DO nothing
|
||||||
|
default:
|
||||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
}
|
}
|
||||||
@ -486,6 +586,8 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
switch task.TaskType {
|
switch task.TaskType {
|
||||||
case int(TASK_TYPE_TRAINING):
|
case int(TASK_TYPE_TRAINING):
|
||||||
//DO NOTHING
|
//DO NOTHING
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
//DO NOTHING
|
||||||
case int(TASK_TYPE_CLASSIFICATION):
|
case int(TASK_TYPE_CLASSIFICATION):
|
||||||
//DO NOTHING
|
//DO NOTHING
|
||||||
default:
|
default:
|
||||||
@ -501,6 +603,74 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
return c.SendJSON(model)
|
return c.SendJSON(model)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
PostAuthJson(x, "/tasks/runner/heads", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
|
||||||
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExpHead struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Start int `db:"range_start" json:"start"`
|
||||||
|
End int `db:"range_end" json:"end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
switch task.TaskType {
|
||||||
|
case int(TASK_TYPE_TRAINING):
|
||||||
|
fallthrough
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
// status = 2 (INIT) 3 (TRAINING)
|
||||||
|
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and status in (2,3)", dat.DefId)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed getting active heads", err)
|
||||||
|
}
|
||||||
|
return c.SendJSON(heads)
|
||||||
|
case int(TASK_TYPE_CLASSIFICATION):
|
||||||
|
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1", dat.DefId)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed getting active heads", err)
|
||||||
|
}
|
||||||
|
return c.SendJSON(heads)
|
||||||
|
default:
|
||||||
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
PostAuthJson(x, "/tasks/runner/train_exp/class/status/train", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||||
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
task, error := verifyTask(x, c, dat)
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.TaskType != int(TASK_TYPE_TRAINING) {
|
||||||
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetBaseModel(c, *task.ModelId)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to get model", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = SetModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed update status", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SendJSON("Ok")
|
||||||
|
})
|
||||||
|
|
||||||
PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
if error != nil {
|
if error != nil {
|
||||||
@ -568,6 +738,13 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
return c.E500M("Failed to delete unsed definitions", err)
|
return c.E500M("Failed to delete unsed definitions", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the class status to trained
|
||||||
|
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1;", model.Id)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to set class status")
|
||||||
|
return c.E500M("Failed to set class status", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err = model.UpdateStatus(c, READY); err != nil {
|
if err = model.UpdateStatus(c, READY); err != nil {
|
||||||
model.UpdateStatus(c, FAILED_TRAINING)
|
model.UpdateStatus(c, FAILED_TRAINING)
|
||||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
||||||
@ -576,16 +753,7 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
|
|
||||||
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
||||||
|
|
||||||
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
clearRunnerTask(x, dat.Id)
|
||||||
mutex.Lock()
|
|
||||||
defer mutex.Unlock()
|
|
||||||
|
|
||||||
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
|
|
||||||
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
|
|
||||||
runner_data["task"] = nil
|
|
||||||
runners[dat.Id] = runner_data
|
|
||||||
x.DataMap["runners"] = runners
|
|
||||||
|
|
||||||
return c.SendJSON("Ok")
|
return c.SendJSON("Ok")
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -635,4 +803,153 @@ func handleRemoteRunner(x *Handle) {
|
|||||||
|
|
||||||
return c.SendJSON("Ok")
|
return c.SendJSON("Ok")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
PostAuthJson(x, "/tasks/runner/train_exp/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||||
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
task, error := verifyTask(x, c, dat)
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.TaskType != int(TASK_TYPE_TRAINING) {
|
||||||
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetBaseModel(c, *task.ModelId)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to get model", "err", err)
|
||||||
|
return c.E500M("Failed to get mode", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO add check the to the model
|
||||||
|
|
||||||
|
var def Definition
|
||||||
|
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
|
||||||
|
if err == NotFoundError {
|
||||||
|
// TODO Make the Model status have a message
|
||||||
|
c.Logger.Error("All definitions failed to train!")
|
||||||
|
model.UpdateStatus(c, FAILED_TRAINING)
|
||||||
|
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
|
||||||
|
clearRunnerTask(x, dat.Id)
|
||||||
|
return c.SendJSON("Ok")
|
||||||
|
} else if err != nil {
|
||||||
|
model.UpdateStatus(c, FAILED_TRAINING)
|
||||||
|
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
|
||||||
|
return c.E500M("Failed to get model definition", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
|
||||||
|
model.UpdateStatus(c, FAILED_TRAINING)
|
||||||
|
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
|
||||||
|
return c.E500M("Failed to update model definition", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
to_delete, err := GetDbMultitple[JustId](c, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
c.GetLogger().Error("Failed to select model_definition to delete")
|
||||||
|
return c.E500M("Failed to select model definition to delete", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range to_delete {
|
||||||
|
os.RemoveAll(path.Join("savedData", model.Id, "defs", d.Id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO Check if returning also works here
|
||||||
|
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
|
||||||
|
model.UpdateStatus(c, FAILED_TRAINING)
|
||||||
|
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
||||||
|
return c.E500M("Failed to delete unsed definitions", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = SplitModel(c, model); err != nil {
|
||||||
|
err = SetModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to split the model! And Failed to set class status")
|
||||||
|
return c.E500M("Failed to split the model", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Error("Failed to split the model")
|
||||||
|
return c.E500M("Failed to split the model", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the class status to trained
|
||||||
|
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to set class status")
|
||||||
|
return c.E500M("Failed to set class status", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
|
||||||
|
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
|
||||||
|
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
|
||||||
|
|
||||||
|
if err = model.UpdateStatus(c, READY); err != nil {
|
||||||
|
model.UpdateStatus(c, FAILED_TRAINING)
|
||||||
|
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
||||||
|
return c.E500M("Failed to update status of model", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
||||||
|
|
||||||
|
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||||
|
mutex.Lock()
|
||||||
|
defer mutex.Unlock()
|
||||||
|
|
||||||
|
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
|
||||||
|
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
|
||||||
|
runner_data["task"] = nil
|
||||||
|
runners[dat.Id] = runner_data
|
||||||
|
x.DataMap["runners"] = runners
|
||||||
|
|
||||||
|
return c.SendJSON("Ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
PostAuthJson(x, "/tasks/runner/retrain/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||||
|
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
task, error := verifyTask(x, c, dat)
|
||||||
|
if error != nil {
|
||||||
|
return error
|
||||||
|
}
|
||||||
|
|
||||||
|
if task.TaskType != int(TASK_TYPE_RETRAINING) {
|
||||||
|
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||||
|
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetBaseModel(c, *task.ModelId)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to get model", "err", err)
|
||||||
|
return c.E500M("Failed to get mode", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to set class status", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to set head status", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = model.UpdateStatus(c, READY)
|
||||||
|
if err != nil {
|
||||||
|
return c.E500M("Failed to set class status", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
||||||
|
clearRunnerTask(x, dat.Id)
|
||||||
|
|
||||||
|
return c.SendJSON("Ok")
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -120,6 +120,12 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
|
|||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
|
|
||||||
switch task.TaskType {
|
switch task.TaskType {
|
||||||
|
case int(TASK_TYPE_RETRAINING):
|
||||||
|
runners := handler.DataMap["runners"].(map[string]interface{})
|
||||||
|
runner := runners[runner_id].(map[string]interface{})
|
||||||
|
runner["task"] = &task
|
||||||
|
runners[runner_id] = runner
|
||||||
|
handler.DataMap["runners"] = runners
|
||||||
case int(TASK_TYPE_TRAINING):
|
case int(TASK_TYPE_TRAINING):
|
||||||
if err := PrepareTraining(handler, base, task, runner_id); err != nil {
|
if err := PrepareTraining(handler, base, task, runner_id); err != nil {
|
||||||
logger.Error("Failed to prepare for training", "err", err)
|
logger.Error("Failed to prepare for training", "err", err)
|
||||||
|
Loading…
Reference in New Issue
Block a user