fix model retrain not working closes #93
This commit is contained in:
parent
026932cfab
commit
f182b205f8
@ -3,6 +3,7 @@ version: '3.1'
|
|||||||
services:
|
services:
|
||||||
db:
|
db:
|
||||||
image: docker.andr3h3nriqu3s.com/services/postgres
|
image: docker.andr3h3nriqu3s.com/services/postgres
|
||||||
|
command: -c 'max_connections=400'
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_PASSWORD: verysafepassword
|
POSTGRES_PASSWORD: verysafepassword
|
||||||
|
@ -1658,150 +1658,6 @@ func trainRetrain(c *Context, model *BaseModel, defId string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleRetrain(c *Context) *Error {
|
|
||||||
var err error = nil
|
|
||||||
if !c.CheckAuthLevel(1) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var dat JustId
|
|
||||||
|
|
||||||
if err_ := c.ToJSON(&dat); err_ != nil {
|
|
||||||
return err_
|
|
||||||
}
|
|
||||||
|
|
||||||
if dat.Id == "" {
|
|
||||||
return c.JsonBadRequest("Please provide a id")
|
|
||||||
}
|
|
||||||
|
|
||||||
model, err := GetBaseModel(c.Db, dat.Id)
|
|
||||||
if err == ModelNotFoundError {
|
|
||||||
return c.JsonBadRequest("Model not found")
|
|
||||||
} else if err != nil {
|
|
||||||
return c.Error500(err)
|
|
||||||
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
|
||||||
return c.JsonBadRequest("Model in invalid status for re-training")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
|
||||||
|
|
||||||
classesUpdated := false
|
|
||||||
|
|
||||||
failed := func() *Error {
|
|
||||||
if classesUpdated {
|
|
||||||
ResetClasses(c, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
|
||||||
c.Logger.Error("Failed to retrain", "err", err)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var def struct {
|
|
||||||
Id string
|
|
||||||
TargetAccuracy int `db:"target_accuracy"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
|
||||||
if err != nil {
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
type C struct {
|
|
||||||
Id string
|
|
||||||
ClassOrder int `db:"class_order"`
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.StartTx()
|
|
||||||
if err != nil {
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
classes, err := GetDbMultitple[C](
|
|
||||||
c,
|
|
||||||
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
|
||||||
model.Id,
|
|
||||||
MODEL_CLASS_STATUS_TO_TRAIN,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
_err := c.RollbackTx()
|
|
||||||
if _err != nil {
|
|
||||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
|
||||||
}
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(classes) == 0 {
|
|
||||||
c.Logger.Error("No classes are available!")
|
|
||||||
_err := c.RollbackTx()
|
|
||||||
if _err != nil {
|
|
||||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
|
||||||
}
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
//Update the classes
|
|
||||||
{
|
|
||||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
|
||||||
err = err2
|
|
||||||
if err != nil {
|
|
||||||
_err := c.RollbackTx()
|
|
||||||
if _err != nil {
|
|
||||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
|
||||||
}
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
|
||||||
if err != nil {
|
|
||||||
_err := c.RollbackTx()
|
|
||||||
if _err != nil {
|
|
||||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
|
||||||
}
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.CommitTx()
|
|
||||||
if err != nil {
|
|
||||||
_err := c.RollbackTx()
|
|
||||||
if _err != nil {
|
|
||||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
|
||||||
}
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
classesUpdated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
var newHead = struct {
|
|
||||||
DefId string `db:"def_id"`
|
|
||||||
RangeStart int `db:"range_start"`
|
|
||||||
RangeEnd int `db:"range_end"`
|
|
||||||
status ModelDefinitionStatus `db:"status"`
|
|
||||||
}{
|
|
||||||
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
|
|
||||||
if err != nil {
|
|
||||||
return failed()
|
|
||||||
}
|
|
||||||
|
|
||||||
go trainRetrain(c, model, def.Id)
|
|
||||||
|
|
||||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Failed to update model status")
|
|
||||||
fmt.Println(err)
|
|
||||||
// TODO improve this response
|
|
||||||
return c.Error500(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.SendJSON(model.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RunTaskTrain(b BasePack, task Task) (err error) {
|
func RunTaskTrain(b BasePack, task Task) (err error) {
|
||||||
l := b.GetLogger()
|
l := b.GetLogger()
|
||||||
|
|
||||||
@ -1929,7 +1785,132 @@ func handleTrain(handle *Handle) {
|
|||||||
return c.SendJSON(id)
|
return c.SendJSON(id)
|
||||||
})
|
})
|
||||||
|
|
||||||
handle.Post("/model/train/retrain", handleRetrain)
|
PostAuthJson(handle, "/model/train/retrain", User_Normal, func(c *Context, dat *JustId) *Error {
|
||||||
|
model, err := GetBaseModel(c.Db, dat.Id)
|
||||||
|
if err == ModelNotFoundError {
|
||||||
|
return c.JsonBadRequest("Model not found")
|
||||||
|
} else if err != nil {
|
||||||
|
return c.E500M("Faield to get model", err)
|
||||||
|
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||||
|
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||||
|
|
||||||
|
classesUpdated := false
|
||||||
|
|
||||||
|
failed := func() *Error {
|
||||||
|
if classesUpdated {
|
||||||
|
ResetClasses(c, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||||
|
return c.E500M("Failed to retrain model", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var def struct {
|
||||||
|
Id string
|
||||||
|
TargetAccuracy int `db:"target_accuracy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
type C struct {
|
||||||
|
Id string
|
||||||
|
ClassOrder int `db:"class_order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.StartTx()
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
classes, err := GetDbMultitple[C](
|
||||||
|
c,
|
||||||
|
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||||
|
model.Id,
|
||||||
|
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(classes) == 0 {
|
||||||
|
c.Logger.Error("No classes are available!")
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
//Update the classes
|
||||||
|
{
|
||||||
|
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||||
|
err = err2
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.CommitTx()
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
classesUpdated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var newHead = struct {
|
||||||
|
DefId string `db:"def_id"`
|
||||||
|
RangeStart int `db:"range_start"`
|
||||||
|
RangeEnd int `db:"range_end"`
|
||||||
|
Status ModelDefinitionStatus `db:"status"`
|
||||||
|
}{
|
||||||
|
def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
go trainRetrain(c, model, def.Id)
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to update model status")
|
||||||
|
fmt.Println(err)
|
||||||
|
// TODO improve this response
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SendJSON(model.Id)
|
||||||
|
})
|
||||||
|
|
||||||
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
||||||
f := c.R.URL.Query()
|
f := c.R.URL.Query()
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -57,7 +58,7 @@ func handleError(err *Error, c *Context) {
|
|||||||
e = c.SendJSON(500)
|
e = c.SendJSON(500)
|
||||||
}
|
}
|
||||||
if e != nil {
|
if e != nil {
|
||||||
c.Logger.Error("Something went very wrong while trying to send and error message")
|
c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
|
||||||
c.Writer.Write([]byte("505"))
|
c.Writer.Write([]byte("505"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -195,7 +196,7 @@ func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.Use
|
|||||||
func handleLoop(array []HandleFunc, context *Context) {
|
func handleLoop(array []HandleFunc, context *Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
context.Logger.Error("Something went very wrong", "Error", r)
|
context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
|
||||||
handleError(&Error{500, "500"}, context)
|
handleError(&Error{500, "500"}, context)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -418,8 +419,7 @@ func (c Context) ErrorCode(err error, code int, data any) *Error {
|
|||||||
c.SetReportCaller(false)
|
c.SetReportCaller(false)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
|
c.Logger.Error("Something went wrong returning with:", "Error", err)
|
||||||
c.Logger.Error(err)
|
|
||||||
}
|
}
|
||||||
return &Error{code, data}
|
return &Error{code, data}
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,7 @@
|
|||||||
try {
|
try {
|
||||||
let temp_model: Model = await get(`models/edit?id=${id}`);
|
let temp_model: Model = await get(`models/edit?id=${id}`);
|
||||||
|
|
||||||
if ([3, 6].includes(temp_model.status)) {
|
if ([3, 7, 6].includes(temp_model.status)) {
|
||||||
setTimeout(getModel, 2000);
|
setTimeout(getModel, 2000);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +143,7 @@
|
|||||||
>
|
>
|
||||||
Model
|
Model
|
||||||
</button>
|
</button>
|
||||||
{#if _model && [2, 3, 4, 5, 6, 7].includes(_model.status)}
|
{#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)}
|
||||||
<button
|
<button
|
||||||
class="tab"
|
class="tab"
|
||||||
on:click|preventDefault={setActive('model-data')}
|
on:click|preventDefault={setActive('model-data')}
|
||||||
@ -152,7 +152,7 @@
|
|||||||
Model Data
|
Model Data
|
||||||
</button>
|
</button>
|
||||||
{/if}
|
{/if}
|
||||||
{#if _model && [5, 6, 7].includes(_model.status)}
|
{#if _model && [5, 6, 7, -6, -7].includes(_model.status)}
|
||||||
<button
|
<button
|
||||||
class="tab"
|
class="tab"
|
||||||
on:click|preventDefault={setActive('tasks')}
|
on:click|preventDefault={setActive('tasks')}
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
import { createEventDispatcher } from 'svelte';
|
import { createEventDispatcher } from 'svelte';
|
||||||
import ModelTable from './ModelTable.svelte';
|
import ModelTable from './ModelTable.svelte';
|
||||||
import TrainModel from './TrainModel.svelte';
|
import TrainModel from './TrainModel.svelte';
|
||||||
|
import ZipStructure from './ZipStructure.svelte';
|
||||||
|
|
||||||
let { model, simple } = $props<{ model: Model; simple?: boolean }>();
|
let { model, simple } = $props<{ model: Model; simple?: boolean }>();
|
||||||
|
|
||||||
@ -103,32 +104,7 @@
|
|||||||
<br />
|
<br />
|
||||||
Each of the folders will contain the classes of the model. The folders must be the same
|
Each of the folders will contain the classes of the model. The folders must be the same
|
||||||
in testing and training. The class folders must have the images for the classes.
|
in testing and training. The class folders must have the images for the classes.
|
||||||
<pre>
|
<ZipStructure />
|
||||||
training\
|
|
||||||
class1\
|
|
||||||
img1.png
|
|
||||||
img2.png
|
|
||||||
img2.png
|
|
||||||
...
|
|
||||||
class2\
|
|
||||||
img1.png
|
|
||||||
img2.png
|
|
||||||
img2.png
|
|
||||||
...
|
|
||||||
...
|
|
||||||
testing\
|
|
||||||
class1\
|
|
||||||
img1.png
|
|
||||||
img2.png
|
|
||||||
img2.png
|
|
||||||
...
|
|
||||||
class2\
|
|
||||||
img1.png
|
|
||||||
img2.png
|
|
||||||
img2.png
|
|
||||||
...
|
|
||||||
...
|
|
||||||
</pre>
|
|
||||||
</div>
|
</div>
|
||||||
<FileUpload replace_slot bind:file accept="application/zip" notExpand>
|
<FileUpload replace_slot bind:file accept="application/zip" notExpand>
|
||||||
<img src="/imgs/upload-icon.png" alt="" />
|
<img src="/imgs/upload-icon.png" alt="" />
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
Colors
|
Colors
|
||||||
} from 'chart.js';
|
} from 'chart.js';
|
||||||
import type { ModelStats } from './types';
|
import type { ModelStats } from './types';
|
||||||
|
import { onDestroy } from 'svelte';
|
||||||
|
|
||||||
Chart.register(
|
Chart.register(
|
||||||
Title,
|
Title,
|
||||||
@ -23,7 +24,7 @@
|
|||||||
Colors
|
Colors
|
||||||
);
|
);
|
||||||
|
|
||||||
let { data } = $props<{ data: ModelStats }>();
|
let { data }: { data: ModelStats } = $props();
|
||||||
|
|
||||||
let ctx: HTMLCanvasElement;
|
let ctx: HTMLCanvasElement;
|
||||||
|
|
||||||
@ -78,6 +79,10 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
onDestroy(() => {
|
||||||
|
if (chart) chart.destroy();
|
||||||
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<div><canvas bind:this={ctx} /></div>
|
<div><canvas bind:this={ctx} /></div>
|
||||||
|
@ -255,6 +255,8 @@
|
|||||||
{:else if selected_class?.status == 3}
|
{:else if selected_class?.status == 3}
|
||||||
Class trained
|
Class trained
|
||||||
{/if}
|
{/if}
|
||||||
|
{:else}
|
||||||
|
Class to train
|
||||||
{/if}
|
{/if}
|
||||||
<button on:click={() => uploadImageDialog.showModal()}> Upload Image </button>
|
<button on:click={() => uploadImageDialog.showModal()}> Upload Image </button>
|
||||||
</h2>
|
</h2>
|
||||||
|
Loading…
Reference in New Issue
Block a user