chore: did more clean up
This commit is contained in:
@@ -11,7 +11,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package model_classes
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
)
|
||||
|
||||
var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error")
|
||||
@@ -23,6 +25,6 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT
|
||||
}
|
||||
|
||||
func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) {
|
||||
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
|
||||
return
|
||||
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,43 +1,26 @@
|
||||
package model_classes
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func HandleList(handle *Handle) {
|
||||
handle.Get("/models/data/list", func(c *Context) *Error {
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
id, err := GetIdFromUrl(c, "id")
|
||||
if err != nil {
|
||||
return c.JsonBadRequest("Model Class not found!")
|
||||
}
|
||||
|
||||
page := 0
|
||||
if c.R.URL.Query().Has("page") {
|
||||
page_url := c.R.URL.Query().Get("page")
|
||||
page_url_number, err := strconv.Atoi(page_url)
|
||||
if err != nil {
|
||||
return c.JsonBadRequest("Page is not a number")
|
||||
}
|
||||
page = page_url_number
|
||||
}
|
||||
|
||||
type DataList struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
PostAuthJson(handle, "/models/data/list", User_Normal, func(c *Context, dat *DataList) *Error {
|
||||
var class_row struct {
|
||||
Name string
|
||||
Model_id string
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &class_row, "model_classes where id=$1", id)
|
||||
err := GetDBOnce(c, &class_row, "model_classes where id=$1", dat.Id)
|
||||
if err == NotFoundError {
|
||||
return c.JsonBadRequest("Model Class not found!")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to get classes", err)
|
||||
}
|
||||
|
||||
type baserow struct {
|
||||
@@ -47,23 +30,21 @@ func HandleList(handle *Handle) {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
rows, err := GetDbMultitple[baserow](c, "model_data_point where class_id=$1 limit 11 offset $2", id, page*10)
|
||||
rows, err := GetDbMultitple[baserow](c, "model_data_point where class_id=$1 limit 11 offset $2", dat.Id, dat.Page*10)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
type ReturnType struct {
|
||||
ImageList []*baserow `json:"image_list"`
|
||||
Page int `json:"page"`
|
||||
ShowNext bool `json:"showNext"`
|
||||
return c.E500M("Failed to get classes", err)
|
||||
}
|
||||
|
||||
max_len := min(11, len(rows))
|
||||
|
||||
c.ShowMessage = false
|
||||
return c.SendJSON(ReturnType{
|
||||
return c.SendJSON(struct {
|
||||
ImageList []*baserow `json:"image_list"`
|
||||
Page int `json:"page"`
|
||||
ShowNext bool `json:"showNext"`
|
||||
}{
|
||||
ImageList: rows[0:max_len],
|
||||
Page: page,
|
||||
Page: dat.Page,
|
||||
ShowNext: len(rows) == 11,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
package model_classes
|
||||
|
||||
type DATA_POINT_MODE int
|
||||
|
||||
const (
|
||||
DATA_POINT_MODE_TRAINING DATA_POINT_MODE = 1
|
||||
DATA_POINT_MODE_TESTING = 2
|
||||
)
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@@ -68,9 +67,9 @@ func fileProcessor(
|
||||
|
||||
parts := strings.Split(file.Name, "/")
|
||||
|
||||
mode := model_classes.DATA_POINT_MODE_TRAINING
|
||||
mode := DATA_POINT_MODE_TRAINING
|
||||
if parts[0] == "testing" {
|
||||
mode = model_classes.DATA_POINT_MODE_TESTING
|
||||
mode = DATA_POINT_MODE_TESTING
|
||||
}
|
||||
|
||||
data_point_id, err := model_classes.AddDataPoint(c.Db, ids[parts[1]], "id://", mode)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@@ -14,44 +13,35 @@ func deleteModelJSON(c *Context, id string) *Error {
|
||||
c.Logger.Warnf("Removing model with id: %s", id)
|
||||
_, err := c.Db.Exec("delete from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to delete models", err)
|
||||
}
|
||||
|
||||
model_path := path.Join("./savedData", id)
|
||||
c.Logger.Warnf("Removing folder of model with id: %s at %s", id, model_path)
|
||||
err = os.RemoveAll(model_path)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to remove data", err)
|
||||
}
|
||||
|
||||
return c.SendJSON(id)
|
||||
}
|
||||
|
||||
func handleDelete(handle *Handle) {
|
||||
handle.Delete("/models/delete", func(c *Context) *Error {
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
var dat struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
if err_ := c.ToJSON(&dat); err_ != nil {
|
||||
return err_
|
||||
}
|
||||
|
||||
type DeleteModel struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
DeleteAuthJson(handle, "/models/delete", User_Normal, func(c *Context, dat *DeleteModel) *Error {
|
||||
var model struct {
|
||||
Id string
|
||||
Name string
|
||||
Status int
|
||||
}
|
||||
|
||||
err := GetDBOnce(c, &model, "models where id=$1 and user_id=$2;", dat.Id, c.User.Id)
|
||||
if err == NotFoundError {
|
||||
return c.SendJSONStatus(http.StatusNotFound, "Model not found!")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
return c.E500M("Faield to get model", err)
|
||||
}
|
||||
|
||||
switch model.Status {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
|
||||
@@ -6,15 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// Auth level set when path is definied as 1
|
||||
func handleStats(c *Context) *Error {
|
||||
var b struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
}
|
||||
|
||||
if _err := c.ToJSON(&b); _err != nil {
|
||||
return _err
|
||||
}
|
||||
|
||||
func handleStats(c *Context, b *JustId) *Error {
|
||||
type Row struct {
|
||||
Name string `db:"mc.name" json:"name"`
|
||||
Training string `db:"count(mdp.id) filter (where mdp.model_mode=1)" json:"training"`
|
||||
@@ -23,7 +15,7 @@ func handleStats(c *Context) *Error {
|
||||
|
||||
rows, err := GetDbMultitple[Row](c, "model_data_point as mdp inner join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 group by mc.name order by mc.name asc;", b.Id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to get stats", err)
|
||||
}
|
||||
|
||||
c.ShowMessage = false
|
||||
@@ -50,5 +42,5 @@ func handleList(handle *Handle) {
|
||||
return c.SendJSON(got)
|
||||
})
|
||||
|
||||
handle.PostAuth("/models/class/stats", 1, handleStats)
|
||||
PostAuthJson(handle, "/models/class/stats", User_Normal, handleStats)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
|
||||
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
_ "image/png"
|
||||
"os"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,29 +4,17 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func handleRest(handle *Handle) {
|
||||
handle.Delete("/models/train/reset", func(c *Context) *Error {
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
var dat struct {
|
||||
Id string `json:"id"`
|
||||
}
|
||||
|
||||
if err := c.ToJSON(&dat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
DeleteAuthJson(handle, "/models/train/reset", User_Normal, func(c *Context, dat *JustId) *Error {
|
||||
model, err := GetBaseModel(c.Db, dat.Id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
} else if err != nil {
|
||||
// TODO improve response
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to get model", err)
|
||||
}
|
||||
|
||||
if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING {
|
||||
@@ -37,8 +25,7 @@ func handleRest(handle *Handle) {
|
||||
|
||||
_, err = c.Db.Exec("delete from model_definition where model_id=$1", model.Id)
|
||||
if err != nil {
|
||||
// TODO improve response
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to delete model", err)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
@@ -80,7 +79,7 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e
|
||||
}
|
||||
count = co.Count
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, DATA_POINT_MODE_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -139,7 +138,7 @@ func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool)
|
||||
return generateCvsExp(c, run_path, model_id, true)
|
||||
}
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -303,7 +302,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
||||
return generateCvsExpandExp(c, run_path, model_id, offset, true)
|
||||
}
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -337,7 +336,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
||||
// This is to load some extra data so that the model has more things to train on
|
||||
//
|
||||
|
||||
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
|
||||
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
package models_utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
FAILED_TRAINING = -4
|
||||
FAILED_PREPARING_TRAINING = -3
|
||||
FAILED_PREPARING_ZIP_FILE = -2
|
||||
FAILED_PREPARING = -1
|
||||
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
TRAINING = 4
|
||||
READY = 5
|
||||
READY_ALTERATION = 6
|
||||
READY_ALTERATION_FAILED = -6
|
||||
|
||||
READY_RETRAIN = 7
|
||||
READY_RETRAIN_FAILED = -7
|
||||
)
|
||||
|
||||
type ModelDefinitionStatus int
|
||||
|
||||
type LayerType int
|
||||
|
||||
const (
|
||||
LAYER_INPUT LayerType = 1
|
||||
LAYER_DENSE = 2
|
||||
LAYER_FLATTEN = 3
|
||||
LAYER_SIMPLE_BLOCK = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MODEL_DEFINITION_STATUS_CANCELD_TRAINING ModelDefinitionStatus = -4
|
||||
MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
|
||||
MODEL_DEFINITION_STATUS_PRE_INIT = 1
|
||||
MODEL_DEFINITION_STATUS_INIT = 2
|
||||
MODEL_DEFINITION_STATUS_TRAINING = 3
|
||||
MODEL_DEFINITION_STATUS_PAUSED_TRAINING = 6
|
||||
MODEL_DEFINITION_STATUS_TRANIED = 4
|
||||
MODEL_DEFINITION_STATUS_READY = 5
|
||||
)
|
||||
|
||||
type ModelClassStatus int
|
||||
|
||||
const (
|
||||
MODEL_CLASS_STATUS_TO_TRAIN ModelClassStatus = 1
|
||||
MODEL_CLASS_STATUS_TRAINING = 2
|
||||
MODEL_CLASS_STATUS_TRAINED = 3
|
||||
)
|
||||
|
||||
type ModelHeadStatus int
|
||||
|
||||
const (
|
||||
MODEL_HEAD_STATUS_PRE_INIT ModelHeadStatus = 1
|
||||
MODEL_HEAD_STATUS_INIT = 2
|
||||
MODEL_HEAD_STATUS_TRAINING = 3
|
||||
MODEL_HEAD_STATUS_TRAINED = 4
|
||||
MODEL_HEAD_STATUS_READY = 5
|
||||
)
|
||||
|
||||
type BaseModel struct {
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
|
||||
ModelType int
|
||||
ImageMode int
|
||||
Width int
|
||||
Height int
|
||||
Format string
|
||||
}
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
rows, err := db.Query("select name, status, id, width, height, color_mode, format, model_type from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, ModelNotFoundError
|
||||
}
|
||||
|
||||
base = &BaseModel{}
|
||||
var colorMode string
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format, &base.ModelType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base.ImageMode = StringToImageMode(colorMode)
|
||||
return
|
||||
}
|
||||
|
||||
func (m BaseModel) CanEval() bool {
|
||||
if m.Status != READY && m.Status != READY_RETRAIN && m.Status != READY_RETRAIN_FAILED && m.Status != READY_ALTERATION && m.Status != READY_ALTERATION_FAILED {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func StringToImageMode(colorMode string) int {
|
||||
switch colorMode {
|
||||
case "greyscale":
|
||||
return 1
|
||||
case "rgb":
|
||||
return 3
|
||||
default:
|
||||
panic("unkown color mode")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user