fix model retrain not working closes #93
This commit is contained in:
parent
026932cfab
commit
f182b205f8
@ -3,6 +3,7 @@ version: '3.1'
|
||||
services:
|
||||
db:
|
||||
image: docker.andr3h3nriqu3s.com/services/postgres
|
||||
command: -c 'max_connections=400'
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_PASSWORD: verysafepassword
|
||||
|
@ -1658,150 +1658,6 @@ func trainRetrain(c *Context, model *BaseModel, defId string) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleRetrain(c *Context) *Error {
|
||||
var err error = nil
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var dat JustId
|
||||
|
||||
if err_ := c.ToJSON(&dat); err_ != nil {
|
||||
return err_
|
||||
}
|
||||
|
||||
if dat.Id == "" {
|
||||
return c.JsonBadRequest("Please provide a id")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c.Db, dat.Id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||
}
|
||||
|
||||
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||
|
||||
classesUpdated := false
|
||||
|
||||
failed := func() *Error {
|
||||
if classesUpdated {
|
||||
ResetClasses(c, model)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
c.Logger.Error("Failed to retrain", "err", err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
var def struct {
|
||||
Id string
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
type C struct {
|
||||
Id string
|
||||
ClassOrder int `db:"class_order"`
|
||||
}
|
||||
|
||||
err = c.StartTx()
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
classes, err := GetDbMultitple[C](
|
||||
c,
|
||||
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||
model.Id,
|
||||
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||
)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
if len(classes) == 0 {
|
||||
c.Logger.Error("No classes are available!")
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = c.CommitTx()
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
classesUpdated = true
|
||||
}
|
||||
|
||||
var newHead = struct {
|
||||
DefId string `db:"def_id"`
|
||||
RangeStart int `db:"range_start"`
|
||||
RangeEnd int `db:"range_end"`
|
||||
status ModelDefinitionStatus `db:"status"`
|
||||
}{
|
||||
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
|
||||
}
|
||||
|
||||
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
go trainRetrain(c, model, def.Id)
|
||||
|
||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
fmt.Println(err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
}
|
||||
|
||||
func RunTaskTrain(b BasePack, task Task) (err error) {
|
||||
l := b.GetLogger()
|
||||
|
||||
@ -1929,7 +1785,132 @@ func handleTrain(handle *Handle) {
|
||||
return c.SendJSON(id)
|
||||
})
|
||||
|
||||
handle.Post("/model/train/retrain", handleRetrain)
|
||||
PostAuthJson(handle, "/model/train/retrain", User_Normal, func(c *Context, dat *JustId) *Error {
|
||||
model, err := GetBaseModel(c.Db, dat.Id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
} else if err != nil {
|
||||
return c.E500M("Faield to get model", err)
|
||||
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||
}
|
||||
|
||||
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||
|
||||
classesUpdated := false
|
||||
|
||||
failed := func() *Error {
|
||||
if classesUpdated {
|
||||
ResetClasses(c, model)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
return c.E500M("Failed to retrain model", err)
|
||||
}
|
||||
|
||||
var def struct {
|
||||
Id string
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
type C struct {
|
||||
Id string
|
||||
ClassOrder int `db:"class_order"`
|
||||
}
|
||||
|
||||
err = c.StartTx()
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
classes, err := GetDbMultitple[C](
|
||||
c,
|
||||
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||
model.Id,
|
||||
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||
)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
if len(classes) == 0 {
|
||||
c.Logger.Error("No classes are available!")
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = c.CommitTx()
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
classesUpdated = true
|
||||
}
|
||||
|
||||
var newHead = struct {
|
||||
DefId string `db:"def_id"`
|
||||
RangeStart int `db:"range_start"`
|
||||
RangeEnd int `db:"range_end"`
|
||||
Status ModelDefinitionStatus `db:"status"`
|
||||
}{
|
||||
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
|
||||
}
|
||||
|
||||
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
go trainRetrain(c, model, def.Id)
|
||||
|
||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
fmt.Println(err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
})
|
||||
|
||||
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
||||
f := c.R.URL.Query()
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -57,7 +58,7 @@ func handleError(err *Error, c *Context) {
|
||||
e = c.SendJSON(500)
|
||||
}
|
||||
if e != nil {
|
||||
c.Logger.Error("Something went very wrong while trying to send and error message")
|
||||
c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
|
||||
c.Writer.Write([]byte("505"))
|
||||
}
|
||||
}
|
||||
@ -195,7 +196,7 @@ func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.Use
|
||||
func handleLoop(array []HandleFunc, context *Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
context.Logger.Error("Something went very wrong", "Error", r)
|
||||
context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
|
||||
handleError(&Error{500, "500"}, context)
|
||||
}
|
||||
}()
|
||||
@ -418,8 +419,7 @@ func (c Context) ErrorCode(err error, code int, data any) *Error {
|
||||
c.SetReportCaller(false)
|
||||
}
|
||||
if err != nil {
|
||||
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
|
||||
c.Logger.Error(err)
|
||||
c.Logger.Error("Something went wrong returning with:", "Error", err)
|
||||
}
|
||||
return &Error{code, data}
|
||||
}
|
||||
|
@ -54,7 +54,7 @@
|
||||
try {
|
||||
let temp_model: Model = await get(`models/edit?id=${id}`);
|
||||
|
||||
if ([3, 6].includes(temp_model.status)) {
|
||||
if ([3, 7, 6].includes(temp_model.status)) {
|
||||
setTimeout(getModel, 2000);
|
||||
}
|
||||
|
||||
@ -143,7 +143,7 @@
|
||||
>
|
||||
Model
|
||||
</button>
|
||||
{#if _model && [2, 3, 4, 5, 6, 7].includes(_model.status)}
|
||||
{#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)}
|
||||
<button
|
||||
class="tab"
|
||||
on:click|preventDefault={setActive('model-data')}
|
||||
@ -152,7 +152,7 @@
|
||||
Model Data
|
||||
</button>
|
||||
{/if}
|
||||
{#if _model && [5, 6, 7].includes(_model.status)}
|
||||
{#if _model && [5, 6, 7, -6, -7].includes(_model.status)}
|
||||
<button
|
||||
class="tab"
|
||||
on:click|preventDefault={setActive('tasks')}
|
||||
|
@ -15,6 +15,7 @@
|
||||
import { createEventDispatcher } from 'svelte';
|
||||
import ModelTable from './ModelTable.svelte';
|
||||
import TrainModel from './TrainModel.svelte';
|
||||
import ZipStructure from './ZipStructure.svelte';
|
||||
|
||||
let { model, simple } = $props<{ model: Model; simple?: boolean }>();
|
||||
|
||||
@ -103,32 +104,7 @@
|
||||
<br />
|
||||
Each of the folders will contain the classes of the model. The folders must be the same
|
||||
in testing and training. The class folders must have the images for the classes.
|
||||
<pre>
|
||||
training\
|
||||
class1\
|
||||
img1.png
|
||||
img2.png
|
||||
img2.png
|
||||
...
|
||||
class2\
|
||||
img1.png
|
||||
img2.png
|
||||
img2.png
|
||||
...
|
||||
...
|
||||
testing\
|
||||
class1\
|
||||
img1.png
|
||||
img2.png
|
||||
img2.png
|
||||
...
|
||||
class2\
|
||||
img1.png
|
||||
img2.png
|
||||
img2.png
|
||||
...
|
||||
...
|
||||
</pre>
|
||||
<ZipStructure />
|
||||
</div>
|
||||
<FileUpload replace_slot bind:file accept="application/zip" notExpand>
|
||||
<img src="/imgs/upload-icon.png" alt="" />
|
||||
|
@ -11,6 +11,7 @@
|
||||
Colors
|
||||
} from 'chart.js';
|
||||
import type { ModelStats } from './types';
|
||||
import { onDestroy } from 'svelte';
|
||||
|
||||
Chart.register(
|
||||
Title,
|
||||
@ -23,7 +24,7 @@
|
||||
Colors
|
||||
);
|
||||
|
||||
let { data } = $props<{ data: ModelStats }>();
|
||||
let { data }: { data: ModelStats } = $props();
|
||||
|
||||
let ctx: HTMLCanvasElement;
|
||||
|
||||
@ -78,6 +79,10 @@
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
onDestroy(() => {
|
||||
if (chart) chart.destroy();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div><canvas bind:this={ctx} /></div>
|
||||
|
@ -255,6 +255,8 @@
|
||||
{:else if selected_class?.status == 3}
|
||||
Class trained
|
||||
{/if}
|
||||
{:else}
|
||||
Class to train
|
||||
{/if}
|
||||
<button on:click={() => uploadImageDialog.showModal()}> Upload Image </button>
|
||||
</h2>
|
||||
|
Loading…
Reference in New Issue
Block a user