feat: closes #18 added the code to generate models and other
This commit is contained in:
parent
04de6ad574
commit
bad53a13e6
@ -7,7 +7,7 @@ tmp_dir = "tmp"
|
||||
bin = "./tmp/main"
|
||||
cmd = "go build -o ./tmp/main ."
|
||||
delay = 0
|
||||
exclude_dir = ["assets", "tmp", "vendor", "testData", "savedData"]
|
||||
exclude_dir = ["assets", "tmp", "vendor", "testData", "savedData", "tensorflow"]
|
||||
exclude_file = []
|
||||
exclude_regex = ["_test.go"]
|
||||
exclude_unchanged = false
|
||||
|
3
go.mod
3
go.mod
@ -3,7 +3,10 @@ module git.andr3h3nriqu3s.com/andr3/fyp
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e // indirect
|
||||
github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 // indirect
|
||||
github.com/google/uuid v1.3.1 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
golang.org/x/crypto v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -1,6 +1,16 @@
|
||||
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e h1:9+2AEFZymTi25FIIcDwuzcOPH04z9+fV6XeLiGORPDI=
|
||||
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo=
|
||||
github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 h1:8Bt1P/zy1gb37L4n8CGgp1qmFwBV5729kxVfj0sqhJk=
|
||||
github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99/go.mod h1:3YgYBeIX42t83uP27Bd4bSMxTnQhSbxl0pYSkCDB1tc=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
|
||||
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
)
|
||||
|
||||
func loadBaseImage(handle *Handle, id string) {
|
||||
@ -21,7 +22,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
// TODO better logging
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Failed to read image for model with id %s\n", id)
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
defer infile.Close()
|
||||
@ -31,7 +32,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
// TODO better logging
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Failed to load image for model with id %s\n", id)
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
if format != "png" {
|
||||
@ -67,7 +68,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
fmt.Println("Other so assuming color")
|
||||
}
|
||||
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
|
||||
@ -77,7 +78,7 @@ func loadBaseImage(handle *Handle, id string) {
|
||||
// TODO better logging
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Could not update model\n")
|
||||
modelUpdateStatus(handle, id, -1)
|
||||
ModelUpdateStatus(handle, id, -1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -33,6 +33,17 @@ func ListClasses(db *sql.DB, model_id string) (cls []ModelClass, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||
result = false
|
||||
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return rows.Next(), nil
|
||||
}
|
||||
|
||||
var ClassAlreadyExists = errors.New("Class aready exists")
|
||||
|
||||
func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
)
|
||||
|
||||
func InsertIfNotPresent(ss []string, s string) []string {
|
||||
@ -31,7 +32,7 @@ func processZipFile(handle *Handle, id string) {
|
||||
reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip"))
|
||||
if err != nil {
|
||||
// TODO add msg to error
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
fmt.Printf("Faield to proccess zip file failed to open reader\n")
|
||||
fmt.Println(err)
|
||||
return
|
||||
@ -51,7 +52,7 @@ func processZipFile(handle *Handle, id string) {
|
||||
|
||||
if paths[0] != "training" && paths[0] != "testing" {
|
||||
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
|
||||
@ -66,7 +67,7 @@ func processZipFile(handle *Handle, id string) {
|
||||
fmt.Printf("testing and training are diferent\n")
|
||||
fmt.Println(testing)
|
||||
fmt.Println(training)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
|
||||
@ -79,21 +80,21 @@ func processZipFile(handle *Handle, id string) {
|
||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
dir_path = path.Join(base_path, "testing", name)
|
||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := model_classes.CreateClass(handle.Db, id, name)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to create class '%s' on db\n", name)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
ids[name] = id
|
||||
@ -108,14 +109,14 @@ func processZipFile(handle *Handle, id string) {
|
||||
f, err := os.Create(file_path)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not create file %s\n", file_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
data, err := reader.Open(file.Name)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not create file %s\n", file_path)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
defer data.Close()
|
||||
@ -135,13 +136,13 @@ func processZipFile(handle *Handle, id string) {
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to add data point for %s\n", id)
|
||||
fmt.Println(err)
|
||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Added data to model '%s'!\n", id)
|
||||
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
}
|
||||
|
||||
func handleDataUpload(handle *Handle) {
|
||||
@ -182,7 +183,7 @@ func handleDataUpload(handle *Handle) {
|
||||
}
|
||||
}
|
||||
|
||||
_, err = getBaseModel(handle.Db, id)
|
||||
_, err = GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
@ -203,7 +204,7 @@ func handleDataUpload(handle *Handle) {
|
||||
|
||||
f.Write(file)
|
||||
|
||||
modelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
|
||||
ModelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
|
||||
|
||||
go processZipFile(handle, id)
|
||||
|
||||
@ -230,7 +231,7 @@ func handleDataUpload(handle *Handle) {
|
||||
|
||||
id := f.Get("id")
|
||||
|
||||
model, err := getBaseModel(handle.Db, id)
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
@ -260,7 +261,7 @@ func handleDataUpload(handle *Handle) {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||
Redirect("/models/edit?id="+id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@ -71,7 +72,7 @@ func handleDelete(handle *Handle) {
|
||||
})
|
||||
}
|
||||
|
||||
var model BaseModel = BaseModel{}
|
||||
model := BaseModel{}
|
||||
model.Id = id
|
||||
|
||||
err = rows.Scan(&model.Name, &model.Status)
|
||||
@ -80,6 +81,8 @@ func handleDelete(handle *Handle) {
|
||||
}
|
||||
|
||||
switch model.Status {
|
||||
case FAILED_PREPARING_TRAINING:
|
||||
fallthrough
|
||||
case FAILED_PREPARING:
|
||||
deleteModel(handle, id, w, c, model)
|
||||
return nil
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@ -70,6 +71,11 @@ func handleEdit(handle *Handle) {
|
||||
case CONFIRM_PRE_TRAINING:
|
||||
|
||||
cls, err := model_classes.ListClasses(handle.Db, id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
has_data, err := model_classes.ModelHasDataPoints(handle.Db, id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
@ -77,7 +83,10 @@ func handleEdit(handle *Handle) {
|
||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Classes": cls,
|
||||
"HasData": has_data,
|
||||
}))
|
||||
case TRAINING:
|
||||
fallthrough
|
||||
case PREPARING_ZIP_FILE:
|
||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
|
@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
models_train "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
@ -15,5 +16,8 @@ func HandleModels (handle *Handle) {
|
||||
handleDataUpload(handle)
|
||||
|
||||
model_classes.HandleList(handle)
|
||||
|
||||
// Train endpoints
|
||||
models_train.HandleTrainEndpoints(handle)
|
||||
}
|
||||
|
||||
|
13
logic/models/train/main.go
Normal file
13
logic/models/train/main.go
Normal file
@ -0,0 +1,13 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func HandleTrainEndpoints(handle *Handle) {
|
||||
handleTrain(handle)
|
||||
handleRest(handle)
|
||||
|
||||
//TODO remove
|
||||
handleTest(handle)
|
||||
}
|
58
logic/models/train/reset.go
Normal file
58
logic/models/train/reset.go
Normal file
@ -0,0 +1,58 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func handleRest(handle *Handle) {
|
||||
handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
if c.Mode == JSON {
|
||||
panic("handle JSON /models/train/reset")
|
||||
}
|
||||
|
||||
f, err := MyParseForm(r)
|
||||
if err != nil {
|
||||
// TODO improve response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
if !CheckId(f, "id") {
|
||||
// TODO improve response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
id := f.Get("id")
|
||||
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
"GoBackLink": "/models",
|
||||
})
|
||||
} else if err != nil {
|
||||
// TODO improve response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != FAILED_PREPARING_TRAINING {
|
||||
// TODO improve response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
_, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id)
|
||||
if err != nil {
|
||||
// TODO improve response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(handle, model.Id, CONFIRM_PRE_TRAINING)
|
||||
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
139
logic/models/train/tensorflow-test.go
Normal file
139
logic/models/train/tensorflow-test.go
Normal file
@ -0,0 +1,139 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
/*import (
|
||||
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||
tg "github.com/galeone/tfgo"
|
||||
"github.com/galeone/tfgo/image"
|
||||
"github.com/galeone/tfgo/image/filter"
|
||||
"github.com/galeone/tfgo/image/padding"
|
||||
)*/
|
||||
|
||||
func getDir() string {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func shapeToSize(shape string) string {
|
||||
split := strings.Split(shape, ",")
|
||||
return strings.Join(split[:len(split) - 1], ",")
|
||||
}
|
||||
|
||||
func handleTest(handle *Handle) {
|
||||
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
|
||||
id, err := GetIdFromUrl(r, "id")
|
||||
if err != nil {
|
||||
return ErrorCode(err, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
var name string
|
||||
var file_path string
|
||||
err = rows.Scan(&name, &file_path)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
file_path = strings.Replace(file_path, "file://", "", 1)
|
||||
|
||||
img_path := path.Join("savedData", id, "data", "training", name, file_path)
|
||||
|
||||
fmt.Printf("%s\n", img_path)
|
||||
|
||||
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer definitions.Close()
|
||||
|
||||
if !definitions.Next() {
|
||||
fmt.Println("Did not find definition")
|
||||
return Error500(nil)
|
||||
}
|
||||
|
||||
var definition_id string
|
||||
|
||||
if err = definitions.Scan(&definition_id); err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer layers.Close()
|
||||
|
||||
type layerrow struct {
|
||||
LayerType int
|
||||
Shape string
|
||||
}
|
||||
|
||||
got := []layerrow{}
|
||||
|
||||
for layers.Next() {
|
||||
var row = layerrow{}
|
||||
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
row.Shape = shapeToSize(row.Shape)
|
||||
got = append(got, row)
|
||||
}
|
||||
|
||||
// Generate folder
|
||||
|
||||
err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm)
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
f, err := os.Create(path.Join("/tmp", id, "run.py"))
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py"))
|
||||
|
||||
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||
if err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if err = tmpl.Execute(f, AnyMap{
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
|
||||
}); err != nil {
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
w.Write([]byte("Done"))
|
||||
return nil
|
||||
})
|
||||
}
|
158
logic/models/train/train.go
Normal file
158
logic/models/train/train.go
Normal file
@ -0,0 +1,158 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
id = ""
|
||||
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return id, errors.New("Something wrong!")
|
||||
}
|
||||
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type ModelDefinitionStatus int
|
||||
|
||||
const (
|
||||
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
|
||||
MODEL_DEFINITION_STATUS_INIT = 2
|
||||
MODEL_DEFINITION_STATUS_TRAINING = 3
|
||||
MODEL_DEFINITION_STATUS_TRANIED = 4
|
||||
MODEL_DEFINITION_STATUS_READY = 5
|
||||
)
|
||||
|
||||
func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) {
|
||||
_, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id)
|
||||
return
|
||||
}
|
||||
|
||||
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape string) (err error) {
|
||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
||||
return
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
if !CheckAuthLevel(1, w, r, c) {
|
||||
return nil
|
||||
}
|
||||
if c.Mode == JSON {
|
||||
panic("TODO /models/train JSON")
|
||||
}
|
||||
|
||||
r.ParseForm()
|
||||
f := r.Form
|
||||
|
||||
number_of_models := 0
|
||||
accuracy := 0
|
||||
|
||||
if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) {
|
||||
fmt.Println(
|
||||
!CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy),
|
||||
)
|
||||
// TODO improve this response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
id := f.Get("id")
|
||||
model_type := f.Get("model_type")
|
||||
// Its not used rn
|
||||
_ = model_type
|
||||
|
||||
model, err := GetBaseModel(handle.Db, id)
|
||||
if err == ModelNotFoundError {
|
||||
return ErrorCode(nil, http.StatusNotFound, c.AddMap(AnyMap{
|
||||
"NotFoundMessage": "Model not found",
|
||||
"GoBackLink": "/models",
|
||||
}))
|
||||
} else if err != nil {
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != CONFIRM_PRE_TRAINING {
|
||||
// TODO improve this response
|
||||
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||
}
|
||||
|
||||
cls, err := model_classes.ListClasses(handle.Db, model.Id)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
|
||||
var fid string
|
||||
for i := 0; i < number_of_models; i++ {
|
||||
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
if fid == "" {
|
||||
fid = def_id
|
||||
}
|
||||
|
||||
// TODO change shape of it depends on the type of the image
|
||||
err = MakeLayer(handle.Db, def_id, 1, 1, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
err = MakeLayer(handle.Db, def_id, 4, 3, fmt.Sprintf("%d,1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
err = MakeLayer(handle.Db, def_id, 5, 2, fmt.Sprintf("%d,1", len(cls)))
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
|
||||
err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||
// TODO improve this response
|
||||
return Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO start training with id fid
|
||||
|
||||
ModelUpdateStatus(handle, model.Id, TRAINING)
|
||||
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type BaseModel struct {
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
}
|
||||
|
||||
const (
|
||||
FAILED_PREPARING_ZIP_FILE = -2
|
||||
FAILED_PREPARING = -1
|
||||
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
)
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func getBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
rows, err := db.Query("select name, status, id from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, ModelNotFoundError
|
||||
}
|
||||
|
||||
base = &BaseModel{}
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
49
logic/models/utils/types.go
Normal file
49
logic/models/utils/types.go
Normal file
@ -0,0 +1,49 @@
|
||||
package models_utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type BaseModel struct {
|
||||
Name string
|
||||
Status int
|
||||
Id string
|
||||
|
||||
Width int
|
||||
Height int
|
||||
}
|
||||
|
||||
const (
|
||||
FAILED_TRAINING = -4
|
||||
FAILED_PREPARING_TRAINING = -3
|
||||
FAILED_PREPARING_ZIP_FILE = -2
|
||||
FAILED_PREPARING = -1
|
||||
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
TRAINING = 4
|
||||
)
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, ModelNotFoundError
|
||||
}
|
||||
|
||||
base = &BaseModel{}
|
||||
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package models
|
||||
package models_utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -6,7 +6,8 @@ import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func modelUpdateStatus(handle *Handle, id string, status int) {
|
||||
// TODO make this return and caller handle error
|
||||
func ModelUpdateStatus(handle *Handle, id string, status int) {
|
||||
_, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
@ -2,10 +2,12 @@ package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -14,6 +16,21 @@ func CheckEmpty(f url.Values, path string) bool {
|
||||
return !f.Has(path) || f.Get(path) == ""
|
||||
}
|
||||
|
||||
func CheckNumber(f url.Values, path string, number *int) bool {
|
||||
if CheckEmpty(f, path) {
|
||||
fmt.Println("here", path)
|
||||
fmt.Println(f.Get(path))
|
||||
return false
|
||||
}
|
||||
n, err := strconv.Atoi(f.Get(path))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return false
|
||||
}
|
||||
*number = n
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckId(f url.Values, path string) bool {
|
||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
-- drop table if exists model_data_point;
|
||||
-- drop table if exists model_defenitions;
|
||||
-- drop table if exists model_definition_layer;
|
||||
-- drop table if exists model_definition;
|
||||
-- drop table if exists models;
|
||||
create table if not exists models (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
@ -32,3 +33,37 @@ create table if not exists model_data_point (
|
||||
-- 2 testing
|
||||
model_mode integer default 1
|
||||
);
|
||||
|
||||
-- drop table if exists model_definition;
|
||||
-- drop table if exists model_definition;
|
||||
create table if not exists model_definition (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
model_id uuid references models (id) on delete cascade,
|
||||
accuracy integer default 0,
|
||||
target_accuracy integer not null,
|
||||
epoch integer default 0,
|
||||
-- TODO add max epoch
|
||||
-- 1: Pre Init
|
||||
-- 2: Init
|
||||
-- 3: Training
|
||||
-- 4: Tranied
|
||||
-- 5: Ready
|
||||
status integer default 1,
|
||||
created_on timestamp default current_timestamp
|
||||
);
|
||||
|
||||
-- drop table if exists model_definition_layer;
|
||||
create table if not exists model_definition_layer (
|
||||
id uuid primary key default gen_random_uuid(),
|
||||
def_id uuid references model_definition (id) on delete cascade,
|
||||
layer_order integer not null,
|
||||
-- 1: input
|
||||
-- 2: dense
|
||||
-- 3: flatten
|
||||
-- TODO add conv
|
||||
layer_type integer not null,
|
||||
-- ei 28,28,1
|
||||
-- a 28x28 grayscale image
|
||||
shape text not null
|
||||
);
|
||||
|
||||
|
@ -252,8 +252,39 @@
|
||||
{{ end }}
|
||||
|
||||
{{ define "train-model-card" }}
|
||||
<form hx-delete="/models/train" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML" {{ if .Error }} class="submitted" {{end}} >
|
||||
tain menu
|
||||
<form hx-post="/models/train" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML" {{ if .Error }} class="submitted" {{end}} >
|
||||
{{ if .HasData }}
|
||||
{{/* TODO expading mode */}}
|
||||
<input type="hidden" value="{{ .Model.Id }}" name="id" />
|
||||
<fieldset>
|
||||
<legend>
|
||||
Model Type
|
||||
</legend>
|
||||
<div class="input-radial">
|
||||
<input id="model_type_simple" value="simple" name="model_type" type="radio" checked />
|
||||
<label for="model_type_simple">Simple</label>
|
||||
</div>
|
||||
</fieldset>
|
||||
{{/* TODO allow more models to be created */}}
|
||||
<fieldset>
|
||||
<label for="number_of_models">Number of Models</label>
|
||||
<input id="number_of_models" type="number" name="number_of_models" value="1" />
|
||||
</fieldset>
|
||||
{{/* TODO to Change the acc */}}
|
||||
<fieldset>
|
||||
<label for="accuracy">Target accuracy</label>
|
||||
<input id="accuracy" type="number" name="accuracy" value="95" />
|
||||
</fieldset>
|
||||
{{/* TODO allow to chose the base of the model */}}
|
||||
{{/* TODO allow to change the shape of the model */}}
|
||||
<button>
|
||||
Train
|
||||
</button>
|
||||
{{ else }}
|
||||
<h2>
|
||||
Please provide data to the model first
|
||||
</h2>
|
||||
{{ end }}
|
||||
</form>
|
||||
{{ end }}
|
||||
|
||||
@ -313,6 +344,29 @@
|
||||
{{/* TODO improve this */}}
|
||||
Processing zip file...
|
||||
</div>
|
||||
{{/* FAILED TO Prepare for training */}}
|
||||
{{ else if (eq .Model.Status -3)}}
|
||||
{{ template "base-model-card" . }}
|
||||
<form hx-delete="/models/train/reset" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML">
|
||||
Failed Prepare for training.<br/>
|
||||
<div class="spacer" ></div>
|
||||
<input type="hidden" name="id" value="{{ .Model.Id }}" />
|
||||
<button class="danger">
|
||||
Try Again
|
||||
</button>
|
||||
</form>
|
||||
{{ template "delete-model-card" . }}
|
||||
{{ else if (eq .Model.Status 4)}}
|
||||
{{ template "base-model-card" . }}
|
||||
<div class="card" hx-get="/models/edit?id={{ .Model.Id }}" hx-headers='{"REQUEST-TYPE": "htmlfull"}' hx-push="true" hx-swap="outerHTML" hx-target=".app" hx-trigger="load delay:2s" >
|
||||
{{/* TODO improve this */}}
|
||||
Training the model...<br/>
|
||||
{{/* TODO Add progress status on definitions */}}
|
||||
{{/* TODO Add aility to stop training */}}
|
||||
</div>
|
||||
<button hx-post="/models/train/test?id={{ .Model.Id }}" hx-headers='{"REQUEST-TYPE": "html"}'>
|
||||
Test
|
||||
</button>
|
||||
{{ else }}
|
||||
<h1>
|
||||
Unknown Status of the model.
|
||||
|
@ -40,7 +40,7 @@
|
||||
</table>
|
||||
{{else}}
|
||||
<h2 class="text-center">
|
||||
You don't have any model
|
||||
You don't have any models
|
||||
</h2>
|
||||
<div class="text-center">
|
||||
<a class="button padded" hx-get="/models/add" hx-headers='{"REQUEST-TYPE": "htmlfull"}' hx-push-url="true" hx-swap="outerHTML" hx-target=".app">
|
||||
|
47
views/py/python_model_template.py
Normal file
47
views/py/python_model_template.py
Normal file
@ -0,0 +1,47 @@
|
||||
import tensorflow as tf
|
||||
import random
|
||||
from tensorflow import keras
|
||||
from keras import layers, losses, optimizers
|
||||
|
||||
seed = random.randint(0, 100000000)
|
||||
|
||||
batch_size = 100
|
||||
|
||||
dataset = keras.utils.image_dataset_from_directory(
|
||||
"{{ .DataDir }}",
|
||||
color_mode="rgb",
|
||||
validation_split=0.2,
|
||||
label_mode='int',
|
||||
seed=seed,
|
||||
subset="training",
|
||||
image_size=({{ .Size }}),
|
||||
batch_size=batch_size)
|
||||
|
||||
dataset_validation = keras.utils.image_dataset_from_directory(
|
||||
"{{ .DataDir }}",
|
||||
color_mode="rgb",
|
||||
validation_split=0.2,
|
||||
label_mode='int',
|
||||
seed=seed,
|
||||
subset="validation",
|
||||
image_size=({{ .Size }}),
|
||||
batch_size=batch_size)
|
||||
|
||||
model = keras.Sequential([
|
||||
{{- range .Layers }}
|
||||
{{- if eq .LayerType 1}}
|
||||
layers.Rescaling(1./255),
|
||||
{{- else if eq .LayerType 2 }}
|
||||
layers.Dense({{ .Shape }}, activation="relu"),
|
||||
{{- else if eq .LayerType 3}}
|
||||
layers.Flatten(),
|
||||
{{- else }}
|
||||
ERROR
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
])
|
||||
|
||||
model.compile(loss=losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam())
|
||||
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=100)
|
||||
|
@ -176,7 +176,8 @@ form {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
form label {
|
||||
form label,
|
||||
form fieldset legend {
|
||||
display: block;
|
||||
padding-bottom: 5px;
|
||||
font-size: 1.2rem;
|
||||
@ -224,6 +225,16 @@ form button {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
form .input-radial input[type="radio"] {
|
||||
width: auto;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
form .input-radial label {
|
||||
display: inline;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
/* Upload files */
|
||||
|
||||
form fieldset.file-upload input[type="file"] {
|
||||
|
Loading…
Reference in New Issue
Block a user