some work done on the running of the model
This commit is contained in:
parent
30c5b57378
commit
4a95f0211d
@ -20,6 +20,8 @@ func ListClasses(db *sql.DB, model_id string) (cls []ModelClass, err error) {
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cls = []ModelClass{}
|
||||
|
||||
for rows.Next() {
|
||||
var model ModelClass
|
||||
err = rows.Scan(&model.Id, &model.ModelId, &model.Name)
|
||||
|
@ -8,13 +8,14 @@ import (
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
utils "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func deleteModel(handle *Handle, id string, w http.ResponseWriter, c *Context, model BaseModel) {
|
||||
c.Logger.Warnf("Removing model with id: %s", id)
|
||||
c.Logger.Warnf("Removing model with id: %s", id)
|
||||
_, err := handle.Db.Exec("delete from models where id=$1;", id)
|
||||
if err != nil {
|
||||
c.Logger.Error(err)
|
||||
c.Logger.Error(err)
|
||||
panic("TODO handle better deleteModel failed delete database query")
|
||||
}
|
||||
|
||||
@ -22,7 +23,7 @@ func deleteModel(handle *Handle, id string, w http.ResponseWriter, c *Context, m
|
||||
c.Logger.Warnf("Removing folder of model with id: %s at %s", id, model_path)
|
||||
err = os.RemoveAll(model_path)
|
||||
if err != nil {
|
||||
c.Logger.Error(err)
|
||||
c.Logger.Error(err)
|
||||
panic("TODO handle better deleteModel failed to delete folder")
|
||||
}
|
||||
|
||||
@ -37,10 +38,76 @@ func deleteModel(handle *Handle, id string, w http.ResponseWriter, c *Context, m
|
||||
}))
|
||||
}
|
||||
|
||||
func deleteModelJSON(c *Context, id string) *Error {
|
||||
c.Logger.Warnf("Removing model with id: %s", id)
|
||||
_, err := c.Db.Exec("delete from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
model_path := path.Join("./savedData", id)
|
||||
c.Logger.Warnf("Removing folder of model with id: %s at %s", id, model_path)
|
||||
err = os.RemoveAll(model_path)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(id)
|
||||
}
|
||||
|
||||
func handleDelete(handle *Handle) {
|
||||
handle.Delete("/models/delete", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
|
||||
if c.Mode == JSON {
|
||||
panic("TODO handle json on models/delete")
|
||||
|
||||
var dat struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
if err_ := c.ToJSON(r, &dat); err_ != nil {
|
||||
return err_
|
||||
}
|
||||
|
||||
var model struct {
|
||||
Id string
|
||||
Name string
|
||||
Status int
|
||||
}
|
||||
|
||||
err := utils.GetDBOnce(c, &model, "models where id=$1 and user_id=$2;", dat.Id, c.User.Id)
|
||||
if err == NotFoundError {
|
||||
return c.SendJSONStatus(http.StatusNotFound, "Model not found!")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
switch model.Status {
|
||||
case FAILED_TRAINING:
|
||||
fallthrough
|
||||
case FAILED_PREPARING_ZIP_FILE:
|
||||
fallthrough
|
||||
case FAILED_PREPARING_TRAINING:
|
||||
fallthrough
|
||||
case FAILED_PREPARING:
|
||||
return deleteModelJSON(c, dat.Id)
|
||||
|
||||
case READY:
|
||||
fallthrough
|
||||
case CONFIRM_PRE_TRAINING:
|
||||
if dat.Name == nil {
|
||||
return c.JsonBadRequest("Provided name does not match the model name")
|
||||
}
|
||||
|
||||
if *dat.Name != model.Name {
|
||||
return c.JsonBadRequest("Provided name does not match the model name")
|
||||
}
|
||||
|
||||
return deleteModelJSON(c, dat.Id)
|
||||
default:
|
||||
c.Logger.Warn("Do not know how to handle model in status", "status", model.Status)
|
||||
return c.JsonBadRequest("Model in invalid status")
|
||||
}
|
||||
}
|
||||
|
||||
// This is required to parse delete forms with bodies
|
||||
@ -66,7 +133,7 @@ func handleDelete(handle *Handle) {
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
c.Logger.Warn("Could not find model for", id, c.User.Id)
|
||||
c.Logger.Warn("Could not find model for", id, c.User.Id)
|
||||
return c.ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
"GoBackLink": "/models",
|
||||
@ -81,14 +148,18 @@ func handleDelete(handle *Handle) {
|
||||
}
|
||||
|
||||
switch model.Status {
|
||||
case FAILED_TRAINING: fallthrough
|
||||
case FAILED_PREPARING_ZIP_FILE: fallthrough
|
||||
case FAILED_PREPARING_TRAINING: fallthrough
|
||||
case FAILED_TRAINING:
|
||||
fallthrough
|
||||
case FAILED_PREPARING_ZIP_FILE:
|
||||
fallthrough
|
||||
case FAILED_PREPARING_TRAINING:
|
||||
fallthrough
|
||||
case FAILED_PREPARING:
|
||||
deleteModel(handle, id, w, c, model)
|
||||
return nil
|
||||
|
||||
case READY: fallthrough
|
||||
case READY:
|
||||
fallthrough
|
||||
case CONFIRM_PRE_TRAINING:
|
||||
if CheckEmpty(f, "name") {
|
||||
return c.Error400(nil, "Name is empty", w, "/models/edit.html", "delete-model-card", AnyMap{
|
||||
|
@ -37,7 +37,7 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
||||
return image.Scale(0, 255)
|
||||
}
|
||||
|
||||
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
|
||||
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
|
||||
order = 0
|
||||
err = nil
|
||||
|
||||
@ -59,10 +59,12 @@ func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.
|
||||
}
|
||||
}
|
||||
|
||||
confidence = vmax
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
|
||||
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
|
||||
|
||||
err = nil
|
||||
order = 0
|
||||
@ -101,6 +103,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
for i, v := range predictions {
|
||||
c.Logger.Info("This is test", "v", v)
|
||||
if v > vmax {
|
||||
order = element.Range_start + i
|
||||
vmax = v
|
||||
@ -109,6 +112,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
}
|
||||
|
||||
// TODO runthe head model
|
||||
confidence = vmax
|
||||
|
||||
c.Logger.Info("Got", "heads", len(heads))
|
||||
return
|
||||
@ -135,7 +139,7 @@ func handleRun(handle *Handle) {
|
||||
if err_part == io.EOF {
|
||||
break
|
||||
} else if err_part != nil {
|
||||
return c.JsonBadRequest("Invalid multipart data")
|
||||
return c.JsonBadRequest("Invalid multipart data")
|
||||
}
|
||||
if part.FormName() == "id" {
|
||||
buf := new(bytes.Buffer)
|
||||
@ -151,13 +155,13 @@ func handleRun(handle *Handle) {
|
||||
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Models not found");
|
||||
return c.JsonBadRequest("Models not found")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != READY {
|
||||
return c.JsonBadRequest("Model not ready to run images")
|
||||
return c.JsonBadRequest("Model not ready to run images")
|
||||
}
|
||||
|
||||
def := JustId{}
|
||||
@ -183,7 +187,7 @@ func handleRun(handle *Handle) {
|
||||
img_file.Write(file)
|
||||
|
||||
if !testImgForModel(c, model, img_path) {
|
||||
return c.JsonBadRequest("Provided image does not match the model")
|
||||
return c.JsonBadRequest("Provided image does not match the model")
|
||||
}
|
||||
|
||||
root := tg.NewRoot()
|
||||
@ -206,16 +210,17 @@ func handleRun(handle *Handle) {
|
||||
}
|
||||
|
||||
vi := -1
|
||||
var confidence float32 = 0
|
||||
|
||||
if model.ModelType == 2 {
|
||||
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, err = runModelExp(c, model, def_id, inputImage)
|
||||
vi, confidence, err = runModelExp(c, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
} else {
|
||||
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, err = runModelNormal(c, model, def_id, inputImage)
|
||||
vi, confidence, err = runModelNormal(c, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
@ -228,7 +233,7 @@ func handleRun(handle *Handle) {
|
||||
return c.Error500(err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
return c.SendJSON(nil)
|
||||
return c.SendJSON(nil)
|
||||
}
|
||||
|
||||
var name string
|
||||
@ -236,7 +241,15 @@ func handleRun(handle *Handle) {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(name)
|
||||
returnValue := struct {
|
||||
Class string `json:"class"`
|
||||
Confidence float32 `json:"confidence"`
|
||||
}{
|
||||
Class: name,
|
||||
Confidence: confidence,
|
||||
}
|
||||
|
||||
return c.SendJSON(returnValue)
|
||||
}
|
||||
|
||||
read_form, err := r.MultipartReader()
|
||||
@ -336,16 +349,17 @@ func handleRun(handle *Handle) {
|
||||
}
|
||||
|
||||
vi := -1
|
||||
var confidence float32 = 0
|
||||
|
||||
if model.ModelType == 2 {
|
||||
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, err = runModelExp(c, model, def_id, inputImage)
|
||||
vi, confidence, err = runModelExp(c, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
} else {
|
||||
c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, err = runModelNormal(c, model, def_id, inputImage)
|
||||
vi, confidence, err = runModelNormal(c, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
@ -362,6 +376,7 @@ func handleRun(handle *Handle) {
|
||||
"Model": model,
|
||||
"NotFound": true,
|
||||
"Result": nil,
|
||||
"Confidence": confidence,
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
@ -101,6 +101,70 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e
|
||||
return
|
||||
}
|
||||
|
||||
func setModelClassStatus(c *Context, status ModelClassStatus, filter string, args ...any) (err error) {
|
||||
_, err = c.Db.Exec("update model_classes set stauts = $1 where "+filter, args...)
|
||||
return
|
||||
}
|
||||
|
||||
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
|
||||
|
||||
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer classes.Close()
|
||||
|
||||
if !classes.Next() {
|
||||
return
|
||||
}
|
||||
|
||||
if err = classes.Scan(&count); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if doPanic {
|
||||
return 0, errors.New("No model classes available")
|
||||
}
|
||||
|
||||
return generateCvsExp(c, run_path, model_id, true)
|
||||
}
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer data.Close()
|
||||
|
||||
f, err := os.Create(path.Join(run_path, "train.csv"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
f.Write([]byte("Id,Index\n"))
|
||||
|
||||
for data.Next() {
|
||||
var id string
|
||||
var class_order int
|
||||
var file_path string
|
||||
if err = data.Scan(&id, &class_order, &file_path); err != nil {
|
||||
return
|
||||
}
|
||||
if file_path == "id://" {
|
||||
f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
|
||||
} else {
|
||||
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func trainDefinition(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
c.Logger.Warn("About to start training definition")
|
||||
accuracy = 0
|
||||
@ -137,9 +201,9 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(run_path)
|
||||
defer removeAll(run_path, err)
|
||||
|
||||
_, err = generateCvs(c, run_path, model.Id)
|
||||
classCount, err := generateCvs(c, run_path, model.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -171,6 +235,8 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
"LoadPrev": load_prev,
|
||||
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||
"SaveModelPath": path.Join(getDir(), result_path),
|
||||
"Depth": classCount,
|
||||
"StartPoint": 0,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@ -208,6 +274,12 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
return
|
||||
}
|
||||
|
||||
func removeAll(path string, err error) {
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
}
|
||||
}
|
||||
|
||||
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
accuracy = 0
|
||||
|
||||
@ -295,9 +367,9 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(run_path)
|
||||
defer removeAll(run_path, err)
|
||||
|
||||
_, err = generateCvs(c, run_path, model.Id)
|
||||
classCount, err := generateCvsExp(c, run_path, model.Id, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -333,6 +405,8 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||
"SaveModelPath": path.Join(getDir(), result_path),
|
||||
"RemoveTopCount": remove_top_count,
|
||||
"Depth": classCount,
|
||||
"StartPoint": 0,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@ -413,7 +487,7 @@ func (nf ToRemoveList) Less(i, j int) bool {
|
||||
func trainModel(c *Context, model *BaseModel) {
|
||||
definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to trainModel!Err:")
|
||||
c.Logger.Error("Failed to train Model! Err:")
|
||||
c.Logger.Error(err)
|
||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||
return
|
||||
@ -584,7 +658,7 @@ func trainModel(c *Context, model *BaseModel) {
|
||||
|
||||
for to_delete.Next() {
|
||||
var id string
|
||||
if to_delete.Scan(&id); err != nil {
|
||||
if err = to_delete.Scan(&id); err != nil {
|
||||
c.Logger.Error("Failed to scan the id of a model_definition to delete")
|
||||
c.Logger.Error(err)
|
||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||
@ -604,6 +678,21 @@ func trainModel(c *Context, model *BaseModel) {
|
||||
ModelUpdateStatus(c, model.Id, READY)
|
||||
}
|
||||
|
||||
type TrainModelRowUsable struct {
|
||||
Id string
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
Epoch int
|
||||
Acuracy float64 `db:"0"`
|
||||
}
|
||||
|
||||
type TrainModelRowUsables []*TrainModelRowUsable
|
||||
|
||||
func (nf TrainModelRowUsables) Len() int { return len(nf) }
|
||||
func (nf TrainModelRowUsables) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
|
||||
func (nf TrainModelRowUsables) Less(i, j int) bool {
|
||||
return nf[i].Acuracy < nf[j].Acuracy
|
||||
}
|
||||
|
||||
func trainModelExp(c *Context, model *BaseModel) {
|
||||
var err error = nil
|
||||
|
||||
@ -612,25 +701,13 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||
}
|
||||
|
||||
definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
|
||||
if err != nil {
|
||||
failed("Failed to trainModel!")
|
||||
return
|
||||
}
|
||||
defer definitionsRows.Close()
|
||||
|
||||
var definitions TraingModelRowDefinitions = []TrainModelRow{}
|
||||
|
||||
for definitionsRows.Next() {
|
||||
var rowv TrainModelRow
|
||||
rowv.acuracy = 0
|
||||
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
|
||||
failed("Failed to train Model Could not read definition from db!")
|
||||
return
|
||||
}
|
||||
definitions = append(definitions, rowv)
|
||||
}
|
||||
var definitions TrainModelRowUsables
|
||||
|
||||
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
|
||||
if err != nil {
|
||||
failed("Failed to get definitions");
|
||||
return
|
||||
}
|
||||
if len(definitions) == 0 {
|
||||
failed("No Definitions defined!")
|
||||
return
|
||||
@ -642,30 +719,30 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
for {
|
||||
var toRemove ToRemoveList = []int{}
|
||||
for i, def := range definitions {
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
accuracy, err := trainDefinitionExp(c, model, def.id, !firstRound)
|
||||
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
accuracy, err := trainDefinitionExp(c, model, def.Id, !firstRound)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to train definition!Err:", "err", err)
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
toRemove = append(toRemove, i)
|
||||
continue
|
||||
}
|
||||
def.epoch += EPOCH_PER_RUN
|
||||
def.Epoch += EPOCH_PER_RUN
|
||||
accuracy = accuracy * 100
|
||||
def.acuracy = float64(accuracy)
|
||||
def.Acuracy = float64(accuracy)
|
||||
|
||||
definitions[i].epoch += EPOCH_PER_RUN
|
||||
definitions[i].acuracy = accuracy
|
||||
definitions[i].Epoch += EPOCH_PER_RUN
|
||||
definitions[i].Acuracy = accuracy
|
||||
|
||||
if accuracy >= float64(def.target_accuracy) {
|
||||
if accuracy >= float64(def.TargetAccuracy) {
|
||||
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
|
||||
if err != nil {
|
||||
failed("Failed to train definition!")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
if err != nil {
|
||||
failed("Failed to train definition!")
|
||||
return
|
||||
@ -675,14 +752,14 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
break
|
||||
}
|
||||
|
||||
if def.epoch > MAX_EPOCH {
|
||||
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy)
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
if def.Epoch > MAX_EPOCH {
|
||||
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
|
||||
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
toRemove = append(toRemove, i)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
|
||||
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
|
||||
if err != nil {
|
||||
failed("Failed to train definition!")
|
||||
return
|
||||
@ -713,14 +790,13 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
}
|
||||
|
||||
sort.Sort(sort.Reverse(definitions))
|
||||
acc := definitions[0].Acuracy - 20.0
|
||||
|
||||
acc := definitions[0].acuracy - 20.0
|
||||
|
||||
c.Logger.Info("Training models, Highest acc", "acc", definitions[0].acuracy, "mod_acc", acc)
|
||||
c.Logger.Info("Training models, Highest acc", "acc", definitions[0].Acuracy, "mod_acc", acc)
|
||||
|
||||
toRemove = []int{}
|
||||
for i, def := range definitions {
|
||||
if def.acuracy < acc {
|
||||
if def.Acuracy < acc {
|
||||
toRemove = append(toRemove, i)
|
||||
}
|
||||
}
|
||||
@ -730,7 +806,7 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
sort.Sort(sort.Reverse(toRemove))
|
||||
for _, n := range toRemove {
|
||||
c.Logger.Warn("Removing definition not fast enough learning", "n", n)
|
||||
ModelDefinitionUpdateStatus(c, definitions[n].id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
ModelDefinitionUpdateStatus(c, definitions[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
|
||||
definitions = remove(definitions, n)
|
||||
}
|
||||
}
|
||||
@ -821,8 +897,7 @@ func splitModel(c *Context, model *BaseModel) (err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// TODO reneable it
|
||||
// defer os.RemoveAll(run_path)
|
||||
defer removeAll(run_path, err)
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
@ -1286,7 +1361,7 @@ func handle_models_train_json(w http.ResponseWriter, r *http.Request, c *Context
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
return c.SendJSON(model.Id)
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
|
@ -6,9 +6,9 @@ import (
|
||||
)
|
||||
|
||||
type BaseModel struct {
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
|
||||
ModelType int
|
||||
ImageMode int
|
||||
@ -52,6 +52,14 @@ const (
|
||||
MODEL_DEFINITION_STATUS_READY = 5
|
||||
)
|
||||
|
||||
type ModelClassStatus int
|
||||
|
||||
const (
|
||||
MODEL_CLASS_STATUS_TO_TRAIN ModelClassStatus = 1
|
||||
MODEL_CLASS_STATUS_TRAINING = 2
|
||||
MODEL_CLASS_STATUS_TRAINED = 3
|
||||
)
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
|
@ -387,7 +387,9 @@ func (c Context) SendJSONStatus(status int, dat any) *Error {
|
||||
}
|
||||
|
||||
func (c Context) JsonBadRequest(dat any) *Error {
|
||||
c.SetReportCaller(true)
|
||||
c.Logger.Warn("Request failed with a bad request", "dat", dat)
|
||||
c.SetReportCaller(false)
|
||||
return c.SendJSONStatus(http.StatusBadRequest, dat)
|
||||
}
|
||||
|
||||
|
@ -190,18 +190,29 @@ type Generic struct{ reflect.Type }
|
||||
|
||||
var NotFoundError = errors.New("Not found")
|
||||
|
||||
func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) {
|
||||
t := reflect.TypeFor[T]()
|
||||
nargs := t.NumField()
|
||||
|
||||
query := ""
|
||||
func generateQuery(t reflect.Type) (query string, nargs int) {
|
||||
nargs = t.NumField()
|
||||
query = ""
|
||||
|
||||
for i := 0; i < nargs; i += 1 {
|
||||
query += strings.ToLower(t.Field(i).Name) + ","
|
||||
field := t.Field(i)
|
||||
name, ok := field.Tag.Lookup("db")
|
||||
if !ok {
|
||||
name = field.Name;
|
||||
}
|
||||
query += strings.ToLower(name) + ","
|
||||
}
|
||||
|
||||
// Remove the last comma
|
||||
query = query[0 : len(query)-1]
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) {
|
||||
t := reflect.TypeFor[T]()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
|
||||
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||
if err != nil {
|
||||
@ -242,16 +253,8 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||
|
||||
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
nargs := t.NumField()
|
||||
|
||||
query := ""
|
||||
|
||||
for i := 0; i < nargs; i += 1 {
|
||||
query += strings.ToLower(t.Field(i).Name) + ","
|
||||
}
|
||||
|
||||
// Remove the last comma
|
||||
query = query[0 : len(query)-1]
|
||||
query, nargs := generateQuery(t)
|
||||
|
||||
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||
if err != nil {
|
||||
|
@ -13,6 +13,8 @@ http {
|
||||
server {
|
||||
listen 8000;
|
||||
|
||||
client_max_body_size 1G;
|
||||
|
||||
location / {
|
||||
proxy_http_version 1.1;
|
||||
proxy_pass http://127.0.0.1:5001;
|
||||
|
@ -29,7 +29,12 @@ create table if not exists model_classes (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
model_id uuid references models (id) on delete cascade,
|
||||
name varchar (70) not null,
|
||||
class_order integer
|
||||
class_order integer,
|
||||
|
||||
-- 1: to_train
|
||||
-- 2: training
|
||||
-- 3: trained
|
||||
status integer default 1,
|
||||
);
|
||||
|
||||
-- drop table if exists model_data_point;
|
||||
|
@ -9,9 +9,9 @@ import requests
|
||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
||||
{{ if .HeadId }}
|
||||
requests.get(f'http://localhost:8000//model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
|
||||
requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
|
||||
{{ else }}
|
||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
||||
requests.get(f'http://localhost:8000/api/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}')
|
||||
{{end}}
|
||||
|
||||
|
||||
@ -23,6 +23,9 @@ df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
||||
keys = tf.constant(df['Id'].dropna())
|
||||
values = tf.constant(list(map(int, df['Index'].dropna())))
|
||||
|
||||
depth = {{ .Depth }}
|
||||
diff = {{ .StartPoint }}
|
||||
|
||||
table = tf.lookup.StaticHashTable(
|
||||
initializer=tf.lookup.KeyValueTensorInitializer(
|
||||
keys=keys,
|
||||
@ -44,7 +47,8 @@ def pathToLabel(path):
|
||||
{{ else }}
|
||||
ERROR
|
||||
{{ end }}
|
||||
return table.lookup(tf.strings.as_string([path]))
|
||||
|
||||
return tf.one_hot(table.lookup(tf.strings.as_string([path])) - diff, depth)[0]
|
||||
|
||||
def decode_image(img):
|
||||
{{ if eq .Model.Format "png" }}
|
||||
@ -161,7 +165,8 @@ ERROR
|
||||
{{ end }}
|
||||
|
||||
model.compile(
|
||||
loss=losses.SparseCategoricalCrossentropy(),
|
||||
#loss=losses.SparseCategoricalCrossentropy(),
|
||||
loss=losses.BinaryCrossentropy(from_logits=False),
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
<script lang="ts">
|
||||
let { replace_slot, accept, file } = $props<{
|
||||
let { replace_slot, accept, file, notExpand } = $props<{
|
||||
replace_slot?: boolean,
|
||||
accept?: string,
|
||||
file?: File,
|
||||
notExpand?: boolean
|
||||
}>();
|
||||
|
||||
let fileInput: HTMLInputElement;
|
||||
@ -27,7 +28,7 @@
|
||||
</script>
|
||||
|
||||
<div class="icon-holder">
|
||||
<button class="icon" class:adapt={replace_slot && file} on:click={() => fileInput.click()}>
|
||||
<button class="icon" class:adapt={replace_slot && file && !notExpand} on:click={() => fileInput.click()}>
|
||||
{#if replace_slot && file}
|
||||
<slot name="replaced" file={file}>
|
||||
<img src={fileData} alt="" />
|
||||
|
@ -42,6 +42,7 @@ export async function post(url: string, body: any) {
|
||||
if (r.status === 401) {
|
||||
userStore.user = undefined;
|
||||
goto("/login")
|
||||
throw r;
|
||||
} else if (r.status !== 200) {
|
||||
throw r;
|
||||
}
|
||||
|
@ -39,7 +39,7 @@
|
||||
New
|
||||
</a>
|
||||
</div>
|
||||
<table>
|
||||
<table class="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>
|
||||
@ -101,4 +101,5 @@
|
||||
height: calc(100% - 20px);
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
</style>
|
||||
|
@ -1,15 +1,29 @@
|
||||
<script lang="ts">
|
||||
import MessageSimple from 'src/lib/MessageSimple.svelte';
|
||||
import type { Model } from './+page.svelte';
|
||||
import { rdelete } from '$lib/requests.svelte'
|
||||
import { goto } from '$app/navigation';
|
||||
|
||||
let {model}: { model: Model } = $props();
|
||||
let { model } = $props<{ model: Model }>();
|
||||
let name: string = $state("");
|
||||
let submmited: boolean = $state(false);
|
||||
let nameDoesNotMatch: string = $state("");
|
||||
|
||||
function deleteModel() {
|
||||
let messageSimple: MessageSimple;
|
||||
|
||||
async function deleteModel() {
|
||||
submmited = true;
|
||||
nameDoesNotMatch = "";
|
||||
console.error("TODO")
|
||||
messageSimple.display("");
|
||||
|
||||
try {
|
||||
await rdelete("models/delete", {id: model.id, name});
|
||||
goto("/models");
|
||||
} catch (e) {
|
||||
if (e instanceof Response) {
|
||||
messageSimple.display(await e.json());
|
||||
} else {
|
||||
messageSimple.display("Could not delete the model");
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@ -19,12 +33,8 @@
|
||||
To delete this model please type "{model.name}":
|
||||
</label>
|
||||
<input name="name" id="name" required bind:value={name} />
|
||||
{#if nameDoesNotMatch }
|
||||
<span class="form-msg red">
|
||||
Name does not match "{model.name}"
|
||||
</span>
|
||||
{/if}
|
||||
</fieldset>
|
||||
<MessageSimple bind:this={messageSimple} />
|
||||
<button class="danger">
|
||||
Delete
|
||||
</button>
|
||||
|
@ -127,12 +127,12 @@
|
||||
...
|
||||
</pre>
|
||||
</div>
|
||||
<FileUpload replace_slot bind:file={file} accept="application/zip" >
|
||||
<FileUpload replace_slot bind:file={file} accept="application/zip" notExpand >
|
||||
<img src="/imgs/upload-icon.png" alt="" />
|
||||
<span>
|
||||
Upload Zip File
|
||||
</span>
|
||||
<div slot="replaced">
|
||||
<div slot="replaced" style="display: inline;">
|
||||
<img src="/imgs/upload-icon.png" alt="" />
|
||||
<span>
|
||||
File selected
|
||||
|
@ -8,7 +8,12 @@
|
||||
|
||||
let file: File | undefined = $state();
|
||||
|
||||
let result: string | undefined = $state();
|
||||
type Result = {
|
||||
class: string,
|
||||
confidence: number,
|
||||
}
|
||||
|
||||
let _result: Promise<Result | undefined> = $state(new Promise(() => {}));
|
||||
let run = $state(false);
|
||||
|
||||
let messages: MessageSimple;
|
||||
@ -25,7 +30,8 @@
|
||||
run = true;
|
||||
|
||||
try {
|
||||
result = await postFormData('models/run', form);
|
||||
_result = await postFormData('models/run', form);
|
||||
console.log(await _result);
|
||||
} catch (e) {
|
||||
if (e instanceof Response) {
|
||||
messages.display(await e.json());
|
||||
@ -60,19 +66,21 @@
|
||||
Run
|
||||
</button>
|
||||
{#if run}
|
||||
{#if !result}
|
||||
<div class="result">
|
||||
<h1>
|
||||
The class was not found
|
||||
</h1>
|
||||
</div>
|
||||
{:else}
|
||||
<div>
|
||||
<h1>
|
||||
Result
|
||||
</h1>
|
||||
The image was classified as {result}
|
||||
</div>
|
||||
{/if}
|
||||
{#await _result then result}
|
||||
{#if !result}
|
||||
<div class="result">
|
||||
<h1>
|
||||
The class was not found
|
||||
</h1>
|
||||
</div>
|
||||
{:else}
|
||||
<div>
|
||||
<h1>
|
||||
Result
|
||||
</h1>
|
||||
The image was classified as {result.class} with confidence: {result.confidence}
|
||||
</div>
|
||||
{/if}
|
||||
{/await}
|
||||
{/if}
|
||||
</form>
|
||||
|
@ -101,3 +101,33 @@ a.button {
|
||||
.card h3 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
|
||||
.table {
|
||||
width: 100%;
|
||||
box-shadow: 0 2px 8px 1px #66666622;
|
||||
border-radius: 10px;
|
||||
border-collapse: collapse;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.table thead {
|
||||
background: #60606022;
|
||||
}
|
||||
|
||||
.table tr td,
|
||||
.table tr th {
|
||||
border-left: 1px solid #22222244;
|
||||
padding: 15px;
|
||||
}
|
||||
|
||||
.table tr td:first-child,
|
||||
.table tr th:first-child {
|
||||
border-left: none;
|
||||
}
|
||||
|
||||
.table tr td button,
|
||||
.table tr td .button {
|
||||
padding: 5px 10px;
|
||||
box-shadow: 0 2px 5px 1px #66666655;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user