feat: closes #18 added the code to generate models and other

This commit is contained in:
2023-09-26 20:15:28 +01:00
parent 04de6ad574
commit bad53a13e6
22 changed files with 651 additions and 70 deletions

View File

@@ -12,6 +12,7 @@ import (
"path"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
)
func loadBaseImage(handle *Handle, id string) {
@@ -21,7 +22,7 @@ func loadBaseImage(handle *Handle, id string) {
// TODO better logging
fmt.Println(err)
fmt.Printf("Failed to read image for model with id %s\n", id)
modelUpdateStatus(handle, id, -1)
ModelUpdateStatus(handle, id, -1)
return
}
defer infile.Close()
@@ -31,7 +32,7 @@ func loadBaseImage(handle *Handle, id string) {
// TODO better logging
fmt.Println(err)
fmt.Printf("Failed to load image for model with id %s\n", id)
modelUpdateStatus(handle, id, -1)
ModelUpdateStatus(handle, id, -1)
return
}
if format != "png" {
@@ -67,7 +68,7 @@ func loadBaseImage(handle *Handle, id string) {
fmt.Println("Other so assuming color")
}
modelUpdateStatus(handle, id, -1)
ModelUpdateStatus(handle, id, -1)
return
}
@@ -77,7 +78,7 @@ func loadBaseImage(handle *Handle, id string) {
// TODO better logging
fmt.Println(err)
fmt.Printf("Could not update model\n")
modelUpdateStatus(handle, id, -1)
ModelUpdateStatus(handle, id, -1)
return
}
}

View File

@@ -33,6 +33,17 @@ func ListClasses(db *sql.DB, model_id string) (cls []ModelClass, err error) {
return
}
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
result = false
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
if err != nil {
return
}
defer rows.Close()
return rows.Next(), nil
}
var ClassAlreadyExists = errors.New("Class aready exists")
func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) {

View File

@@ -14,6 +14,7 @@ import (
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
)
func InsertIfNotPresent(ss []string, s string) []string {
@@ -31,7 +32,7 @@ func processZipFile(handle *Handle, id string) {
reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip"))
if err != nil {
// TODO add msg to error
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
fmt.Printf("Faield to proccess zip file failed to open reader\n")
fmt.Println(err)
return
@@ -51,7 +52,7 @@ func processZipFile(handle *Handle, id string) {
if paths[0] != "training" && paths[0] != "testing" {
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
@@ -66,7 +67,7 @@ func processZipFile(handle *Handle, id string) {
fmt.Printf("testing and training are diferent\n")
fmt.Println(testing)
fmt.Println(training)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
@@ -79,21 +80,21 @@ func processZipFile(handle *Handle, id string) {
err = os.MkdirAll(dir_path, os.ModePerm)
if err != nil {
fmt.Printf("Failed to create dir %s\n", dir_path)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
dir_path = path.Join(base_path, "testing", name)
err = os.MkdirAll(dir_path, os.ModePerm)
if err != nil {
fmt.Printf("Failed to create dir %s\n", dir_path)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
id, err := model_classes.CreateClass(handle.Db, id, name)
if err != nil {
fmt.Printf("Failed to create class '%s' on db\n", name)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
ids[name] = id
@@ -108,14 +109,14 @@ func processZipFile(handle *Handle, id string) {
f, err := os.Create(file_path)
if err != nil {
fmt.Printf("Could not create file %s\n", file_path)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
defer f.Close()
data, err := reader.Open(file.Name)
if err != nil {
fmt.Printf("Could not create file %s\n", file_path)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
defer data.Close()
@@ -135,13 +136,13 @@ func processZipFile(handle *Handle, id string) {
if err != nil {
fmt.Printf("Failed to add data point for %s\n", id)
fmt.Println(err)
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
return
}
}
fmt.Printf("Added data to model '%s'!\n", id)
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
}
func handleDataUpload(handle *Handle) {
@@ -182,7 +183,7 @@ func handleDataUpload(handle *Handle) {
}
}
_, err = getBaseModel(handle.Db, id)
_, err = GetBaseModel(handle.Db, id)
if err == ModelNotFoundError {
return ErrorCode(nil, http.StatusNotFound, AnyMap{
"NotFoundMessage": "Model not found",
@@ -203,7 +204,7 @@ func handleDataUpload(handle *Handle) {
f.Write(file)
modelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
ModelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
go processZipFile(handle, id)
@@ -230,7 +231,7 @@ func handleDataUpload(handle *Handle) {
id := f.Get("id")
model, err := getBaseModel(handle.Db, id)
model, err := GetBaseModel(handle.Db, id)
if err == ModelNotFoundError {
return ErrorCode(nil, http.StatusNotFound, AnyMap{
"NotFoundMessage": "Model not found",
@@ -260,7 +261,7 @@ func handleDataUpload(handle *Handle) {
return Error500(err)
}
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
Redirect("/models/edit?id="+id, c.Mode, w, r)
return nil
})

View File

@@ -7,6 +7,7 @@ import (
"path"
"strconv"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
@@ -71,7 +72,7 @@ func handleDelete(handle *Handle) {
})
}
var model BaseModel = BaseModel{}
model := BaseModel{}
model.Id = id
err = rows.Scan(&model.Name, &model.Status)
@@ -80,6 +81,8 @@ func handleDelete(handle *Handle) {
}
switch model.Status {
case FAILED_PREPARING_TRAINING:
fallthrough
case FAILED_PREPARING:
deleteModel(handle, id, w, c, model)
return nil

View File

@@ -5,6 +5,7 @@ import (
"net/http"
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
@@ -70,6 +71,11 @@ func handleEdit(handle *Handle) {
case CONFIRM_PRE_TRAINING:
cls, err := model_classes.ListClasses(handle.Db, id)
if err != nil {
return Error500(err)
}
has_data, err := model_classes.ModelHasDataPoints(handle.Db, id)
if err != nil {
return Error500(err)
}
@@ -77,7 +83,10 @@ func handleEdit(handle *Handle) {
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model,
"Classes": cls,
"HasData": has_data,
}))
case TRAINING:
fallthrough
case PREPARING_ZIP_FILE:
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model,

View File

@@ -2,6 +2,7 @@ package models
import (
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
models_train "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
@@ -15,5 +16,8 @@ func HandleModels (handle *Handle) {
handleDataUpload(handle)
model_classes.HandleList(handle)
// Train endpoints
models_train.HandleTrainEndpoints(handle)
}

View File

@@ -0,0 +1,13 @@
package models_train
import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func HandleTrainEndpoints(handle *Handle) {
handleTrain(handle)
handleRest(handle)
//TODO remove
handleTest(handle)
}

View File

@@ -0,0 +1,58 @@
package models_train
import (
"net/http"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func handleRest(handle *Handle) {
handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
if c.Mode == JSON {
panic("handle JSON /models/train/reset")
}
f, err := MyParseForm(r)
if err != nil {
// TODO improve response
return ErrorCode(nil, 400, c.AddMap(nil))
}
if !CheckId(f, "id") {
// TODO improve response
return ErrorCode(nil, 400, c.AddMap(nil))
}
id := f.Get("id")
model, err := GetBaseModel(handle.Db, id)
if err == ModelNotFoundError {
return ErrorCode(nil, http.StatusNotFound, AnyMap{
"NotFoundMessage": "Model not found",
"GoBackLink": "/models",
})
} else if err != nil {
// TODO improve response
return Error500(err)
}
if model.Status != FAILED_PREPARING_TRAINING {
// TODO improve response
return ErrorCode(nil, 400, c.AddMap(nil))
}
_, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id)
if err != nil {
// TODO improve response
return Error500(err)
}
ModelUpdateStatus(handle, model.Id, CONFIRM_PRE_TRAINING)
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
return nil
})
}

View File

@@ -0,0 +1,139 @@
package models_train
import (
"fmt"
"net/http"
"os"
"path"
"strings"
"text/template"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
/*import (
tf "github.com/galeone/tensorflow/tensorflow/go"
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
"github.com/galeone/tfgo/image/filter"
"github.com/galeone/tfgo/image/padding"
)*/
func getDir() string {
dir, err := os.Getwd()
if err != nil {
panic(err)
}
return dir
}
func shapeToSize(shape string) string {
split := strings.Split(shape, ",")
return strings.Join(split[:len(split) - 1], ",")
}
func handleTest(handle *Handle) {
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
id, err := GetIdFromUrl(r, "id")
if err != nil {
return ErrorCode(err, 400, c.AddMap(nil))
}
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
if err != nil {
return Error500(err)
}
defer rows.Close()
if !rows.Next() {
return Error500(err)
}
var name string
var file_path string
err = rows.Scan(&name, &file_path)
if err != nil {
return Error500(err)
}
file_path = strings.Replace(file_path, "file://", "", 1)
img_path := path.Join("savedData", id, "data", "training", name, file_path)
fmt.Printf("%s\n", img_path)
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
if err != nil {
return Error500(err)
}
defer definitions.Close()
if !definitions.Next() {
fmt.Println("Did not find definition")
return Error500(nil)
}
var definition_id string
if err = definitions.Scan(&definition_id); err != nil {
return Error500(err)
}
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
return Error500(err)
}
defer layers.Close()
type layerrow struct {
LayerType int
Shape string
}
got := []layerrow{}
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
return Error500(err)
}
row.Shape = shapeToSize(row.Shape)
got = append(got, row)
}
// Generate folder
err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm)
if err != nil {
return Error500(err)
}
f, err := os.Create(path.Join("/tmp", id, "run.py"))
if err != nil {
return Error500(err)
}
defer f.Close()
fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py"))
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
if err != nil {
return Error500(err)
}
if err = tmpl.Execute(f, AnyMap{
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
}); err != nil {
return Error500(err)
}
w.Write([]byte("Done"))
return nil
})
}

158
logic/models/train/train.go Normal file
View File

@@ -0,0 +1,158 @@
package models_train
import (
"database/sql"
"errors"
"fmt"
"net/http"
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = ""
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
if err != nil {
return
}
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return id, errors.New("Something wrong!")
}
err = rows.Scan(&id)
if err != nil {
return
}
return
}
type ModelDefinitionStatus int
const (
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
MODEL_DEFINITION_STATUS_INIT = 2
MODEL_DEFINITION_STATUS_TRAINING = 3
MODEL_DEFINITION_STATUS_TRANIED = 4
MODEL_DEFINITION_STATUS_READY = 5
)
func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) {
_, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id)
return
}
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape string) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
return
}
func handleTrain(handle *Handle) {
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
if c.Mode == JSON {
panic("TODO /models/train JSON")
}
r.ParseForm()
f := r.Form
number_of_models := 0
accuracy := 0
if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) {
fmt.Println(
!CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy),
)
// TODO improve this response
return ErrorCode(nil, 400, c.AddMap(nil))
}
id := f.Get("id")
model_type := f.Get("model_type")
// Its not used rn
_ = model_type
model, err := GetBaseModel(handle.Db, id)
if err == ModelNotFoundError {
return ErrorCode(nil, http.StatusNotFound, c.AddMap(AnyMap{
"NotFoundMessage": "Model not found",
"GoBackLink": "/models",
}))
} else if err != nil {
// TODO improve this response
return Error500(err)
}
if model.Status != CONFIRM_PRE_TRAINING {
// TODO improve this response
return ErrorCode(nil, 400, c.AddMap(nil))
}
cls, err := model_classes.ListClasses(handle.Db, model.Id)
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
var fid string
for i := 0; i < number_of_models; i++ {
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
if fid == "" {
fid = def_id
}
// TODO change shape of it depends on the type of the image
err = MakeLayer(handle.Db, def_id, 1, 1, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
err = MakeLayer(handle.Db, def_id, 4, 3, fmt.Sprintf("%d,1", len(cls)))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
err = MakeLayer(handle.Db, def_id, 5, 2, fmt.Sprintf("%d,1", len(cls)))
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
// TODO improve this response
return Error500(err)
}
}
// TODO start training with id fid
ModelUpdateStatus(handle, model.Id, TRAINING)
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
return nil
})
}

View File

@@ -1,43 +0,0 @@
package models
import (
"database/sql"
"errors"
)
type BaseModel struct {
Name string
Status int
Id string
}
const (
FAILED_PREPARING_ZIP_FILE = -2
FAILED_PREPARING = -1
PREPARING = 1
CONFIRM_PRE_TRAINING = 2
PREPARING_ZIP_FILE = 3
)
var ModelNotFoundError = errors.New("Model not found error")
func getBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
rows, err := db.Query("select name, status, id from models where id=$1;", id)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return nil, ModelNotFoundError
}
base = &BaseModel{}
err = rows.Scan(&base.Name, &base.Status, &base.Id)
if err != nil {
return nil, err
}
return
}

View File

@@ -0,0 +1,49 @@
package models_utils
import (
"database/sql"
"errors"
)
type BaseModel struct {
Name string
Status int
Id string
Width int
Height int
}
const (
FAILED_TRAINING = -4
FAILED_PREPARING_TRAINING = -3
FAILED_PREPARING_ZIP_FILE = -2
FAILED_PREPARING = -1
PREPARING = 1
CONFIRM_PRE_TRAINING = 2
PREPARING_ZIP_FILE = 3
TRAINING = 4
)
var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return nil, ModelNotFoundError
}
base = &BaseModel{}
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
if err != nil {
return nil, err
}
return
}

View File

@@ -1,4 +1,4 @@
package models
package models_utils
import (
"fmt"
@@ -6,7 +6,8 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func modelUpdateStatus(handle *Handle, id string, status int) {
// TODO make this return and caller handle error
func ModelUpdateStatus(handle *Handle, id string, status int) {
_, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id)
if err != nil {
fmt.Println("Failed to update model status")

View File

@@ -2,10 +2,12 @@ package utils
import (
"errors"
"fmt"
"io"
"mime"
"net/http"
"net/url"
"strconv"
"github.com/google/uuid"
)
@@ -14,6 +16,21 @@ func CheckEmpty(f url.Values, path string) bool {
return !f.Has(path) || f.Get(path) == ""
}
func CheckNumber(f url.Values, path string, number *int) bool {
if CheckEmpty(f, path) {
fmt.Println("here", path)
fmt.Println(f.Get(path))
return false
}
n, err := strconv.Atoi(f.Get(path))
if err != nil {
fmt.Println(err)
return false
}
*number = n
return true
}
func CheckId(f url.Values, path string) bool {
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
}