feat: closes #18 added the code to generate models and other
This commit is contained in:
parent
04de6ad574
commit
bad53a13e6
@ -7,7 +7,7 @@ tmp_dir = "tmp"
|
|||||||
bin = "./tmp/main"
|
bin = "./tmp/main"
|
||||||
cmd = "go build -o ./tmp/main ."
|
cmd = "go build -o ./tmp/main ."
|
||||||
delay = 0
|
delay = 0
|
||||||
exclude_dir = ["assets", "tmp", "vendor", "testData", "savedData"]
|
exclude_dir = ["assets", "tmp", "vendor", "testData", "savedData", "tensorflow"]
|
||||||
exclude_file = []
|
exclude_file = []
|
||||||
exclude_regex = ["_test.go"]
|
exclude_regex = ["_test.go"]
|
||||||
exclude_unchanged = false
|
exclude_unchanged = false
|
||||||
|
3
go.mod
3
go.mod
@ -3,7 +3,10 @@ module git.andr3h3nriqu3s.com/andr3/fyp
|
|||||||
go 1.20
|
go 1.20
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e // indirect
|
||||||
|
github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 // indirect
|
||||||
github.com/google/uuid v1.3.1 // indirect
|
github.com/google/uuid v1.3.1 // indirect
|
||||||
github.com/lib/pq v1.10.9 // indirect
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
golang.org/x/crypto v0.13.0 // indirect
|
golang.org/x/crypto v0.13.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.28.1 // indirect
|
||||||
)
|
)
|
||||||
|
10
go.sum
10
go.sum
@ -1,6 +1,16 @@
|
|||||||
|
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e h1:9+2AEFZymTi25FIIcDwuzcOPH04z9+fV6XeLiGORPDI=
|
||||||
|
github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo=
|
||||||
|
github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 h1:8Bt1P/zy1gb37L4n8CGgp1qmFwBV5729kxVfj0sqhJk=
|
||||||
|
github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99/go.mod h1:3YgYBeIX42t83uP27Bd4bSMxTnQhSbxl0pYSkCDB1tc=
|
||||||
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
||||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
|
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
|
||||||
|
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
|
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadBaseImage(handle *Handle, id string) {
|
func loadBaseImage(handle *Handle, id string) {
|
||||||
@ -21,7 +22,7 @@ func loadBaseImage(handle *Handle, id string) {
|
|||||||
// TODO better logging
|
// TODO better logging
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
fmt.Printf("Failed to read image for model with id %s\n", id)
|
fmt.Printf("Failed to read image for model with id %s\n", id)
|
||||||
modelUpdateStatus(handle, id, -1)
|
ModelUpdateStatus(handle, id, -1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer infile.Close()
|
defer infile.Close()
|
||||||
@ -31,7 +32,7 @@ func loadBaseImage(handle *Handle, id string) {
|
|||||||
// TODO better logging
|
// TODO better logging
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
fmt.Printf("Failed to load image for model with id %s\n", id)
|
fmt.Printf("Failed to load image for model with id %s\n", id)
|
||||||
modelUpdateStatus(handle, id, -1)
|
ModelUpdateStatus(handle, id, -1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if format != "png" {
|
if format != "png" {
|
||||||
@ -67,7 +68,7 @@ func loadBaseImage(handle *Handle, id string) {
|
|||||||
fmt.Println("Other so assuming color")
|
fmt.Println("Other so assuming color")
|
||||||
}
|
}
|
||||||
|
|
||||||
modelUpdateStatus(handle, id, -1)
|
ModelUpdateStatus(handle, id, -1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ func loadBaseImage(handle *Handle, id string) {
|
|||||||
// TODO better logging
|
// TODO better logging
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
fmt.Printf("Could not update model\n")
|
fmt.Printf("Could not update model\n")
|
||||||
modelUpdateStatus(handle, id, -1)
|
ModelUpdateStatus(handle, id, -1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,17 @@ func ListClasses(db *sql.DB, model_id string) (cls []ModelClass, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||||
|
result = false
|
||||||
|
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
return rows.Next(), nil
|
||||||
|
}
|
||||||
|
|
||||||
var ClassAlreadyExists = errors.New("Class aready exists")
|
var ClassAlreadyExists = errors.New("Class aready exists")
|
||||||
|
|
||||||
func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) {
|
func CreateClass(db *sql.DB, model_id string, name string) (id string, err error) {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func InsertIfNotPresent(ss []string, s string) []string {
|
func InsertIfNotPresent(ss []string, s string) []string {
|
||||||
@ -31,7 +32,7 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip"))
|
reader, err := zip.OpenReader(path.Join("savedData", id, "base_data.zip"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO add msg to error
|
// TODO add msg to error
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
fmt.Printf("Faield to proccess zip file failed to open reader\n")
|
fmt.Printf("Faield to proccess zip file failed to open reader\n")
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
@ -51,7 +52,7 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
|
|
||||||
if paths[0] != "training" && paths[0] != "testing" {
|
if paths[0] != "training" && paths[0] != "testing" {
|
||||||
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
|
fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
fmt.Printf("testing and training are diferent\n")
|
fmt.Printf("testing and training are diferent\n")
|
||||||
fmt.Println(testing)
|
fmt.Println(testing)
|
||||||
fmt.Println(training)
|
fmt.Println(training)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,21 +80,21 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
err = os.MkdirAll(dir_path, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
fmt.Printf("Failed to create dir %s\n", dir_path)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dir_path = path.Join(base_path, "testing", name)
|
dir_path = path.Join(base_path, "testing", name)
|
||||||
err = os.MkdirAll(dir_path, os.ModePerm)
|
err = os.MkdirAll(dir_path, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to create dir %s\n", dir_path)
|
fmt.Printf("Failed to create dir %s\n", dir_path)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := model_classes.CreateClass(handle.Db, id, name)
|
id, err := model_classes.CreateClass(handle.Db, id, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to create class '%s' on db\n", name)
|
fmt.Printf("Failed to create class '%s' on db\n", name)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ids[name] = id
|
ids[name] = id
|
||||||
@ -108,14 +109,14 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
f, err := os.Create(file_path)
|
f, err := os.Create(file_path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Could not create file %s\n", file_path)
|
fmt.Printf("Could not create file %s\n", file_path)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
data, err := reader.Open(file.Name)
|
data, err := reader.Open(file.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Could not create file %s\n", file_path)
|
fmt.Printf("Could not create file %s\n", file_path)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer data.Close()
|
defer data.Close()
|
||||||
@ -135,13 +136,13 @@ func processZipFile(handle *Handle, id string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Failed to add data point for %s\n", id)
|
fmt.Printf("Failed to add data point for %s\n", id)
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
modelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Added data to model '%s'!\n", id)
|
fmt.Printf("Added data to model '%s'!\n", id)
|
||||||
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDataUpload(handle *Handle) {
|
func handleDataUpload(handle *Handle) {
|
||||||
@ -182,7 +183,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = getBaseModel(handle.Db, id)
|
_, err = GetBaseModel(handle.Db, id)
|
||||||
if err == ModelNotFoundError {
|
if err == ModelNotFoundError {
|
||||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||||
"NotFoundMessage": "Model not found",
|
"NotFoundMessage": "Model not found",
|
||||||
@ -203,7 +204,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
|
|
||||||
f.Write(file)
|
f.Write(file)
|
||||||
|
|
||||||
modelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
|
ModelUpdateStatus(handle, id, PREPARING_ZIP_FILE)
|
||||||
|
|
||||||
go processZipFile(handle, id)
|
go processZipFile(handle, id)
|
||||||
|
|
||||||
@ -230,7 +231,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
|
|
||||||
id := f.Get("id")
|
id := f.Get("id")
|
||||||
|
|
||||||
model, err := getBaseModel(handle.Db, id)
|
model, err := GetBaseModel(handle.Db, id)
|
||||||
if err == ModelNotFoundError {
|
if err == ModelNotFoundError {
|
||||||
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||||
"NotFoundMessage": "Model not found",
|
"NotFoundMessage": "Model not found",
|
||||||
@ -260,7 +261,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING)
|
||||||
Redirect("/models/edit?id="+id, c.Mode, w, r)
|
Redirect("/models/edit?id="+id, c.Mode, w, r)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ func handleDelete(handle *Handle) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var model BaseModel = BaseModel{}
|
model := BaseModel{}
|
||||||
model.Id = id
|
model.Id = id
|
||||||
|
|
||||||
err = rows.Scan(&model.Name, &model.Status)
|
err = rows.Scan(&model.Name, &model.Status)
|
||||||
@ -80,6 +81,8 @@ func handleDelete(handle *Handle) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch model.Status {
|
switch model.Status {
|
||||||
|
case FAILED_PREPARING_TRAINING:
|
||||||
|
fallthrough
|
||||||
case FAILED_PREPARING:
|
case FAILED_PREPARING:
|
||||||
deleteModel(handle, id, w, c, model)
|
deleteModel(handle, id, w, c, model)
|
||||||
return nil
|
return nil
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,6 +71,11 @@ func handleEdit(handle *Handle) {
|
|||||||
case CONFIRM_PRE_TRAINING:
|
case CONFIRM_PRE_TRAINING:
|
||||||
|
|
||||||
cls, err := model_classes.ListClasses(handle.Db, id)
|
cls, err := model_classes.ListClasses(handle.Db, id)
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
has_data, err := model_classes.ModelHasDataPoints(handle.Db, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Error500(err)
|
return Error500(err)
|
||||||
}
|
}
|
||||||
@ -77,7 +83,10 @@ func handleEdit(handle *Handle) {
|
|||||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||||
"Model": model,
|
"Model": model,
|
||||||
"Classes": cls,
|
"Classes": cls,
|
||||||
|
"HasData": has_data,
|
||||||
}))
|
}))
|
||||||
|
case TRAINING:
|
||||||
|
fallthrough
|
||||||
case PREPARING_ZIP_FILE:
|
case PREPARING_ZIP_FILE:
|
||||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||||
"Model": model,
|
"Model": model,
|
||||||
|
@ -2,6 +2,7 @@ package models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||||
|
models_train "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,5 +16,8 @@ func HandleModels (handle *Handle) {
|
|||||||
handleDataUpload(handle)
|
handleDataUpload(handle)
|
||||||
|
|
||||||
model_classes.HandleList(handle)
|
model_classes.HandleList(handle)
|
||||||
|
|
||||||
|
// Train endpoints
|
||||||
|
models_train.HandleTrainEndpoints(handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
13
logic/models/train/main.go
Normal file
13
logic/models/train/main.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package models_train
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleTrainEndpoints(handle *Handle) {
|
||||||
|
handleTrain(handle)
|
||||||
|
handleRest(handle)
|
||||||
|
|
||||||
|
//TODO remove
|
||||||
|
handleTest(handle)
|
||||||
|
}
|
58
logic/models/train/reset.go
Normal file
58
logic/models/train/reset.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package models_train
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleRest(handle *Handle) {
|
||||||
|
handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
|
if !CheckAuthLevel(1, w, r, c) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.Mode == JSON {
|
||||||
|
panic("handle JSON /models/train/reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := MyParseForm(r)
|
||||||
|
if err != nil {
|
||||||
|
// TODO improve response
|
||||||
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !CheckId(f, "id") {
|
||||||
|
// TODO improve response
|
||||||
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
id := f.Get("id")
|
||||||
|
|
||||||
|
model, err := GetBaseModel(handle.Db, id)
|
||||||
|
if err == ModelNotFoundError {
|
||||||
|
return ErrorCode(nil, http.StatusNotFound, AnyMap{
|
||||||
|
"NotFoundMessage": "Model not found",
|
||||||
|
"GoBackLink": "/models",
|
||||||
|
})
|
||||||
|
} else if err != nil {
|
||||||
|
// TODO improve response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.Status != FAILED_PREPARING_TRAINING {
|
||||||
|
// TODO improve response
|
||||||
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id)
|
||||||
|
if err != nil {
|
||||||
|
// TODO improve response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelUpdateStatus(handle, model.Id, CONFIRM_PRE_TRAINING)
|
||||||
|
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
139
logic/models/train/tensorflow-test.go
Normal file
139
logic/models/train/tensorflow-test.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package models_train
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
/*import (
|
||||||
|
tf "github.com/galeone/tensorflow/tensorflow/go"
|
||||||
|
tg "github.com/galeone/tfgo"
|
||||||
|
"github.com/galeone/tfgo/image"
|
||||||
|
"github.com/galeone/tfgo/image/filter"
|
||||||
|
"github.com/galeone/tfgo/image/padding"
|
||||||
|
)*/
|
||||||
|
|
||||||
|
func getDir() string {
|
||||||
|
dir, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return dir
|
||||||
|
}
|
||||||
|
|
||||||
|
func shapeToSize(shape string) string {
|
||||||
|
split := strings.Split(shape, ",")
|
||||||
|
return strings.Join(split[:len(split) - 1], ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleTest(handle *Handle) {
|
||||||
|
handle.Post("/models/train/test", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
|
if !CheckAuthLevel(1, w, r, c) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := GetIdFromUrl(r, "id")
|
||||||
|
if err != nil {
|
||||||
|
return ErrorCode(err, 400, c.AddMap(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := handle.Db.Query("select mc.name, mdp.file_path from model_classes as mc join model_data_point as mdp on mdp.class_id = mc.id where mdp.model_mode = 1 and mc.model_id = $1 limit 1;", id)
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
var file_path string
|
||||||
|
err = rows.Scan(&name, &file_path)
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
file_path = strings.Replace(file_path, "file://", "", 1)
|
||||||
|
|
||||||
|
img_path := path.Join("savedData", id, "data", "training", name, file_path)
|
||||||
|
|
||||||
|
fmt.Printf("%s\n", img_path)
|
||||||
|
|
||||||
|
definitions, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=2 limit 1;", id)
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
defer definitions.Close()
|
||||||
|
|
||||||
|
if !definitions.Next() {
|
||||||
|
fmt.Println("Did not find definition")
|
||||||
|
return Error500(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
var definition_id string
|
||||||
|
|
||||||
|
if err = definitions.Scan(&definition_id); err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
defer layers.Close()
|
||||||
|
|
||||||
|
type layerrow struct {
|
||||||
|
LayerType int
|
||||||
|
Shape string
|
||||||
|
}
|
||||||
|
|
||||||
|
got := []layerrow{}
|
||||||
|
|
||||||
|
for layers.Next() {
|
||||||
|
var row = layerrow{}
|
||||||
|
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
row.Shape = shapeToSize(row.Shape)
|
||||||
|
got = append(got, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate folder
|
||||||
|
|
||||||
|
err = os.MkdirAll(path.Join("/tmp", id), os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(path.Join("/tmp", id, "run.py"))
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
fmt.Printf("Using path: %s\n", path.Join("/tmp", id, "run.py"))
|
||||||
|
|
||||||
|
tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
|
||||||
|
if err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tmpl.Execute(f, AnyMap{
|
||||||
|
"Layers": got,
|
||||||
|
"Size": got[0].Shape,
|
||||||
|
"DataDir": path.Join(getDir(), "savedData", id, "data", "training"),
|
||||||
|
}); err != nil {
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte("Done"))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
158
logic/models/train/train.go
Normal file
158
logic/models/train/train.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package models_train
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||||
|
id = ""
|
||||||
|
_, err = db.Exec("insert into model_definition (model_id, target_accuracy) values ($1, $2);", model_id, target_accuracy)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("select id from model_definition where model_id=$1 order by created_on DESC;", model_id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
return id, errors.New("Something wrong!")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDefinitionStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
|
||||||
|
MODEL_DEFINITION_STATUS_INIT = 2
|
||||||
|
MODEL_DEFINITION_STATUS_TRAINING = 3
|
||||||
|
MODEL_DEFINITION_STATUS_TRANIED = 4
|
||||||
|
MODEL_DEFINITION_STATUS_READY = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) {
|
||||||
|
_, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type int, shape string) (err error) {
|
||||||
|
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleTrain(handle *Handle) {
|
||||||
|
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
|
if !CheckAuthLevel(1, w, r, c) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c.Mode == JSON {
|
||||||
|
panic("TODO /models/train JSON")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ParseForm()
|
||||||
|
f := r.Form
|
||||||
|
|
||||||
|
number_of_models := 0
|
||||||
|
accuracy := 0
|
||||||
|
|
||||||
|
if !CheckId(f, "id") || CheckEmpty(f, "model_type") || !CheckNumber(f, "number_of_models", &number_of_models) || !CheckNumber(f, "accuracy", &accuracy) {
|
||||||
|
fmt.Println(
|
||||||
|
!CheckId(f, "id"), CheckEmpty(f, "model_type"), !CheckNumber(f, "number_of_models", &number_of_models), !CheckNumber(f, "accuracy", &accuracy),
|
||||||
|
)
|
||||||
|
// TODO improve this response
|
||||||
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
id := f.Get("id")
|
||||||
|
model_type := f.Get("model_type")
|
||||||
|
// Its not used rn
|
||||||
|
_ = model_type
|
||||||
|
|
||||||
|
model, err := GetBaseModel(handle.Db, id)
|
||||||
|
if err == ModelNotFoundError {
|
||||||
|
return ErrorCode(nil, http.StatusNotFound, c.AddMap(AnyMap{
|
||||||
|
"NotFoundMessage": "Model not found",
|
||||||
|
"GoBackLink": "/models",
|
||||||
|
}))
|
||||||
|
} else if err != nil {
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.Status != CONFIRM_PRE_TRAINING {
|
||||||
|
// TODO improve this response
|
||||||
|
return ErrorCode(nil, 400, c.AddMap(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
cls, err := model_classes.ListClasses(handle.Db, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
var fid string
|
||||||
|
for i := 0; i < number_of_models; i++ {
|
||||||
|
def_id, err := MakeDefenition(handle.Db, model.Id, accuracy)
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fid == "" {
|
||||||
|
fid = def_id
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO change shape of it depends on the type of the image
|
||||||
|
err = MakeLayer(handle.Db, def_id, 1, 1, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
err = MakeLayer(handle.Db, def_id, 4, 3, fmt.Sprintf("%d,1", len(cls)))
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
err = MakeLayer(handle.Db, def_id, 5, 2, fmt.Sprintf("%d,1", len(cls)))
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT)
|
||||||
|
if err != nil {
|
||||||
|
ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING)
|
||||||
|
// TODO improve this response
|
||||||
|
return Error500(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO start training with id fid
|
||||||
|
|
||||||
|
ModelUpdateStatus(handle, model.Id, TRAINING)
|
||||||
|
Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
@ -1,43 +0,0 @@
|
|||||||
package models
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type BaseModel struct {
|
|
||||||
Name string
|
|
||||||
Status int
|
|
||||||
Id string
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
FAILED_PREPARING_ZIP_FILE = -2
|
|
||||||
FAILED_PREPARING = -1
|
|
||||||
|
|
||||||
PREPARING = 1
|
|
||||||
CONFIRM_PRE_TRAINING = 2
|
|
||||||
PREPARING_ZIP_FILE = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
var ModelNotFoundError = errors.New("Model not found error")
|
|
||||||
|
|
||||||
func getBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
|
||||||
rows, err := db.Query("select name, status, id from models where id=$1;", id)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if !rows.Next() {
|
|
||||||
return nil, ModelNotFoundError
|
|
||||||
}
|
|
||||||
|
|
||||||
base = &BaseModel{}
|
|
||||||
err = rows.Scan(&base.Name, &base.Status, &base.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
49
logic/models/utils/types.go
Normal file
49
logic/models/utils/types.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package models_utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaseModel struct {
|
||||||
|
Name string
|
||||||
|
Status int
|
||||||
|
Id string
|
||||||
|
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
FAILED_TRAINING = -4
|
||||||
|
FAILED_PREPARING_TRAINING = -3
|
||||||
|
FAILED_PREPARING_ZIP_FILE = -2
|
||||||
|
FAILED_PREPARING = -1
|
||||||
|
|
||||||
|
PREPARING = 1
|
||||||
|
CONFIRM_PRE_TRAINING = 2
|
||||||
|
PREPARING_ZIP_FILE = 3
|
||||||
|
TRAINING = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModelNotFoundError = errors.New("Model not found error")
|
||||||
|
|
||||||
|
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||||
|
rows, err := db.Query("select name, status, id, width, height from models where id=$1;", id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, ModelNotFoundError
|
||||||
|
}
|
||||||
|
|
||||||
|
base = &BaseModel{}
|
||||||
|
err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package models
|
package models_utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -6,7 +6,8 @@ import (
|
|||||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func modelUpdateStatus(handle *Handle, id string, status int) {
|
// TODO make this return and caller handle error
|
||||||
|
func ModelUpdateStatus(handle *Handle, id string, status int) {
|
||||||
_, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id)
|
_, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed to update model status")
|
fmt.Println("Failed to update model status")
|
@ -2,10 +2,12 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -14,6 +16,21 @@ func CheckEmpty(f url.Values, path string) bool {
|
|||||||
return !f.Has(path) || f.Get(path) == ""
|
return !f.Has(path) || f.Get(path) == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CheckNumber(f url.Values, path string, number *int) bool {
|
||||||
|
if CheckEmpty(f, path) {
|
||||||
|
fmt.Println("here", path)
|
||||||
|
fmt.Println(f.Get(path))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
n, err := strconv.Atoi(f.Get(path))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
*number = n
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func CheckId(f url.Values, path string) bool {
|
func CheckId(f url.Values, path string) bool {
|
||||||
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
-- drop table if exists model_data_point;
|
-- drop table if exists model_data_point;
|
||||||
-- drop table if exists model_defenitions;
|
-- drop table if exists model_definition_layer;
|
||||||
|
-- drop table if exists model_definition;
|
||||||
-- drop table if exists models;
|
-- drop table if exists models;
|
||||||
create table if not exists models (
|
create table if not exists models (
|
||||||
id uuid primary key default gen_random_uuid(),
|
id uuid primary key default gen_random_uuid(),
|
||||||
@ -32,3 +33,37 @@ create table if not exists model_data_point (
|
|||||||
-- 2 testing
|
-- 2 testing
|
||||||
model_mode integer default 1
|
model_mode integer default 1
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- drop table if exists model_definition;
|
||||||
|
-- drop table if exists model_definition;
|
||||||
|
create table if not exists model_definition (
|
||||||
|
id uuid primary key default gen_random_uuid(),
|
||||||
|
model_id uuid references models (id) on delete cascade,
|
||||||
|
accuracy integer default 0,
|
||||||
|
target_accuracy integer not null,
|
||||||
|
epoch integer default 0,
|
||||||
|
-- TODO add max epoch
|
||||||
|
-- 1: Pre Init
|
||||||
|
-- 2: Init
|
||||||
|
-- 3: Training
|
||||||
|
-- 4: Tranied
|
||||||
|
-- 5: Ready
|
||||||
|
status integer default 1,
|
||||||
|
created_on timestamp default current_timestamp
|
||||||
|
);
|
||||||
|
|
||||||
|
-- drop table if exists model_definition_layer;
|
||||||
|
create table if not exists model_definition_layer (
|
||||||
|
id uuid primary key default gen_random_uuid(),
|
||||||
|
def_id uuid references model_definition (id) on delete cascade,
|
||||||
|
layer_order integer not null,
|
||||||
|
-- 1: input
|
||||||
|
-- 2: dense
|
||||||
|
-- 3: flatten
|
||||||
|
-- TODO add conv
|
||||||
|
layer_type integer not null,
|
||||||
|
-- ei 28,28,1
|
||||||
|
-- a 28x28 grayscale image
|
||||||
|
shape text not null
|
||||||
|
);
|
||||||
|
|
||||||
|
@ -252,8 +252,39 @@
|
|||||||
{{ end }}
|
{{ end }}
|
||||||
|
|
||||||
{{ define "train-model-card" }}
|
{{ define "train-model-card" }}
|
||||||
<form hx-delete="/models/train" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML" {{ if .Error }} class="submitted" {{end}} >
|
<form hx-post="/models/train" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML" {{ if .Error }} class="submitted" {{end}} >
|
||||||
tain menu
|
{{ if .HasData }}
|
||||||
|
{{/* TODO expading mode */}}
|
||||||
|
<input type="hidden" value="{{ .Model.Id }}" name="id" />
|
||||||
|
<fieldset>
|
||||||
|
<legend>
|
||||||
|
Model Type
|
||||||
|
</legend>
|
||||||
|
<div class="input-radial">
|
||||||
|
<input id="model_type_simple" value="simple" name="model_type" type="radio" checked />
|
||||||
|
<label for="model_type_simple">Simple</label>
|
||||||
|
</div>
|
||||||
|
</fieldset>
|
||||||
|
{{/* TODO allow more models to be created */}}
|
||||||
|
<fieldset>
|
||||||
|
<label for="number_of_models">Number of Models</label>
|
||||||
|
<input id="number_of_models" type="number" name="number_of_models" value="1" />
|
||||||
|
</fieldset>
|
||||||
|
{{/* TODO to Change the acc */}}
|
||||||
|
<fieldset>
|
||||||
|
<label for="accuracy">Target accuracy</label>
|
||||||
|
<input id="accuracy" type="number" name="accuracy" value="95" />
|
||||||
|
</fieldset>
|
||||||
|
{{/* TODO allow to chose the base of the model */}}
|
||||||
|
{{/* TODO allow to change the shape of the model */}}
|
||||||
|
<button>
|
||||||
|
Train
|
||||||
|
</button>
|
||||||
|
{{ else }}
|
||||||
|
<h2>
|
||||||
|
Please provide data to the model first
|
||||||
|
</h2>
|
||||||
|
{{ end }}
|
||||||
</form>
|
</form>
|
||||||
{{ end }}
|
{{ end }}
|
||||||
|
|
||||||
@ -313,6 +344,29 @@
|
|||||||
{{/* TODO improve this */}}
|
{{/* TODO improve this */}}
|
||||||
Processing zip file...
|
Processing zip file...
|
||||||
</div>
|
</div>
|
||||||
|
{{/* FAILED TO Prepare for training */}}
|
||||||
|
{{ else if (eq .Model.Status -3)}}
|
||||||
|
{{ template "base-model-card" . }}
|
||||||
|
<form hx-delete="/models/train/reset" hx-headers='{"REQUEST-TYPE": "html"}' hx-swap="outerHTML">
|
||||||
|
Failed Prepare for training.<br/>
|
||||||
|
<div class="spacer" ></div>
|
||||||
|
<input type="hidden" name="id" value="{{ .Model.Id }}" />
|
||||||
|
<button class="danger">
|
||||||
|
Try Again
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
{{ template "delete-model-card" . }}
|
||||||
|
{{ else if (eq .Model.Status 4)}}
|
||||||
|
{{ template "base-model-card" . }}
|
||||||
|
<div class="card" hx-get="/models/edit?id={{ .Model.Id }}" hx-headers='{"REQUEST-TYPE": "htmlfull"}' hx-push="true" hx-swap="outerHTML" hx-target=".app" hx-trigger="load delay:2s" >
|
||||||
|
{{/* TODO improve this */}}
|
||||||
|
Training the model...<br/>
|
||||||
|
{{/* TODO Add progress status on definitions */}}
|
||||||
|
{{/* TODO Add aility to stop training */}}
|
||||||
|
</div>
|
||||||
|
<button hx-post="/models/train/test?id={{ .Model.Id }}" hx-headers='{"REQUEST-TYPE": "html"}'>
|
||||||
|
Test
|
||||||
|
</button>
|
||||||
{{ else }}
|
{{ else }}
|
||||||
<h1>
|
<h1>
|
||||||
Unknown Status of the model.
|
Unknown Status of the model.
|
||||||
|
@ -40,7 +40,7 @@
|
|||||||
</table>
|
</table>
|
||||||
{{else}}
|
{{else}}
|
||||||
<h2 class="text-center">
|
<h2 class="text-center">
|
||||||
You don't have any model
|
You don't have any models
|
||||||
</h2>
|
</h2>
|
||||||
<div class="text-center">
|
<div class="text-center">
|
||||||
<a class="button padded" hx-get="/models/add" hx-headers='{"REQUEST-TYPE": "htmlfull"}' hx-push-url="true" hx-swap="outerHTML" hx-target=".app">
|
<a class="button padded" hx-get="/models/add" hx-headers='{"REQUEST-TYPE": "htmlfull"}' hx-push-url="true" hx-swap="outerHTML" hx-target=".app">
|
||||||
|
47
views/py/python_model_template.py
Normal file
47
views/py/python_model_template.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import random
|
||||||
|
from tensorflow import keras
|
||||||
|
from keras import layers, losses, optimizers
|
||||||
|
|
||||||
|
seed = random.randint(0, 100000000)
|
||||||
|
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
dataset = keras.utils.image_dataset_from_directory(
|
||||||
|
"{{ .DataDir }}",
|
||||||
|
color_mode="rgb",
|
||||||
|
validation_split=0.2,
|
||||||
|
label_mode='int',
|
||||||
|
seed=seed,
|
||||||
|
subset="training",
|
||||||
|
image_size=({{ .Size }}),
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
dataset_validation = keras.utils.image_dataset_from_directory(
|
||||||
|
"{{ .DataDir }}",
|
||||||
|
color_mode="rgb",
|
||||||
|
validation_split=0.2,
|
||||||
|
label_mode='int',
|
||||||
|
seed=seed,
|
||||||
|
subset="validation",
|
||||||
|
image_size=({{ .Size }}),
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
model = keras.Sequential([
|
||||||
|
{{- range .Layers }}
|
||||||
|
{{- if eq .LayerType 1}}
|
||||||
|
layers.Rescaling(1./255),
|
||||||
|
{{- else if eq .LayerType 2 }}
|
||||||
|
layers.Dense({{ .Shape }}, activation="relu"),
|
||||||
|
{{- else if eq .LayerType 3}}
|
||||||
|
layers.Flatten(),
|
||||||
|
{{- else }}
|
||||||
|
ERROR
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
])
|
||||||
|
|
||||||
|
model.compile(loss=losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam())
|
||||||
|
|
||||||
|
his = model.fit(dataset, validation_data= dataset_validation, epochs=100)
|
||||||
|
|
@ -176,7 +176,8 @@ form {
|
|||||||
box-shadow: none;
|
box-shadow: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
form label {
|
form label,
|
||||||
|
form fieldset legend {
|
||||||
display: block;
|
display: block;
|
||||||
padding-bottom: 5px;
|
padding-bottom: 5px;
|
||||||
font-size: 1.2rem;
|
font-size: 1.2rem;
|
||||||
@ -224,6 +225,16 @@ form button {
|
|||||||
padding: 10px;
|
padding: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
form .input-radial input[type="radio"] {
|
||||||
|
width: auto;
|
||||||
|
box-shadow: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
form .input-radial label {
|
||||||
|
display: inline;
|
||||||
|
font-size: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
/* Upload files */
|
/* Upload files */
|
||||||
|
|
||||||
form fieldset.file-upload input[type="file"] {
|
form fieldset.file-upload input[type="file"] {
|
||||||
|
Loading…
Reference in New Issue
Block a user