feat: closes #18 added the code to generate models and other
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
)
|
||||
|
||||
func loadBaseImage(handle *Handle, id string) {
|
||||
@@ -21,7 +22,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
// TODO better logging
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Failed to read image for model with id %s\n", id)
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
defer infile.Close()
|
||||
@@ -31,7 +32,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
// TODO better logging
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Failed to load image for model with id %s\n", id)
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
if format != "png" {
|
||||
@@ -67,7 +68,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
fmt.Println("Other so assuming color")
|
||||
}
|
||||
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +78,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
// TODO better logging
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Could not update model\n")
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,17 @@ func ListClasses(db *sql.DB, model_id string) (cls []ModelClass, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||
result = false
|
||||
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return rows.Next(), nil
|
||||
}
|
||||
|
||||
var ClassAlreadyExists = errors.New("Class aready exists")
|
||||
|
||||
func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
)
|
||||
|
||||
func InsertIfNotPresent(ss []string, s string) []string {
|
||||
@@ -31,7 +32,7 @@ func processZipFile(handle *Handle, id string) {
|
||||
reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip"))
|
||||
if err != nil {
|
||||
// TODO add msg to error
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
fmt.Printf("Faield to proccess zip file failed to open reader\n")
|
||||
fmt.Println(err)
|
||||
return
|
||||
@@ -51,7 +52,7 @@ func processZipFile(handle *Handle, id string) {
|
||||
|
||||
if paths[0] != "training" && paths[0] != "testing" {
|
||||
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -66,7 +67,7 @@ func processZipFile(handle *Handle, id string) {
|
||||
fmt.Printf("testing and training are diferent\n")
|
||||
fmt.Println(testing)
|
||||
fmt.Println(training)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -79,21 +80,21 @@ func processZipFile(handle *Handle, id string) {
|
||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
dir_path = path.Join(base_path, "testing", name)
|
||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := model_classes.CreateClass(handle.Db, id, name)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create class '%s' on db\n", name)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
ids[name] = id
|
||||
@@ -108,14 +109,14 @@ func processZipFile(handle *Handle, id string) {
|
||||
f, err := os.Create(file_path)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not create file %s\n", file_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
data, err := reader.Open(file.Name)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not create file %s\n", file_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
defer data.Close()
|
||||
@@ -135,13 +136,13 @@ func processZipFile(handle *Handle, id string) {
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to add data point for %s\n", id)
|
||||
fmt.Println(err)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Added data to model '%s'!\n", id)
|
||||
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
}
|
||||
|
||||
func handleDataUpload(handle *Handle) {
|
||||
@@ -182,7 +183,7 @@ func handleDataUpload(handle *Handle) {
|
||||
}
|
||||
}
|
||||
|
||||
_, err = getBaseModel(handle.Db, id)
|
||||
_, err = GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
@@ -203,7 +204,7 @@ func handleDataUpload(handle *Handle) {
|
||||
|
||||
f.Write(file)
|
||||
|
||||
modelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
|
||||
|
||||
go processZipFile(handle, id)
|
||||
|
||||
@@ -230,7 +231,7 @@ func handleDataUpload(handle *Handle) {
|
||||
|
||||
id := f.Get("id")
|
||||
|
||||
model, err := getBaseModel(handle.Db, id)
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
@@ -260,7 +261,7 @@ func handleDataUpload(handle *Handle) {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
Redirect("/models/edit?id="+id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@@ -71,7 +72,7 @@ func handleDelete(handle *Handle) {
|
||||
})
|
||||
}
|
||||
|
||||
var model BaseModel = BaseModel{}
|
||||
model := BaseModel{}
|
||||
model.Id = id
|
||||
|
||||
err = rows.Scan(&model.Name, &model.Status)
|
||||
@@ -80,6 +81,8 @@ func handleDelete(handle *Handle) {
|
||||
}
|
||||
|
||||
switch model.Status {
|
||||
case FAILED_PREPARING_TRAINING:
|
||||
fallthrough
|
||||
case FAILED_PREPARING:
|
||||
deleteModel(handle, id, w, c, model)
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@@ -70,6 +71,11 @@ func handleEdit(handle *Handle) {
|
||||
case CONFIRM_PRE_TRAINING:
|
||||
|
||||
cls, err := model_classes.ListClasses(handle.Db, id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
has_data, err := model_classes.ModelHasDataPoints(handle.Db, id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
@@ -77,7 +83,10 @@ func handleEdit(handle *Handle) {
|
||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Classes": cls,
|
||||
"HasData": has_data,
|
||||
}))
|
||||
case TRAINING:
|
||||
fallthrough
|
||||
case PREPARING_ZIP_FILE:
|
||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
|
||||
@@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
models_train "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@@ -15,5 +16,8 @@ func HandleModels (handle *Handle) {
|
||||
handleDataUpload(handle)
|
||||
|
||||
model_classes.HandleList(handle)
|
||||
|
||||
// Train endpoints
|
||||
models_train.HandleTrainEndpoints(handle)
|
||||
}
|
||||
|
||||
|
||||
13
logic/models/train/main.go
Normal file
13
logic/models/train/main.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func HandleTrainEndpoints(handle *Handle) {
|
||||
handleTrain(handle)
|
||||
handleRest(handle)
|
||||
|
||||
//TODO remove
|
||||
handleTest(handle)
|
||||
}
|
||||
58
logic/models/train/reset.go
Normal file
58
logic/models/train/reset.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func handleRest(handle *Handle) {
|
||||
handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
if c.Mode == JSON {
|
||||
panic("handle JSON /models/train/reset")
|
||||
}
|
||||
|
||||
f, err := MyParseForm(r)
|
||||
if err != nil {
|
||||
// TODO improve response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
if !CheckId(f, "id") {
|
||||
// TODO improve response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
id := f.Get("id")
|
||||
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
"GoBackLink": "/models",
|
||||
})
|
||||
} else if err != nil {
|
||||
// TODO improve response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != FAILED_PREPARING_TRAINING {
|
||||
// TODO improve response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
_, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id)
|
||||
if err != nil {
|
||||
// TODO improve response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(handle, model.Id, CONFIRM_PRE_TRAINING)
|
||||
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
139
logic/models/train/tensorflow-test.go
Normal file
139
logic/models/train/tensorflow-test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
/*import (
|
||||
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||
tg "github.com/galeone/tfgo"
|
||||
"github.com/galeone/tfgo/image"
|
||||
"github.com/galeone/tfgo/image/filter"
|
||||
"github.com/galeone/tfgo/image/padding"
|
||||
)*/
|
||||
|
||||
func getDir() string {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func shapeToSize(shape string) string {
|
||||
split := strings.Split(shape, ",")
|
||||
return strings.Join(split[:len(split) - 1], ",")
|
||||
}
|
||||
|
||||
func handleTest(handle *Handle) {
|
||||
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
|
||||
id, err := GetIdFromUrl(r, "id")
|
||||
if err != nil {
|
||||
return ErrorCode(err, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
var name string
|
||||
var file_path string
|
||||
err = rows.Scan(&name, &file_path)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
file_path = strings.Replace(file_path, "file://", "", 1)
|
||||
|
||||
img_path := path.Join("savedData", id, "data", "training", name, file_path)
|
||||
|
||||
fmt.Printf("%s\n", img_path)
|
||||
|
||||
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer definitions.Close()
|
||||
|
||||
if !definitions.Next() {
|
||||
fmt.Println("Did not find definition")
|
||||
return Error500(nil)
|
||||
}
|
||||
|
||||
var definition_id string
|
||||
|
||||
if err = definitions.Scan(&definition_id); err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer layers.Close()
|
||||
|
||||
type layerrow struct {
|
||||
LayerType int
|
||||
Shape string
|
||||
}
|
||||
|
||||
got := []layerrow{}
|
||||
|
||||
for layers.Next() {
|
||||
var row = layerrow{}
|
||||
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
row.Shape = shapeToSize(row.Shape)
|
||||
got = append(got, row)
|
||||
}
|
||||
|
||||
// Generate folder
|
||||
|
||||
err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
f, err := os.Create(path.Join("/tmp", id, "run.py"))
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py"))
|
||||
|
||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if err = tmpl.Execute(f, AnyMap{
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
|
||||
}); err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
w.Write([]byte("Done"))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
158
logic/models/train/train.go
Normal file
158
logic/models/train/train.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
id = ""
|
||||
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return id, errors.New("Something wrong!")
|
||||
}
|
||||
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type ModelDefinitionStatus int
|
||||
|
||||
const (
|
||||
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
|
||||
MODEL_DEFINITION_STATUS_INIT = 2
|
||||
MODEL_DEFINITION_STATUS_TRAINING = 3
|
||||
MODEL_DEFINITION_STATUS_TRANIED = 4
|
||||
MODEL_DEFINITION_STATUS_READY = 5
|
||||
)
|
||||
|
||||
func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) {
|
||||
_, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id)
|
||||
return
|
||||
}
|
||||
|
||||
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape string) (err error) {
|
||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
||||
return
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
if c.Mode == JSON {
|
||||
panic("TODO /models/train JSON")
|
||||
}
|
||||
|
||||
r.ParseForm()
|
||||
f := r.Form
|
||||
|
||||
number_of_models := 0
|
||||
accuracy := 0
|
||||
|
||||
if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) {
|
||||
fmt.Println(
|
||||
!CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy),
|
||||
)
|
||||
// TODO improve this response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
id := f.Get("id")
|
||||
model_type := f.Get("model_type")
|
||||
// Its not used rn
|
||||
_ = model_type
|
||||
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, c.AddMap(AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
"GoBackLink": "/models",
|
||||
}))
|
||||
} else if err != nil {
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != CONFIRM_PRE_TRAINING {
|
||||
// TODO improve this response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
cls, err := model_classes.ListClasses(handle.Db, model.Id)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
|
||||
var fid string
|
||||
for i := 0; i < number_of_models; i++ {
|
||||
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if fid == "" {
|
||||
fid = def_id
|
||||
}
|
||||
|
||||
// TODO change shape of it depends on the type of the image
|
||||
err = MakeLayer(handle.Db, def_id, 1, 1, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
err = MakeLayer(handle.Db, def_id, 4, 3, fmt.Sprintf("%d,1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
err = MakeLayer(handle.Db, def_id, 5, 2, fmt.Sprintf("%d,1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO start training with id fid
|
||||
|
||||
ModelUpdateStatus(handle, model.Id, TRAINING)
|
||||
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type BaseModel struct {
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
}
|
||||
|
||||
const (
|
||||
FAILED_PREPARING_ZIP_FILE = -2
|
||||
FAILED_PREPARING = -1
|
||||
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
)
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func getBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
rows, err := db.Query("select name, status, id from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, ModelNotFoundError
|
||||
}
|
||||
|
||||
base = &BaseModel{}
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
49
logic/models/utils/types.go
Normal file
49
logic/models/utils/types.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package models_utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type BaseModel struct {
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
const (
|
||||
FAILED_TRAINING = -4
|
||||
FAILED_PREPARING_TRAINING = -3
|
||||
FAILED_PREPARING_ZIP_FILE = -2
|
||||
FAILED_PREPARING = -1
|
||||
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
TRAINING = 4
|
||||
)
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, ModelNotFoundError
|
||||
}
|
||||
|
||||
base = &BaseModel{}
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package models
|
||||
package models_utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -6,7 +6,8 @@ import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func modelUpdateStatus(handle *Handle, id string, status int) {
|
||||
// TODO make this return and caller handle error
|
||||
func ModelUpdateStatus(handle *Handle, id string, status int) {
|
||||
_, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
@@ -2,10 +2,12 @@ package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -14,6 +16,21 @@ func CheckEmpty(f url.Values, path string) bool {
|
||||
return !f.Has(path) || f.Get(path) == ""
|
||||
}
|
||||
|
||||
func CheckNumber(f url.Values, path string, number *int) bool {
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.Atoi(f.Get(path))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckId(f url.Values, path string) bool {
|
||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user