fix model retrain not working closes #93

This commit is contained in:
Andre Henriques 2024-04-16 16:02:57 +01:00
parent 026932cfab
commit f182b205f8
7 changed files with 144 additions and 179 deletions

View File

@ -3,6 +3,7 @@ version: '3.1'
services: services:
db: db:
image: docker.andr3h3nriqu3s.com/services/postgres image: docker.andr3h3nriqu3s.com/services/postgres
command: -c 'max_connections=400'
restart: always restart: always
environment: environment:
POSTGRES_PASSWORD: verysafepassword POSTGRES_PASSWORD: verysafepassword

View File

@ -1658,150 +1658,6 @@ func trainRetrain(c *Context, model *BaseModel, defId string) {
} }
} }
func handleRetrain(c *Context) *Error {
var err error = nil
if !c.CheckAuthLevel(1) {
return nil
}
var dat JustId
if err_ := c.ToJSON(&dat); err_ != nil {
return err_
}
if dat.Id == "" {
return c.JsonBadRequest("Please provide a id")
}
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.Error500(err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
// TODO improve this response
return c.Error500(err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
var newHead = struct {
DefId string `db:"def_id"`
RangeStart int `db:"range_start"`
RangeEnd int `db:"range_end"`
status ModelDefinitionStatus `db:"status"`
}{
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
}
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
}
func RunTaskTrain(b BasePack, task Task) (err error) { func RunTaskTrain(b BasePack, task Task) (err error) {
l := b.GetLogger() l := b.GetLogger()
@ -1929,7 +1785,132 @@ func handleTrain(handle *Handle) {
return c.SendJSON(id) return c.SendJSON(id)
}) })
handle.Post("/model/train/retrain", handleRetrain) PostAuthJson(handle, "/model/train/retrain", User_Normal, func(c *Context, dat *JustId) *Error {
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Faield to get model", err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
return c.E500M("Failed to retrain model", err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
var newHead = struct {
DefId string `db:"def_id"`
RangeStart int `db:"range_start"`
RangeEnd int `db:"range_end"`
Status ModelDefinitionStatus `db:"status"`
}{
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
}
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
})
handle.Get("/model/epoch/update", func(c *Context) *Error { handle.Get("/model/epoch/update", func(c *Context) *Error {
f := c.R.URL.Query() f := c.R.URL.Query()

View File

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"os" "os"
"path" "path"
"runtime/debug"
"strings" "strings"
"time" "time"
@ -57,7 +58,7 @@ func handleError(err *Error, c *Context) {
e = c.SendJSON(500) e = c.SendJSON(500)
} }
if e != nil { if e != nil {
c.Logger.Error("Something went very wrong while trying to send and error message") c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
c.Writer.Write([]byte("505")) c.Writer.Write([]byte("505"))
} }
} }
@ -195,7 +196,7 @@ func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.Use
func handleLoop(array []HandleFunc, context *Context) { func handleLoop(array []HandleFunc, context *Context) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
context.Logger.Error("Something went very wrong", "Error", r) context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
handleError(&Error{500, "500"}, context) handleError(&Error{500, "500"}, context)
} }
}() }()
@ -418,8 +419,7 @@ func (c Context) ErrorCode(err error, code int, data any) *Error {
c.SetReportCaller(false) c.SetReportCaller(false)
} }
if err != nil { if err != nil {
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code) c.Logger.Error("Something went wrong returning with:", "Error", err)
c.Logger.Error(err)
} }
return &Error{code, data} return &Error{code, data}
} }

View File

@ -54,7 +54,7 @@
try { try {
let temp_model: Model = await get(`models/edit?id=${id}`); let temp_model: Model = await get(`models/edit?id=${id}`);
if ([3, 6].includes(temp_model.status)) { if ([3, 7, 6].includes(temp_model.status)) {
setTimeout(getModel, 2000); setTimeout(getModel, 2000);
} }
@ -143,7 +143,7 @@
> >
Model Model
</button> </button>
{#if _model && [2, 3, 4, 5, 6, 7].includes(_model.status)} {#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)}
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('model-data')} on:click|preventDefault={setActive('model-data')}
@ -152,7 +152,7 @@
Model Data Model Data
</button> </button>
{/if} {/if}
{#if _model && [5, 6, 7].includes(_model.status)} {#if _model && [5, 6, 7, -6, -7].includes(_model.status)}
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('tasks')} on:click|preventDefault={setActive('tasks')}

View File

@ -15,6 +15,7 @@
import { createEventDispatcher } from 'svelte'; import { createEventDispatcher } from 'svelte';
import ModelTable from './ModelTable.svelte'; import ModelTable from './ModelTable.svelte';
import TrainModel from './TrainModel.svelte'; import TrainModel from './TrainModel.svelte';
import ZipStructure from './ZipStructure.svelte';
let { model, simple } = $props<{ model: Model; simple?: boolean }>(); let { model, simple } = $props<{ model: Model; simple?: boolean }>();
@ -103,32 +104,7 @@
<br /> <br />
Each of the folders will contain the classes of the model. The folders must be the same Each of the folders will contain the classes of the model. The folders must be the same
in testing and training. The class folders must have the images for the classes. in testing and training. The class folders must have the images for the classes.
<pre> <ZipStructure />
training\
class1\
img1.png
img2.png
img2.png
...
class2\
img1.png
img2.png
img2.png
...
...
testing\
class1\
img1.png
img2.png
img2.png
...
class2\
img1.png
img2.png
img2.png
...
...
</pre>
</div> </div>
<FileUpload replace_slot bind:file accept="application/zip" notExpand> <FileUpload replace_slot bind:file accept="application/zip" notExpand>
<img src="/imgs/upload-icon.png" alt="" /> <img src="/imgs/upload-icon.png" alt="" />

View File

@ -11,6 +11,7 @@
Colors Colors
} from 'chart.js'; } from 'chart.js';
import type { ModelStats } from './types'; import type { ModelStats } from './types';
import { onDestroy } from 'svelte';
Chart.register( Chart.register(
Title, Title,
@ -23,7 +24,7 @@
Colors Colors
); );
let { data } = $props<{ data: ModelStats }>(); let { data }: { data: ModelStats } = $props();
let ctx: HTMLCanvasElement; let ctx: HTMLCanvasElement;
@ -78,6 +79,10 @@
} }
} }
}); });
onDestroy(() => {
if (chart) chart.destroy();
});
</script> </script>
<div><canvas bind:this={ctx} /></div> <div><canvas bind:this={ctx} /></div>

View File

@ -255,6 +255,8 @@
{:else if selected_class?.status == 3} {:else if selected_class?.status == 3}
Class trained Class trained
{/if} {/if}
{:else}
Class to train
{/if} {/if}
<button on:click={() => uploadImageDialog.showModal()}> Upload Image </button> <button on:click={() => uploadImageDialog.showModal()}> Upload Image </button>
</h2> </h2>