diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..84e9a67
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,5 @@
+tmp/
+testData/
+savedData/
+!savedData/.keep
+fyp
diff --git a/DockerfileDev b/DockerfileNginxDev
similarity index 100%
rename from DockerfileDev
rename to DockerfileNginxDev
diff --git a/DockerfileServer b/DockerfileServer
new file mode 100644
index 0000000..0c06cf5
--- /dev/null
+++ b/DockerfileServer
@@ -0,0 +1,54 @@
+FROM docker.io/nvidia/cuda:11.8.0-devel-ubuntu22.04
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update
+RUN apt-get install -y wget sudo pkg-config libopencv-dev unzip python3-pip
+
+RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
+
+RUN mkdir /go
+ENV GOPATH=/go
+
+RUN wget https://go.dev/dl/go1.22.2.linux-amd64.tar.gz
+RUN tar -xvf go1.22.2.linux-amd64.tar.gz -C /usr/local
+ENV PATH=$PATH:/usr/local/go/bin
+
+RUN mkdir /app
+WORKDIR /app
+
+ADD go.mod .
+ADD go.sum .
+ADD main.go .
+ADD logic logic
+
+RUN go install || true
+
+WORKDIR /root
+
+RUN wget https://github.com/sugarme/gotch/releases/download/v0.9.0/setup-libtorch.sh
+RUN chmod +x setup-libtorch.sh
+ENV CUDA_VER=11.8
+ENV GOTCH_VER=v0.9.1
+RUN bash setup-libtorch.sh
+ENV GOTCH_LIBTORCH="/usr/local/lib/libtorch"
+ENV LIBRARY_PATH="$LIBRARY_PATH:$GOTCH_LIBTORCH/lib"
+ENV export CPATH="$CPATH:$GOTCH_LIBTORCH/lib:$GOTCH_LIBTORCH/include:$GOTCH_LIBTORCH/include/torch/csrc/api/include"
+ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$GOTCH_LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
+RUN wget https://github.com/sugarme/gotch/releases/download/v0.9.0/setup-gotch.sh
+RUN chmod +x setup-gotch.sh
+RUN echo 'root ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
+RUN bash setup-gotch.sh
+
+RUN ln -s /usr/local/lib/libtorch/include/torch/csrc /usr/local/lib/libtorch/include/torch/csrc/api/include/torch
+RUN mkdir -p /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include/torch/csrc/api
+RUN find /usr/local/lib/libtorch/include -maxdepth 1 -type d | tail -n +2 | grep -ve 'torch$' | xargs -I{} ln -s {} /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include
+RUN ln -s /usr/local/lib/libtorch/include/torch/csrc/api/include /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include/torch/csrc/api/include
+RUN find /usr/local/lib/libtorch/include/torch -maxdepth 1 -type f | xargs -I{} ln -s {} /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include/torch
+RUN ln -s /usr/local/lib/libtorch/lib/libcudnn.so.8 /usr/local/lib/libcudnn.so
+
+WORKDIR /app
+
+ADD . .
+RUN go install || true
+
+CMD ["bash", "-c", "go run ."]
diff --git a/go.mod b/go.mod
index 9ac9940..f93209a 100644
--- a/go.mod
+++ b/go.mod
@@ -4,8 +4,6 @@ go 1.21
require (
github.com/charmbracelet/log v0.3.1
- github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe
- github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99
github.com/google/uuid v1.6.0
github.com/lib/pq v1.10.9
golang.org/x/crypto v0.19.0
@@ -34,6 +32,7 @@ require (
github.com/muesli/termenv v0.15.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rivo/uniseg v0.4.6 // indirect
+ github.com/sugarme/gotch v0.9.1 // indirect
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.1.0 // indirect
diff --git a/go.sum b/go.sum
index 3a312a2..c95a5a6 100644
--- a/go.sum
+++ b/go.sum
@@ -13,12 +13,6 @@ github.com/charmbracelet/log v0.3.1/go.mod h1:OR4E1hutLsax3ZKpXbgUqPtTjQfrh1pG3z
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
-github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e h1:9+2AEFZymTi25FIIcDwuzcOPH04z9+fV6XeLiGORPDI=
-github.com/galeone/tensorflow/tensorflow/go v0.0.0-20221023090153-6b7fa0680c3e/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo=
-github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe h1:7yELf1NFEwECpXMGowkoftcInMlVtLTCdwWLmxKgzNM=
-github.com/galeone/tensorflow/tensorflow/go v0.0.0-20240119075110-6ad3cf65adfe/go.mod h1:TelZuq26kz2jysARBwOrTv16629hyUsHmIoj54QqyFo=
-github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99 h1:8Bt1P/zy1gb37L4n8CGgp1qmFwBV5729kxVfj0sqhJk=
-github.com/galeone/tfgo v0.0.0-20230715013254-16113111dc99/go.mod h1:3YgYBeIX42t83uP27Bd4bSMxTnQhSbxl0pYSkCDB1tc=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -74,7 +68,13 @@ github.com/rivo/uniseg v0.4.6 h1:Sovz9sDSwbOz9tgUy8JpT+KgCkPYJEN/oYzlJiYTNLg=
github.com/rivo/uniseg v0.4.6/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/sugarme/gotch v0.9.1 h1:J6JCE1C2AfPmM1xk0p46LdzWtfNvbvZZnWdkj9v54jo=
+github.com/sugarme/gotch v0.9.1/go.mod h1:dien16KQcZPg/g+YiEH3q3ldHlKO2//2I2i2Gp5OQcI=
+github.com/wangkuiyi/gotorch v0.0.0-20201028015551-9afed2f3ad7b h1:oJfm5gCGdy9k2Yb+qmMR+HMRQ89CbVDsDi6DD9AZSTk=
+github.com/wangkuiyi/gotorch v0.0.0-20201028015551-9afed2f3ad7b/go.mod h1:WC7g+ojb7tPOZhHI2+ZI7ZXTW7uzF9uFOZfZgIX+SjI=
+github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
diff --git a/lib b/lib
new file mode 120000
index 0000000..7ba1558
--- /dev/null
+++ b/lib
@@ -0,0 +1 @@
+/usr/local/lib
\ No newline at end of file
diff --git a/lib.go.back b/lib.go.back
new file mode 100644
index 0000000..283709c
--- /dev/null
+++ b/lib.go.back
@@ -0,0 +1,10 @@
+package libtch
+
+// #cgo LDFLAGS: -lstdc++ -ltorch -lc10 -ltorch_cpu -L${SRCDIR}/libtorch/lib
+// #cgo LDFLAGS: -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcudnn -lcaffe2_nvrtc -lnvrtc-builtins -lnvrtc -lnvToolsExt -lc10_cuda -ltorch_cuda
+// #cgo CFLAGS: -I${SRCDIR} -O3 -Wall -Wno-unused-variable -Wno-deprecated-declarations -Wno-c++11-narrowing -g -Wno-sign-compare -Wno-unused-function
+// #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=0
+// #cgo CFLAGS: -I/usr/local/cuda/include
+// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR} -g -O3
+// #cgo CXXFLAGS: -I${SRCDIR}/libtorch/lib -I${SRCDIR}/libtorch/include -I${SRCDIR}/libtorch/include/torch/csrc/api/include -I/opt/libtorch/include/torch/csrc/api/include
+import "C"
diff --git a/logic/db_types/classes.go b/logic/db_types/classes.go
index 49eee43..3492763 100644
--- a/logic/db_types/classes.go
+++ b/logic/db_types/classes.go
@@ -6,3 +6,19 @@ const (
DATA_POINT_MODE_TRAINING DATA_POINT_MODE = 1
DATA_POINT_MODE_TESTING = 2
)
+
+type ModelClassStatus int
+
+const (
+ CLASS_STATUS_TO_TRAIN ModelClassStatus = iota + 1
+ CLASS_STATUS_TRAINING
+ CLASS_STATUS_TRAINED
+)
+
+type ModelClass struct {
+ Id string `db:"mc.id"`
+ ModelId string `db:"mc.model_id"`
+ Name string `db:"mc.name"`
+ ClassOrder int `db:"mc.class_order"`
+ Status int `db:"mc.status"`
+}
diff --git a/logic/db_types/definitions.go b/logic/db_types/definitions.go
new file mode 100644
index 0000000..8242672
--- /dev/null
+++ b/logic/db_types/definitions.go
@@ -0,0 +1,95 @@
+package dbtypes
+
+import (
+ "time"
+
+ "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
+)
+
+type DefinitionStatus int
+
+const (
+ DEFINITION_STATUS_CANCELD_TRAINING DefinitionStatus = -4
+ DEFINITION_STATUS_FAILED_TRAINING = -3
+ DEFINITION_STATUS_PRE_INIT = 1
+ DEFINITION_STATUS_INIT = 2
+ DEFINITION_STATUS_TRAINING = 3
+ DEFINITION_STATUS_PAUSED_TRAINING = 6
+ DEFINITION_STATUS_TRANIED = 4
+ DEFINITION_STATUS_READY = 5
+)
+
+type Definition struct {
+ Id string `db:"md.id"`
+ ModelId string `db:"md.model_id"`
+ Accuracy float64 `db:"md.accuracy"`
+ TargetAccuracy int `db:"md.target_accuracy"`
+ Epoch int `db:"md.epoch"`
+ Status int `db:"md.status"`
+ CreatedOn time.Time `db:"md.created_on"`
+ EpochProgress int `db:"md.epoch_progress"`
+}
+
+type SortByAccuracyDefinitions []*Definition
+
+func (nf SortByAccuracyDefinitions) Len() int { return len(nf) }
+func (nf SortByAccuracyDefinitions) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
+func (nf SortByAccuracyDefinitions) Less(i, j int) bool {
+ return nf[i].Accuracy < nf[j].Accuracy
+}
+
+func GetDefinition(db db.Db, definition_id string) (definition Definition, err error) {
+ err = GetDBOnce(db, &definition, "model_definition as md where id=$1;", definition_id)
+ return
+}
+
+func MakeDefenition(db db.Db, model_id string, target_accuracy int) (definition Definition, err error) {
+ var NewDefinition = struct {
+ ModelId string `db:"model_id"`
+ TargetAccuracy int `db:"target_accuracy"`
+ }{ModelId: model_id, TargetAccuracy: target_accuracy}
+
+ id, err := InsertReturnId(db, &NewDefinition, "model_definition", "id")
+ if err != nil {
+ return
+ }
+ return GetDefinition(db, id)
+}
+
+func (d Definition) UpdateStatus(db db.Db, status DefinitionStatus) (err error) {
+ _, err = db.Exec("update model_definition set status=$1 where id=$2", status, d.Id)
+ return
+}
+
+func (d Definition) MakeLayer(db db.Db, layer_order int, layer_type LayerType, shape string) (layer Layer, err error) {
+ var NewLayer = struct {
+ DefinitionId string `db:"def_id"`
+ LayerOrder int `db:"layer_order"`
+ LayerType LayerType `db:"layer_type"`
+ Shape string `db:"shape"`
+ }{
+ DefinitionId: d.Id,
+ LayerOrder: layer_order,
+ LayerType: layer_type,
+ Shape: shape,
+ }
+
+ id, err := InsertReturnId(db, &NewLayer, "model_definition_layer", "id")
+ if err != nil {
+ return
+ }
+
+ return GetLayer(db, id)
+}
+
+func (d Definition) GetLayers(db db.Db, filter string, args ...any) (layer []*Layer, err error) {
+ args = append(args, d.Id)
+ return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...)
+}
+
+func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64) (err error) {
+ d.Accuracy = accuracy
+ d.Epoch += 1
+ _, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id)
+ return
+}
diff --git a/logic/db_types/layer.go b/logic/db_types/layer.go
new file mode 100644
index 0000000..1d738ed
--- /dev/null
+++ b/logic/db_types/layer.go
@@ -0,0 +1,50 @@
+package dbtypes
+
+import (
+ "encoding/json"
+
+ "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
+)
+
+type LayerType int
+
+const (
+ LAYER_INPUT LayerType = 1
+ LAYER_DENSE = 2
+ LAYER_FLATTEN = 3
+ LAYER_SIMPLE_BLOCK = 4
+)
+
+type Layer struct {
+ Id string `db:"mdl.id"`
+ DefinitionId string `db:"mdl.def_id"`
+ LayerOrder string `db:"mdl.layer_order"`
+ LayerType LayerType `db:"mdl.layer_type"`
+ Shape string `db:"mdl.shape"`
+ ExpType string `db:"mdl.exp_type"`
+}
+
+func ShapeToString(args ...int) string {
+ text, err := json.Marshal(args)
+ if err != nil {
+ panic("Could not generate Shape")
+ }
+ return string(text)
+}
+
+func StringToShape(str string) (shape []int64) {
+ err := json.Unmarshal([]byte(str), &shape)
+ if err != nil {
+ panic("Could not parse Shape")
+ }
+ return
+}
+
+func (l Layer) GetShape() []int64 {
+ return StringToShape(l.Shape)
+}
+
+func GetLayer(db db.Db, layer_id string) (layer Layer, err error) {
+ err = GetDBOnce(db, &layer, "model_definition_layer as mdl where mdl.id=$1", layer_id)
+ return
+}
diff --git a/logic/db_types/types.go b/logic/db_types/types.go
index 97fb993..0b033a2 100644
--- a/logic/db_types/types.go
+++ b/logic/db_types/types.go
@@ -1,9 +1,12 @@
package dbtypes
import (
- "errors"
+ "fmt"
+ "os"
+ "path"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
+ "github.com/jackc/pgx/v5"
)
const (
@@ -24,36 +27,6 @@ const (
READY_RETRAIN_FAILED = -7
)
-type ModelDefinitionStatus int
-
-type LayerType int
-
-const (
- LAYER_INPUT LayerType = 1
- LAYER_DENSE = 2
- LAYER_FLATTEN = 3
- LAYER_SIMPLE_BLOCK = 4
-)
-
-const (
- MODEL_DEFINITION_STATUS_CANCELD_TRAINING ModelDefinitionStatus = -4
- MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
- MODEL_DEFINITION_STATUS_PRE_INIT = 1
- MODEL_DEFINITION_STATUS_INIT = 2
- MODEL_DEFINITION_STATUS_TRAINING = 3
- MODEL_DEFINITION_STATUS_PAUSED_TRAINING = 6
- MODEL_DEFINITION_STATUS_TRANIED = 4
- MODEL_DEFINITION_STATUS_READY = 5
-)
-
-type ModelClassStatus int
-
-const (
- MODEL_CLASS_STATUS_TO_TRAIN ModelClassStatus = 1
- MODEL_CLASS_STATUS_TRAINING = 2
- MODEL_CLASS_STATUS_TRAINED = 3
-)
-
type ModelHeadStatus int
const (
@@ -78,8 +51,6 @@ type BaseModel struct {
CanTrain int `db:"can_train"`
}
-var ModelNotFoundError = errors.New("Model not found error")
-
func GetBaseModel(db db.Db, id string) (base *BaseModel, err error) {
var model BaseModel
err = GetDBOnce(db, &model, "models where id=$1", id)
@@ -97,11 +68,104 @@ func (m BaseModel) CanEval() bool {
return true
}
+func (m BaseModel) removeFailedDataPoints(c BasePack) (err error) {
+ rows, err := c.GetDb().Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", m.Id)
+ if err != nil {
+ return
+ }
+ defer rows.Close()
+
+ base_path := path.Join("savedData", m.Id, "data")
+
+ for rows.Next() {
+ var dataPointId string
+ err = rows.Scan(&dataPointId)
+ if err != nil {
+ return
+ }
+
+ p := path.Join(base_path, dataPointId+"."+m.Format)
+
+ c.GetLogger().Warn("Removing image", "path", p)
+
+ err = os.RemoveAll(p)
+ if err != nil {
+ return
+ }
+ }
+
+ _, err = c.GetDb().Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", m.Id)
+ return
+}
+
+// DO NOT Pass un filtered data on filters
+func (m BaseModel) GetDefinitions(db db.Db, filters string, args ...any) ([]*Definition, error) {
+ n_args := []any{m.Id}
+ n_args = append(n_args, args...)
+ return GetDbMultitple[Definition](db, fmt.Sprintf("model_definition as md where md.model_id=$1 %s", filters), n_args...)
+}
+
+// DO NOT Pass un filtered data on filters
+func (m BaseModel) GetClasses(db db.Db, filters string, args ...any) ([]*ModelClass, error) {
+ n_args := []any{m.Id}
+ n_args = append(n_args, args...)
+ return GetDbMultitple[ModelClass](db, fmt.Sprintf("model_classes as mc where mc.model_id=$1 %s", filters), n_args...)
+}
+
+type DataPointIterator struct {
+ rows pgx.Rows
+ Model BaseModel
+}
+
+type DataPoint struct {
+ Class int
+ Path string
+}
+
+func (iter DataPointIterator) Close() {
+ iter.rows.Close()
+}
+
+func (m BaseModel) DataPoints(db db.Db, mode DATA_POINT_MODE) (data []DataPoint, err error) {
+ rows, err := db.Query(
+ "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner "+
+ "join model_classes as mc on mc.id = mdp.class_id "+
+ "where mc.model_id = $1 and mdp.model_mode=$2;",
+ m.Id, mode)
+ if err != nil {
+ return
+ }
+ defer rows.Close()
+
+ data = []DataPoint{}
+
+ for rows.Next() {
+ var id string
+ var class_order int
+ var file_path string
+ if err = rows.Scan(&id, &class_order, &file_path); err != nil {
+ return
+ }
+ if file_path == "id://" {
+ data = append(data, DataPoint{
+ Path: path.Join("./savedData", m.Id, "data", id+"."+m.Format),
+ Class: class_order,
+ })
+ } else {
+ panic("TODO remote file path")
+ }
+ }
+ return
+}
+
+const RGB string = "rgb"
+const GRAY string = "greyscale"
+
func StringToImageMode(colorMode string) int {
switch colorMode {
- case "greyscale":
+ case GRAY:
return 1
- case "rgb":
+ case RGB:
return 3
default:
panic("unkown color mode")
diff --git a/logic/db_types/utils.go b/logic/db_types/utils.go
index 352cacd..c58c243 100644
--- a/logic/db_types/utils.go
+++ b/logic/db_types/utils.go
@@ -14,11 +14,13 @@ import (
"github.com/charmbracelet/log"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
+ "github.com/jackc/pgx/v5/pgconn"
db "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
)
type BasePack interface {
+ db.Db
GetDb() db.Db
GetLogger() *log.Logger
GetHost() string
@@ -42,6 +44,18 @@ func (b BasePackStruct) GetLogger() *log.Logger {
return b.Logger
}
+func (c BasePackStruct) Query(query string, args ...any) (pgx.Rows, error) {
+ return c.Db.Query(query, args...)
+}
+
+func (c BasePackStruct) Exec(query string, args ...any) (pgconn.CommandTag, error) {
+ return c.Db.Exec(query, args...)
+}
+
+func (c BasePackStruct) Begin() (pgx.Tx, error) {
+ return c.Db.Begin()
+}
+
func CheckEmpty(f url.Values, path string) bool {
return !f.Has(path) || f.Get(path) == ""
}
diff --git a/logic/models/classes/main.go b/logic/models/classes/main.go
index 81d7563..fc30ce7 100644
--- a/logic/models/classes/main.go
+++ b/logic/models/classes/main.go
@@ -7,15 +7,15 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
)
-type ModelClass struct {
+type ModelClassJSON struct {
Id string `json:"id"`
ModelId string `json:"model_id" db:"model_id"`
Name string `json:"name"`
Status int `json:"status"`
}
-func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) {
- return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id)
+func ListClassesJSON(c BasePack, model_id string) (cls []*ModelClassJSON, err error) {
+ return GetDbMultitple[ModelClassJSON](c.GetDb(), "model_classes where model_id=$1", model_id)
}
func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) {
diff --git a/logic/models/data.go b/logic/models/data.go
index 47112a0..b46d71d 100644
--- a/logic/models/data.go
+++ b/logic/models/data.go
@@ -435,7 +435,7 @@ func handleDataUpload(handle *Handle) {
}
model, err := GetBaseModel(handle.Db, id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.SendJSONStatus(http.StatusNotFound, "Model not found")
} else if err != nil {
return c.Error500(err)
@@ -468,7 +468,7 @@ func handleDataUpload(handle *Handle) {
}
PostAuthJson(handle, "/models/data/class/new", User_Normal, func(c *Context, obj *CreateNewEmptyClass) *Error {
model, err := GetBaseModel(c.Db, obj.Id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Failed to get model information", err)
@@ -495,7 +495,7 @@ func handleDataUpload(handle *Handle) {
return c.E500M("Could not create class", err)
}
- var modelClass model_classes.ModelClass
+ var modelClass model_classes.ModelClassJSON
err = GetDBOnce(c, &modelClass, "model_classes where id=$1;", id)
if err != nil {
return c.E500M("Failed to get class information but class was creted", err)
@@ -518,7 +518,7 @@ func handleDataUpload(handle *Handle) {
c.Logger.Info("model", "model", *model_id)
model, err := GetBaseModel(c.Db, *model_id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Could not find the model")
} else if err != nil {
return c.E500M("Error getting model information", err)
@@ -626,7 +626,7 @@ func handleDataUpload(handle *Handle) {
c.Logger.Info("Trying to expand model", "id", id)
model, err := GetBaseModel(handle.Db, id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.SendJSONStatus(http.StatusNotFound, "Model not found")
} else if err != nil {
return c.Error500(err)
@@ -670,7 +670,7 @@ func handleDataUpload(handle *Handle) {
}
model, err := GetBaseModel(handle.Db, dat.Id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.SendJSONStatus(http.StatusNotFound, "Model not found")
} else if err != nil {
return c.Error500(err)
@@ -704,7 +704,7 @@ func handleDataUpload(handle *Handle) {
return c.Error500(err)
}
} else {
- _, err = handle.Db.Exec("delete from model_classes where model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TO_TRAIN)
+ _, err = handle.Db.Exec("delete from model_classes where model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN)
if err != nil {
return c.Error500(err)
}
diff --git a/logic/models/edit.go b/logic/models/edit.go
index 0a6efa8..ebe1002 100644
--- a/logic/models/edit.go
+++ b/logic/models/edit.go
@@ -24,7 +24,7 @@ func handleEdit(handle *Handle) {
return c.Error500(err)
}
- cls, err := model_classes.ListClasses(c, model.Id)
+ cls, err := model_classes.ListClassesJSON(c, model.Id)
if err != nil {
return c.Error500(err)
}
@@ -35,9 +35,9 @@ func handleEdit(handle *Handle) {
}
type ReturnType struct {
- Classes []*model_classes.ModelClass `json:"classes"`
- HasData bool `json:"has_data"`
- NumberOfInvalidImages int `json:"number_of_invalid_images"`
+ Classes []*model_classes.ModelClassJSON `json:"classes"`
+ HasData bool `json:"has_data"`
+ NumberOfInvalidImages int `json:"number_of_invalid_images"`
}
c.ShowMessage = false
@@ -109,7 +109,7 @@ func handleEdit(handle *Handle) {
layers := []layerdef{}
for _, def := range defs {
- if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
+ if def.Status == DEFINITION_STATUS_TRAINING {
rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
if err != nil {
return c.Error500(err)
@@ -166,7 +166,7 @@ func handleEdit(handle *Handle) {
for i, def := range defs {
var lay *[]layerdef = nil
- if def.Status == MODEL_DEFINITION_STATUS_TRAINING && !setLayers {
+ if def.Status == DEFINITION_STATUS_TRAINING && !setLayers {
lay = &layers
setLayers = true
}
diff --git a/logic/models/run.go b/logic/models/run.go.notemp
similarity index 100%
rename from logic/models/run.go
rename to logic/models/run.go.notemp
diff --git a/logic/models/train/reset.go b/logic/models/train/reset.go
index bc92eeb..bbc96a8 100644
--- a/logic/models/train/reset.go
+++ b/logic/models/train/reset.go
@@ -11,7 +11,7 @@ import (
func handleRest(handle *Handle) {
DeleteAuthJson(handle, "/models/train/reset", User_Normal, func(c *Context, dat *JustId) *Error {
model, err := GetBaseModel(c.Db, dat.Id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Failed to get model", err)
diff --git a/logic/models/train/torch/modelloader/modelloader.go b/logic/models/train/torch/modelloader/modelloader.go
new file mode 100644
index 0000000..f5e6e64
--- /dev/null
+++ b/logic/models/train/torch/modelloader/modelloader.go
@@ -0,0 +1,149 @@
+package imageloader
+
+import (
+ "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
+ types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
+ "github.com/sugarme/gotch"
+ torch "github.com/sugarme/gotch/ts"
+ "github.com/sugarme/gotch/vision"
+)
+
+type Dataset struct {
+ TrainImages *torch.Tensor
+ TrainLabels *torch.Tensor
+ TestImages *torch.Tensor
+ TestLabels *torch.Tensor
+ TrainImagesSize int
+ TestImagesSize int
+ Device gotch.Device
+}
+
+func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MODE, classStart int, classEnd int) (imgs, labels *torch.Tensor, count int, err error) {
+ train_points, err := m.DataPoints(db, types.DATA_POINT_MODE_TRAINING)
+ if err != nil {
+ return
+ }
+
+ size := int64(classEnd - classStart + 1)
+
+ pimgs := []*torch.Tensor{}
+ plabels := []*torch.Tensor{}
+
+ for _, point := range train_points {
+ var img, label *torch.Tensor
+ img, err = vision.Load(point.Path)
+ if err != nil {
+ return
+ }
+ pimgs = append(pimgs, img)
+
+ t_label := make([]int, size)
+ if point.Class <= classEnd && point.Class >= classStart {
+ t_label[point.Class-classStart] = 1
+ }
+
+ label, err = torch.OfSlice(t_label)
+ if err != nil {
+ return
+ }
+ plabels = append(plabels, label)
+ }
+
+ imgs, err = torch.Concat(pimgs, 0)
+ if err != nil {
+ return
+ }
+
+ labels, err = torch.Stack(plabels, 0)
+ if err != nil {
+ return
+ }
+
+ count = len(pimgs)
+
+ imgs, err = torch.Stack(pimgs, 0)
+
+ labels, err = labels.ToDtype(gotch.Float, false, false, true)
+ if err != nil {
+ return
+ }
+
+ imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+func NewDataset(db db.Db, m *types.BaseModel, classStart int, classEnd int) (ds *Dataset, err error) {
+ trainImages, trainLabels, train_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TRAINING, classStart, classEnd)
+ if err != nil {
+ return
+ }
+
+ testImages, testLabels, test_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TESTING, classStart, classEnd)
+ if err != nil {
+ return
+ }
+
+ ds = &Dataset{
+ TrainImages: trainImages,
+ TrainLabels: trainLabels,
+ TestImages: testImages,
+ TestLabels: testLabels,
+ TrainImagesSize: train_count,
+ TestImagesSize: test_count,
+ Device: gotch.CPU,
+ }
+ return
+}
+
+func (ds *Dataset) To(device gotch.Device) (err error) {
+ ds.TrainImages, err = ds.TrainImages.ToDevice(device, ds.TrainImages.DType(), device.IsCuda(), true, true)
+ if err != nil {
+ return
+ }
+
+ ds.TrainLabels, err = ds.TrainLabels.ToDevice(device, ds.TrainLabels.DType(), device.IsCuda(), true, true)
+ if err != nil {
+ return
+ }
+
+ ds.TestImages, err = ds.TestImages.ToDevice(device, ds.TestImages.DType(), device.IsCuda(), true, true)
+ if err != nil {
+ return
+ }
+
+ ds.TestLabels, err = ds.TestLabels.ToDevice(device, ds.TestLabels.DType(), device.IsCuda(), true, true)
+ if err != nil {
+ return
+ }
+
+ ds.Device = device
+ return
+}
+
+func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
+ return torch.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
+}
+
+func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
+
+ train_images, err := ds.TrainImages.DetachCopy(false)
+ if err != nil {
+ return
+ }
+
+ train_labels, err := ds.TrainLabels.DetachCopy(false)
+ if err != nil {
+ return
+ }
+
+ iter, err = torch.NewIter2(train_images, train_labels, batchSize)
+ if err != nil {
+ return
+ }
+
+ return
+}
diff --git a/logic/models/train/torch/torch.go b/logic/models/train/torch/torch.go
new file mode 100644
index 0000000..7a34723
--- /dev/null
+++ b/logic/models/train/torch/torch.go
@@ -0,0 +1,81 @@
+package train
+
+import (
+ types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
+
+ "github.com/charmbracelet/log"
+ "github.com/sugarme/gotch"
+ "github.com/sugarme/gotch/nn"
+
+ //"github.com/sugarme/gotch"
+ //"github.com/sugarme/gotch/vision"
+ torch "github.com/sugarme/gotch/ts"
+)
+
+type IForwardable interface {
+ Forward(xs *torch.Tensor) *torch.Tensor
+}
+
+// Container for a model
+type ContainerModel struct {
+ Seq *nn.SequentialT
+ Vs *nn.VarStore
+}
+
+func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
+ return n.Seq.ForwardT(x, train)
+}
+
+func (n *ContainerModel) To(device gotch.Device) {
+ n.Vs.ToDevice(device)
+}
+
+func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *ContainerModel {
+
+ base_vs := nn.NewVarStore(gotch.CPU)
+ vs := base_vs.Root()
+ seq := nn.SeqT()
+
+ var lastLinearSize int64 = _lastLinearSize
+ lastLinearConv := []int64{}
+
+ for _, layer := range layers {
+ if layer.LayerType == types.LAYER_INPUT {
+ lastLinearConv = layer.GetShape()
+ log.Info("Input: ", "In:", lastLinearConv)
+ } else if layer.LayerType == types.LAYER_DENSE {
+ shape := layer.GetShape()
+ log.Info("New Dense: ", "In:", lastLinearSize, "out:", shape[0])
+ seq.Add(NewLinear(vs, lastLinearSize, shape[0]))
+ lastLinearSize = shape[0]
+ } else if layer.LayerType == types.LAYER_FLATTEN {
+ seq.Add(NewFlatten())
+ lastLinearSize = 1
+ for _, i := range lastLinearConv {
+ lastLinearSize *= i
+ }
+ log.Info("Flatten: ", "In:", lastLinearConv, "out:", lastLinearSize)
+ } else if layer.LayerType == types.LAYER_SIMPLE_BLOCK {
+ log.Info("New Block: ", "In:", lastLinearConv, "out:", []int64{lastLinearConv[1] / 2, lastLinearConv[2] / 2, 128})
+ seq.Add(NewSimpleBlock(vs, lastLinearConv[0]))
+ lastLinearConv[0] = 128
+ lastLinearConv[1] /= 2
+ lastLinearConv[2] /= 2
+ }
+ }
+
+ if addSigmoid {
+ seq.Add(NewSigmoid())
+ }
+
+ b := &ContainerModel{
+ Seq: seq,
+ Vs: base_vs,
+ }
+ return b
+}
+
+func SaveModel(model *ContainerModel, modelFn string) (err error) {
+ model.Vs.ToDevice(gotch.CPU)
+ return model.Vs.Save(modelFn)
+}
diff --git a/logic/models/train/torch/utils.go b/logic/models/train/torch/utils.go
new file mode 100644
index 0000000..6f7083c
--- /dev/null
+++ b/logic/models/train/torch/utils.go
@@ -0,0 +1,167 @@
+package train
+
+import (
+ "github.com/charmbracelet/log"
+
+ "github.com/sugarme/gotch/nn"
+ torch "github.com/sugarme/gotch/ts"
+)
+
+func or_panic(err error) {
+ if err != nil {
+ log.Fatal(err)
+ }
+}
+
+type SimpleBlock struct {
+ C1, C2 *nn.Conv2D
+ BN1 *nn.BatchNorm
+}
+
+// BasicBlock returns a BasicBlockModule instance
+func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock {
+ conf1 := nn.DefaultConv2DConfig()
+ conf1.Stride = []int64{2, 2}
+
+ conf2 := nn.DefaultConv2DConfig()
+ conf2.Padding = []int64{2, 2}
+
+ b := &SimpleBlock{
+ C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1),
+ C2: nn.NewConv2D(vs, 128, 128, 3, conf2),
+ BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()),
+ }
+ return b
+}
+
+// Forward method
+func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor {
+ identity := x
+
+ out := b.C1.Forward(x)
+ out = out.MustRelu(false)
+
+ out = b.C2.Forward(out)
+ out = out.MustRelu(false)
+
+ shape, err := out.Size()
+ or_panic(err)
+
+ out, err = out.AdaptiveAvgPool2d(shape, false)
+ or_panic(err)
+
+ out = b.BN1.Forward(out)
+ out, err = out.LeakyRelu(false)
+ or_panic(err)
+
+ out = out.MustAdd(identity, false)
+ out = out.MustRelu(false)
+
+ return out
+}
+
+func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
+ identity := x
+
+ out := b.C1.ForwardT(x, train)
+ out = out.MustRelu(false)
+
+ out = b.C2.ForwardT(out, train)
+ out = out.MustRelu(false)
+
+ shape, err := out.Size()
+ or_panic(err)
+
+ out, err = out.AdaptiveAvgPool2d(shape, false)
+ or_panic(err)
+
+ out = b.BN1.ForwardT(out, train)
+ out, err = out.LeakyRelu(false)
+ or_panic(err)
+
+ out = out.MustAdd(identity, false)
+ out = out.MustRelu(false)
+
+ return out
+}
+
+type MyLinear struct {
+ FC1 *nn.Linear
+}
+
+// BasicBlock returns a BasicBlockModule instance
+func NewLinear(vs *nn.Path, in, out int64) *MyLinear {
+ config := nn.DefaultLinearConfig()
+ b := &MyLinear{
+ FC1: nn.NewLinear(vs, in, out, config),
+ }
+ return b
+}
+
+// Forward method
+func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor {
+ var err error
+
+ out := b.FC1.Forward(x)
+
+ out, err = out.Relu(false)
+ or_panic(err)
+
+ return out
+}
+
+func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
+ var err error
+
+ out := b.FC1.ForwardT(x, train)
+
+ out, err = out.Relu(false)
+ or_panic(err)
+
+ return out
+}
+
+type Flatten struct{}
+
+// BasicBlock returns a BasicBlockModule instance
+func NewFlatten() *Flatten {
+ return &Flatten{}
+}
+
+// Forward method
+func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
+
+ out, err := x.Flatten(1, -1, false)
+ or_panic(err)
+
+ return out
+}
+
+func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
+
+ out, err := x.Flatten(1, -1, false)
+ or_panic(err)
+
+ return out
+}
+
+type Sigmoid struct{}
+
+func NewSigmoid() *Sigmoid {
+ return &Sigmoid{}
+}
+
+func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
+ out, err := x.Sigmoid(false)
+ or_panic(err)
+
+ return out
+}
+
+func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
+ out, err := x.Sigmoid(false)
+ or_panic(err)
+
+ return out
+}
+
diff --git a/logic/models/train/train.go b/logic/models/train/train_normal.go
similarity index 73%
rename from logic/models/train/train.go
rename to logic/models/train/train_normal.go
index 7b18af7..c73ee58 100644
--- a/logic/models/train/train.go
+++ b/logic/models/train/train_normal.go
@@ -8,6 +8,7 @@ import (
"os"
"os/exec"
"path"
+ "runtime/debug"
"sort"
"strconv"
"strings"
@@ -15,12 +16,16 @@ import (
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
- model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
+ my_torch "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
+ modelloader "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/modelloader"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
"github.com/goccy/go-json"
+ "github.com/sugarme/gotch"
+ "github.com/sugarme/gotch/nn"
+ torch "github.com/sugarme/gotch/ts"
)
const EPOCH_PER_RUN = 20
@@ -39,21 +44,6 @@ func getDir() string {
return dir
}
-// This function creates a new model_definition
-func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) {
- var NewDefinition = struct {
- ModelId string `db:"model_id"`
- TargetAccuracy int `db:"target_accuracy"`
- }{ModelId: model_id, TargetAccuracy: target_accuracy}
-
- return InsertReturnId(db, &NewDefinition, "model_definition", "id")
-}
-
-func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionStatus) (err error) {
- _, err = c.GetDb().Exec("update model_definition set status = $1 where id = $2", status, id)
- return
-}
-
func MakeLayer(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
return
@@ -64,48 +54,6 @@ func MakeLayerExpandable(db db.Db, def_id string, layer_order int, layer_type La
return
}
-func generateCvs(c BasePack, run_path string, model_id string) (count int, err error) {
- db := c.GetDb()
-
- var co struct {
- Count int `db:"count(*)"`
- }
- err = GetDBOnce(db, &co, "model_classes where model_id=$1;", model_id)
- if err != nil {
- return
- }
- count = co.Count
-
- data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, DATA_POINT_MODE_TRAINING)
- if err != nil {
- return
- }
- defer data.Close()
-
- f, err := os.Create(path.Join(run_path, "train.csv"))
- if err != nil {
- return
- }
- defer f.Close()
- f.Write([]byte("Id,Index\n"))
-
- for data.Next() {
- var id string
- var class_order int
- var file_path string
- if err = data.Scan(&id, &class_order, &file_path); err != nil {
- return
- }
- if file_path == "id://" {
- f.Write([]byte(id + "," + strconv.Itoa(class_order) + "\n"))
- } else {
- return count, errors.New("TODO generateCvs to file_path " + file_path)
- }
- }
-
- return
-}
-
func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) {
_, err = c.GetDb().Exec(fmt.Sprintf("update model_classes set status=%d where %s", status, filter), args...)
return
@@ -118,14 +66,14 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool)
var co struct {
Count int `db:"count(*)"`
}
- err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
+ err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, CLASS_STATUS_TRAINING)
if err != nil {
return
}
count = co.Count
if count == 0 {
- err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
+ err = setModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, CLASS_STATUS_TO_TRAIN)
if err != nil {
return
}
@@ -137,7 +85,7 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool)
return generateCvsExp(c, run_path, model_id, true)
}
- data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
+ data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, CLASS_STATUS_TRAINING)
if err != nil {
return
}
@@ -167,117 +115,182 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool)
return
}
-func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
- l := c.GetLogger()
+func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_torch.ContainerModel, classes []*ModelClass) (accuracy float64, model *my_torch.ContainerModel, err error) {
+ log := c.GetLogger()
db := c.GetDb()
- l.Warn("About to start training definition")
+ log.Warn("About to start training definition")
+
+ model = in_model
accuracy = 0
- layers, err := db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
- if err != nil {
- return
- }
- defer layers.Close()
- type layerrow struct {
- LayerType int
- Shape string
- LayerNum int
- }
-
- got := []layerrow{}
- i := 1
-
- for layers.Next() {
- var row = layerrow{}
- if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
+ if model == nil {
+ var layers []*Layer
+ layers, err = def.GetLayers(db, " order by layer_order asc")
+ if err != nil {
return
}
- row.Shape = shapeToSize(row.Shape)
- row.LayerNum = 1
- got = append(got, row)
- i = i + 1
+
+ model = my_torch.BuildModel(layers, 0, true)
}
- // Generate run folder
- run_path := path.Join("/tmp", model.Id, "defs", definition_id)
+ // TODO Make the runner provide this
+ // device := gotch.CudaIfAvailable()
+ device := gotch.CPU
- err = os.MkdirAll(run_path, os.ModePerm)
+ result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id)
+ err = os.MkdirAll(result_path, os.ModePerm)
if err != nil {
return
}
- classCount, err := generateCvs(c, run_path, model.Id)
+ model.To(device)
+ defer model.To(gotch.CPU)
+
+ var ds *modelloader.Dataset
+ ds, err = modelloader.NewDataset(db, m, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder)
if err != nil {
return
}
- // Create python script
- f, err := os.Create(path.Join(run_path, "run.py"))
- if err != nil {
- return
- }
- defer f.Close()
-
- tmpl, err := template.New("python_model_template.py").ParseFiles("views/py/python_model_template.py")
+ err = ds.To(device)
if err != nil {
return
}
- // Copy result around
- result_path := path.Join("savedData", model.Id, "defs", definition_id)
-
- if err = tmpl.Execute(f, AnyMap{
- "Layers": got,
- "Size": got[0].Shape,
- "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
- "RunPath": run_path,
- "ColorMode": model.ImageMode,
- "Model": model,
- "EPOCH_PER_RUN": EPOCH_PER_RUN,
- "DefId": definition_id,
- "LoadPrev": load_prev,
- "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
- "SaveModelPath": path.Join(getDir(), result_path),
- "Depth": classCount,
- "StartPoint": 0,
- "Host": c.GetHost(),
- }); err != nil {
- return
- }
-
- // Run the command
- out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
- if err != nil {
- l.Debug(string(out))
- return
- }
-
- l.Info("Python finished running")
-
- if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
- return
- }
-
- accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
- if err != nil {
- return
- }
- defer accuracy_file.Close()
-
- accuracy_file_bytes, err := io.ReadAll(accuracy_file)
+ opt, err := nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
- accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
+ for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ {
+ var trainIter *torch.Iter2
+ trainIter, err = ds.TrainIter(64)
+ if err != nil {
+ return
+ }
+ // trainIter.ToDevice(device)
+
+ log.Info("epoch", "epoch", epoch)
+
+ var trainLoss float64 = 0
+ var trainCorrect float64 = 0
+ ok := true
+ for ok {
+ var item torch.Iter2Item
+ var loss *torch.Tensor
+ item, ok = trainIter.Next()
+ if !ok {
+ continue
+ }
+
+ pred := model.ForwardT(item.Data, true)
+
+ // Calculate loss
+
+ loss, err = pred.BinaryCrossEntropyWithLogits(item.Label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
+ if err != nil {
+ return
+ }
+
+ loss, err = loss.SetRequiresGrad(true, false)
+ if err != nil {
+ return
+ }
+
+ err = opt.ZeroGrad()
+ if err != nil {
+ return
+ }
+
+ err = loss.Backward()
+ if err != nil {
+ return
+ }
+
+ err = opt.Step()
+ if err != nil {
+ return
+ }
+
+ trainLoss = loss.Float64Values()[0]
+
+ // Calculate accuracy
+
+ var p_pred, p_labels *torch.Tensor
+ p_pred, err = pred.Argmax([]int64{1}, true, false)
+ if err != nil {
+ return
+ }
+
+ p_labels, err = item.Label.Argmax([]int64{1}, true, false)
+ if err != nil {
+ return
+ }
+
+ floats := p_pred.Float64Values()
+ floats_labels := p_labels.Float64Values()
+
+ for i := range floats {
+ if floats[i] == floats_labels[i] {
+ trainCorrect += 1
+ }
+ }
+ }
+
+ log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize))
+
+ /*correct := int64(0)
+ //torch.NoGrad(func() {
+ ok = true
+ testIter := ds.TestIter(64)
+ for ok {
+ var item torch.Iter2Item
+ item, ok = testIter.Next()
+ if !ok {
+ continue
+ }
+
+ output := model.Forward(item.Data)
+
+ var pred, labels *torch.Tensor
+ pred, err = output.Argmax([]int64{1}, true, false)
+ if err != nil {
+ return
+ }
+
+ labels, err = item.Label.Argmax([]int64{1}, true, false)
+ if err != nil {
+ return
+ }
+
+ floats := pred.Float64Values()
+ floats_labels := labels.Float64Values()
+
+ for i := range floats {
+ if floats[i] == floats_labels[i] {
+ correct += 1
+ }
+ }
+ }
+
+ accuracy = float64(correct) / float64(ds.TestImagesSize)
+
+ log.Info("Eval accuracy", "accuracy", accuracy)
+
+ err = def.UpdateAfterEpoch(db, accuracy*100)
+ if err != nil {
+ return
+ }*/
+ //})
+ }
+
+ err = my_torch.SaveModel(model, path.Join(result_path, "model.dat"))
if err != nil {
return
}
- os.RemoveAll(run_path)
-
- l.Info("Model finished training!", "accuracy", accuracy)
+ log.Info("Model finished training!", "accuracy", accuracy)
return
}
@@ -287,7 +300,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i
var co struct {
Count int `db:"count(*)"`
}
- err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
+ err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, CLASS_STATUS_TRAINING)
if err != nil {
return
}
@@ -296,7 +309,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i
count := co.Count
if count == 0 {
- err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
+ err = setModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, CLASS_STATUS_TO_TRAIN)
if err != nil {
return
} else if doPanic {
@@ -305,7 +318,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i
return generateCvsExpandExp(c, run_path, model_id, offset, true)
}
- data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
+ data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, CLASS_STATUS_TRAINING)
if err != nil {
return
}
@@ -339,7 +352,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i
// This is to load some extra data so that the model has more things to train on
//
- data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
+ data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, CLASS_STATUS_TRAINED, count*10)
if err != nil {
return
}
@@ -392,7 +405,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string
l.Info("Got exp head", "head", exp)
- if err = UpdateStatus(c.GetDb(), "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
+ if err = UpdateStatus(c.GetDb(), "exp_model_head", exp.Id, DEFINITION_STATUS_TRAINING); err != nil {
return
}
@@ -568,7 +581,7 @@ func trainDefinitionExp(c BasePack, model *BaseModel, definition_id string, load
exp := heads[0]
- if err = UpdateStatus(db, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
+ if err = UpdateStatus(db, "exp_model_head", exp.Id, DEFINITION_STATUS_TRAINING); err != nil {
return
}
@@ -731,70 +744,67 @@ func (nf ToRemoveList) Less(i, j int) bool {
func trainModel(c BasePack, model *BaseModel) (err error) {
db := c.GetDb()
- l := c.GetLogger()
+ log := c.GetLogger()
- definitionsRows, err := db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
- if err != nil {
- l.Error("Failed to train Model! Err:")
- l.Error(err)
+ fail := func(err error) {
+ log.Error("Failed to train Model!", "err", err, "stack", string(debug.Stack()))
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
+ }
+
+ defs, err := model.GetDefinitions(db, "and md.status=$2", DEFINITION_STATUS_INIT)
+ if err != nil {
+ fail(err)
return
}
- defer definitionsRows.Close()
- var definitions TraingModelRowDefinitions = []TrainModelRow{}
-
- for definitionsRows.Next() {
- var rowv TrainModelRow
- rowv.acuracy = 0
- if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
- l.Error("Failed to train Model Could not read definition from db!Err:")
- l.Error(err)
- ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
- return
- }
- definitions = append(definitions, rowv)
- }
+ var definitions SortByAccuracyDefinitions = defs
if len(definitions) == 0 {
- l.Error("No Definitions defined!")
- ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
+ fail(errors.New("No definitons defined!"))
return
}
- firstRound := true
finished := false
+ models := map[string]*my_torch.ContainerModel{}
+
+ classes, err := model.GetClasses(db, " and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN)
+
for {
+ // Keep track of definitions that did not train fast enough
var toRemove ToRemoveList = []int{}
+
for i, def := range definitions {
- ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
- accuracy, err := trainDefinition(c, model, def.id, !firstRound)
+
+ err := def.UpdateStatus(c, DEFINITION_STATUS_TRAINING)
if err != nil {
- l.Error("Failed to train definition!Err:", "err", err)
- ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ log.Error("Could not make model into training", "err", err)
+ def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
- def.epoch += EPOCH_PER_RUN
- accuracy = accuracy * 100
- def.acuracy = float64(accuracy)
- definitions[i].epoch += EPOCH_PER_RUN
- definitions[i].acuracy = accuracy
+ accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes)
+ if err != nil {
+ log.Error("Failed to train definition!Err:", "err", err)
+ def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
+ toRemove = append(toRemove, i)
+ continue
+ }
+ models[def.Id] = ml_model
- if accuracy >= float64(def.target_accuracy) {
- l.Info("Found a definition that reaches target_accuracy!")
- _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
+ if accuracy >= float64(def.TargetAccuracy) {
+ log.Info("Found a definition that reaches target_accuracy!")
+ _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
if err != nil {
- l.Error("Failed to train definition!Err:\n", "err", err)
+ log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
- _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
- l.Error("Failed to train definition!Err:\n", "err", err)
+ log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
@@ -803,31 +813,32 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
break
}
- if def.epoch > MAX_EPOCH {
- fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy)
- ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ if def.Epoch > MAX_EPOCH {
+ fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
+ def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
- _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
+ _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
if err != nil {
- l.Error("Failed to train definition!Err:\n", "err", err)
+ log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
}
- firstRound = false
if finished {
break
}
sort.Sort(sort.Reverse(toRemove))
- l.Info("Round done", "toRemove", toRemove)
+ log.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove {
+ // Clean up unsed models
+ models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
@@ -843,62 +854,50 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
sort.Sort(sort.Reverse(definitions))
- acc := definitions[0].acuracy - 20.0
+ acc := definitions[0].Accuracy - 20.0
- l.Info("Training models, Highest acc", "acc", definitions[0].acuracy, "mod_acc", acc)
+ log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc)
toRemove = []int{}
for i, def := range definitions {
- if def.acuracy < acc {
+ if def.Accuracy < acc {
toRemove = append(toRemove, i)
}
}
- l.Info("Removing due to accuracy", "toRemove", toRemove)
+ log.Info("Removing due to accuracy", "toRemove", toRemove)
sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove {
- l.Warn("Removing definition not fast enough learning", "n", n)
- ModelDefinitionUpdateStatus(c, definitions[n].id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ log.Warn("Removing definition not fast enough learning", "n", n)
+ definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
+ models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
}
- rows, err := db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
+ var def Definition
+ err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED)
if err != nil {
- l.Error("DB: failed to read definition")
- l.Error(err)
- ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
- return
- }
- defer rows.Close()
-
- if !rows.Next() {
- // TODO Make the Model status have a message
- l.Error("All definitions failed to train!")
+ if err == NotFoundError {
+ log.Error("All definitions failed to train!")
+ } else {
+ log.Error("DB: failed to read definition", "err", err)
+ }
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
- var id string
- if err = rows.Scan(&id); err != nil {
- l.Error("Failed to read id:")
- l.Error(err)
+ if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
+ log.Error("Failed to update model definition", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
- if _, err = db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil {
- l.Error("Failed to update model definition")
- l.Error(err)
- ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
- return
- }
-
- to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
+ to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id)
if err != nil {
- l.Error("Failed to select model_definition to delete")
- l.Error(err)
+ log.Error("Failed to select model_definition to delete")
+ log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
@@ -907,8 +906,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
for to_delete.Next() {
var id string
if err = to_delete.Scan(&id); err != nil {
- l.Error("Failed to scan the id of a model_definition to delete")
- l.Error(err)
+ log.Error("Failed to scan the id of a model_definition to delete", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
@@ -916,9 +914,9 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
}
// TODO Check if returning also works here
- if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
- l.Error("Failed to delete model_definition")
- l.Error(err)
+ if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil {
+ log.Error("Failed to delete model_definition")
+ log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
@@ -949,7 +947,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
var definitions TrainModelRowUsables
- definitions, err = GetDbMultitple[TrainModelRowUsable](db, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
+ definitions, err = GetDbMultitple[TrainModelRowUsable](db, "model_definition where status=$1 and model_id=$2", DEFINITION_STATUS_INIT, model.Id)
if err != nil {
l.Error("Failed to get definitions")
return
@@ -965,11 +963,11 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
for {
var toRemove ToRemoveList = []int{}
for i, def := range definitions {
- ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING)
+ Definition{Id: def.Id}.UpdateStatus(c, DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinitionExp(c, model, def.Id, !firstRound)
if err != nil {
l.Error("Failed to train definition!Err:", "err", err)
- ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ Definition{Id: def.Id}.UpdateStatus(c, DEFINITION_STATUS_TRAINING)
toRemove = append(toRemove, i)
continue
}
@@ -982,13 +980,13 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
if accuracy >= float64(def.TargetAccuracy) {
l.Info("Found a definition that reaches target_accuracy!")
- _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
+ _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
if err != nil {
l.Error("Failed to train definition!")
return err
}
- _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
l.Error("Failed to train definition!")
return err
@@ -1006,12 +1004,13 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
if def.Epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
- ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+
+ Definition{Id: def.Id}.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
- _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
+ _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
if err != nil {
l.Error("Failed to train definition!")
return err
@@ -1056,17 +1055,18 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove {
l.Warn("Removing definition not fast enough learning", "n", n)
- ModelDefinitionUpdateStatus(c, definitions[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+
+ Definition{Id: definitions[n].Id}.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
definitions = remove(definitions, n)
}
}
var dat JustId
- err = GetDBOnce(db, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
+ err = GetDBOnce(db, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// Set the class status to trained
- err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
+ err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("All definitions failed to train! And Failed to set class status")
return err
@@ -1079,12 +1079,12 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return err
}
- if _, err = db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, dat.Id); err != nil {
+ if _, err = db.Exec("update model_definition set status=$1 where id=$2;", DEFINITION_STATUS_READY, dat.Id); err != nil {
l.Error("Failed to update model definition")
return err
}
- to_delete, err := GetDbMultitple[JustId](db, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
+ to_delete, err := GetDbMultitple[JustId](db, "model_definition where status!=$1 and model_id=$2", DEFINITION_STATUS_READY, model.Id)
if err != nil {
l.Error("Failed to select model_definition to delete")
return err
@@ -1095,13 +1095,13 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
}
// TODO Check if returning also works here
- if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
+ if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil {
l.Error("Failed to delete model_definition")
return err
}
if err = splitModel(c, model); err != nil {
- err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
+ err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("Failed to split the model! And Failed to set class status")
return err
@@ -1112,7 +1112,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
}
// Set the class status to trained
- err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
+ err = setModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("Failed to set class status")
return err
@@ -1226,36 +1226,6 @@ func splitModel(c BasePack, model *BaseModel) (err error) {
return
}
-func removeFailedDataPoints(c BasePack, model *BaseModel) (err error) {
- rows, err := c.GetDb().Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id)
- if err != nil {
- return
- }
- defer rows.Close()
-
- base_path := path.Join("savedData", model.Id, "data")
-
- for rows.Next() {
- var dataPointId string
- err = rows.Scan(&dataPointId)
- if err != nil {
- return
- }
-
- p := path.Join(base_path, dataPointId+"."+model.Format)
-
- c.GetLogger().Warn("Removing image", "path", p)
-
- err = os.RemoveAll(p)
- if err != nil {
- return
- }
- }
-
- _, err = c.GetDb().Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id)
- return
-}
-
// This generates a definition
func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) (err error) {
failed := func() {
@@ -1265,7 +1235,7 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
db := c.GetDb()
l := c.GetLogger()
- def_id, err := MakeDefenition(db, model.Id, target_accuracy)
+ def, err := MakeDefenition(db, model.Id, target_accuracy)
if err != nil {
failed()
return
@@ -1274,28 +1244,16 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
order := 1
// Note the shape of the first layer defines the import size
- if complexity == 2 {
- // Note the shape for now is no used
- width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
- height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
- l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
- err = MakeLayer(db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height))
- if err != nil {
- failed()
- return
- }
- order++
- } else {
- err = MakeLayer(db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height))
- if err != nil {
- failed()
- return
- }
- order++
+ //_, err = def.MakeLayer(db, order, LAYER_INPUT, ShapeToString(model.Width, model.Height, model.ImageMode))
+ _, err = def.MakeLayer(db, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height))
+ if err != nil {
+ failed()
+ return
}
+ order++
if complexity == 0 {
- err = MakeLayer(db, def_id, order, LAYER_FLATTEN, "")
+ _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
if err != nil {
failed()
return
@@ -1304,22 +1262,17 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ {
- err = MakeLayer(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
+ _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
return
}
}
-
} else if complexity == 1 || complexity == 2 {
-
- loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
- if loop == 0 {
- loop = 1
- }
+ loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
for i := 0; i < loop; i++ {
- err = MakeLayer(db, def_id, order, LAYER_SIMPLE_BLOCK, "")
+ _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
order++
if err != nil {
failed()
@@ -1327,7 +1280,7 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
}
}
- err = MakeLayer(db, def_id, order, LAYER_FLATTEN, "")
+ _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
if err != nil {
failed()
return
@@ -1339,7 +1292,7 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
loop = 1
}
for i := 0; i < loop; i++ {
- err = MakeLayer(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
+ _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
failed()
@@ -1347,57 +1300,47 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
}
}
} else {
- log.Error("Unkown complexity", "complexity", complexity)
+ l.Error("Unkown complexity", "complexity", complexity)
failed()
return
}
- err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
- if err != nil {
- failed()
- return
- }
-
- return nil
+ return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
}
func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, number_of_models int) (err error) {
- cls, err := model_classes.ListClasses(c, model.Id)
+ cls, err := model.GetClasses(c, "")
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
return
}
- err = removeFailedDataPoints(c, model)
- if err != nil {
- return
- }
-
cls_len := len(cls)
if number_of_models == 1 {
if model.Width < 100 && model.Height < 100 && cls_len < 30 {
- generateDefinition(c, model, target_accuracy, cls_len, 0)
+ err = generateDefinition(c, model, target_accuracy, cls_len, 0)
} else if model.Width > 100 && model.Height > 100 {
- generateDefinition(c, model, target_accuracy, cls_len, 2)
+ err = generateDefinition(c, model, target_accuracy, cls_len, 2)
} else {
- generateDefinition(c, model, target_accuracy, cls_len, 1)
+ err = generateDefinition(c, model, target_accuracy, cls_len, 1)
}
- } else if number_of_models == 3 {
- for i := 0; i < number_of_models; i++ {
- generateDefinition(c, model, target_accuracy, cls_len, i)
+ if err != nil {
+ return
}
} else {
- // TODO handle incrisea the complexity
for i := 0; i < number_of_models; i++ {
- generateDefinition(c, model, target_accuracy, cls_len, 0)
+ err = generateDefinition(c, model, target_accuracy, cls_len, min(i, 2))
+ if err != nil {
+ return
+ }
}
}
return nil
}
-func ExpModelHeadUpdateStatus(db db.Db, id string, status ModelDefinitionStatus) (err error) {
+func ExpModelHeadUpdateStatus(db db.Db, id string, status DefinitionStatus) (err error) {
_, err = db.Exec("update model_definition set status = $1 where id = $2", status, id)
return
}
@@ -1417,7 +1360,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
return
}
- def_id, err := MakeDefenition(c.GetDb(), model.Id, target_accuracy)
+ def, err := MakeDefenition(c.GetDb(), model.Id, target_accuracy)
if err != nil {
failed()
return
@@ -1436,7 +1379,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
}
- err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
+ err = MakeLayerExpandable(c.GetDb(), def.Id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
order++
@@ -1458,7 +1401,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
//loop = max(loop, 3)
for i := 0; i < loop; i++ {
- err = MakeLayerExpandable(db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
+ err = MakeLayerExpandable(db, def.Id, order, LAYER_SIMPLE_BLOCK, "", 1)
order++
if err != nil {
failed()
@@ -1467,7 +1410,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
}
// Flatten the blocks into dense
- err = MakeLayerExpandable(db, def_id, order, LAYER_FLATTEN, "", 1)
+ err = MakeLayerExpandable(db, def.Id, order, LAYER_FLATTEN, "", 1)
if err != nil {
failed()
return
@@ -1475,7 +1418,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order++
// Flatten the blocks into dense
- err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1)
+ err = MakeLayerExpandable(db, def.Id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1)
if err != nil {
failed()
return
@@ -1489,7 +1432,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
// loop = max(loop, 3)
for i := 0; i < loop; i++ {
- err = MakeLayer(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
+ err = MakeLayer(db, def.Id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
order++
if err != nil {
failed()
@@ -1498,12 +1441,12 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
}
var newHead = struct {
- DefId string `db:"def_id"`
- RangeStart int `db:"range_start"`
- RangeEnd int `db:"range_end"`
- Status ModelDefinitionStatus `db:"status"`
+ DefId string `db:"def_id"`
+ RangeStart int `db:"range_start"`
+ RangeEnd int `db:"range_end"`
+ Status DefinitionStatus `db:"status"`
}{
- def_id, 0, number_of_classes - 1, MODEL_DEFINITION_STATUS_INIT,
+ def.Id, 0, number_of_classes - 1, DEFINITION_STATUS_INIT,
}
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
if err != nil {
@@ -1511,7 +1454,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
return
}
- err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT)
+ err = def.UpdateStatus(c, DEFINITION_STATUS_INIT)
if err != nil {
failed()
return
@@ -1522,15 +1465,9 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
// TODO make this json friendy
func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy int, number_of_models int) (err error) {
- cls, err := model_classes.ListClasses(c, model.Id)
+ cls, err := model.GetClasses(c, "")
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
- // TODO improve this response
- return
- }
-
- err = removeFailedDataPoints(c, model)
- if err != nil {
return
}
@@ -1557,7 +1494,7 @@ func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy
}
func ResetClasses(c BasePack, model *BaseModel) {
- _, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
+ _, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.GetLogger().Error("Error while reseting the classes", "error", err)
}
@@ -1574,7 +1511,7 @@ func trainExpandable(c *Context, model *BaseModel) {
var definitions TrainModelRowUsables
- definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
+ definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", DEFINITION_STATUS_READY, model.Id)
if err != nil {
failed("Failed to get definitions")
return
@@ -1612,7 +1549,7 @@ func trainExpandable(c *Context, model *BaseModel) {
}
// Set the class status to trained
- err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
+ err = setModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
@@ -1629,9 +1566,7 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information")
l.Error("Failed to get model information", "err", err)
return err
- }
-
- if model.Status != TRAINING {
+ } else if model.Status != TRAINING {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training")
return errors.New("Model not in the right status")
}
@@ -1649,6 +1584,10 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
}
if model.ModelType == 2 {
+ task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "TODO expandable models")
+ ModelUpdateStatus(b, model.Id, FAILED_TRAINING)
+ panic("todo")
+
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil {
l.Error("Failed to generate defintions", "err", full_error)
@@ -1658,6 +1597,7 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
} else {
full_error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil {
+ l.Error("Failed to generate defintions", "err", full_error)
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
return errors.New("Failed to generate definitions")
}
@@ -1682,6 +1622,9 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
}
func RunTaskRetrain(b BasePack, task Task) (err error) {
+ task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "TODO retrain with torch")
+ panic("TODO")
+
model, err := GetBaseModel(b.GetDb(), *task.ModelId)
if err != nil {
return err
@@ -1743,7 +1686,7 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
l.Info("Model updaded")
- _, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
+ _, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", CLASS_STATUS_TRAINED, CLASS_STATUS_TRAINING, model.Id)
if err != nil {
l.Error("Error while updating the classes", "error", err)
failed()
@@ -1772,7 +1715,7 @@ func handleTrain(handle *Handle) {
}
model, err := GetBaseModel(c.Db, dat.Id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Failed to get model information", err)
@@ -1826,7 +1769,7 @@ func handleTrain(handle *Handle) {
PostAuthJson(handle, "/model/train/retrain", User_Normal, func(c *Context, dat *JustId) *Error {
model, err := GetBaseModel(c.Db, dat.Id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Faield to get model", err)
@@ -1873,7 +1816,7 @@ func handleTrain(handle *Handle) {
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
- MODEL_CLASS_STATUS_TO_TRAIN,
+ CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
@@ -1894,7 +1837,7 @@ func handleTrain(handle *Handle) {
//Update the classes
{
- _, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
+ _, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", CLASS_STATUS_TRAINING, CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
@@ -1916,12 +1859,12 @@ func handleTrain(handle *Handle) {
}
var newHead = struct {
- DefId string `db:"def_id"`
- RangeStart int `db:"range_start"`
- RangeEnd int `db:"range_end"`
- Status ModelDefinitionStatus `db:"status"`
+ DefId string `db:"def_id"`
+ RangeStart int `db:"range_start"`
+ RangeEnd int `db:"range_end"`
+ Status DefinitionStatus `db:"status"`
}{
- def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT,
+ def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, DEFINITION_STATUS_INIT,
}
_, err = InsertReturnId(c.GetDb(), &newHead, "exp_model_head", "id")
diff --git a/logic/stats/tasks.go b/logic/stats/tasks.go
index c6fcb72..93749b8 100644
--- a/logic/stats/tasks.go
+++ b/logic/stats/tasks.go
@@ -14,7 +14,7 @@ func handleTasksStats(handle *Handle) {
}
PostAuthJson(handle, "/stats/task/model/day", User_Normal, func(c *Context, dat *ModelTasksStatsRequest) *Error {
model, err := GetBaseModel(c, dat.ModelId)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Model not found!")
} else if err != nil {
return c.E500M("Failed to get model", err)
diff --git a/logic/tasks/agreement.go b/logic/tasks/agreement.go
index bae6f65..60cc596 100644
--- a/logic/tasks/agreement.go
+++ b/logic/tasks/agreement.go
@@ -14,7 +14,7 @@ func handleRequests(x *Handle) {
PostAuthJson(x, "/task/agreement", User_Normal, func(c *Context, dat *AgreementRequest) *Error {
var task Task
err := GetDBOnce(c, &task, "tasks where id=$1", dat.Id)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.E500M("Failed to get task data", err)
diff --git a/logic/tasks/list.go b/logic/tasks/list.go
index 019621f..0ff0423 100644
--- a/logic/tasks/list.go
+++ b/logic/tasks/list.go
@@ -46,7 +46,7 @@ func handleList(handler *Handle) {
if requestData.ModelId != "" {
_, err := GetBaseModel(c.Db, requestData.ModelId)
- if err == ModelNotFoundError {
+ if err == NotFoundError {
return c.SendJSONStatus(404, "Model not found!")
} else if err != nil {
return c.Error500(err)
diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go
index 4e6f967..a522efa 100644
--- a/logic/tasks/runner/runner.go
+++ b/logic/tasks/runner/runner.go
@@ -11,7 +11,8 @@ import (
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
- . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
+
+ // . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/users"
@@ -52,9 +53,10 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
logger.Info("Classification Task")
- if err = ClassifyTask(base, task); err != nil {
+ /*if err = ClassifyTask(base, task); err != nil {
logger.Error("Classification task failed", "error", err)
- }
+ }*/
+ task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "TODO move tasks to pytorch")
back_channel <- index
continue
diff --git a/logic/utils/handler.go b/logic/utils/handler.go
index ebf79c9..7d06043 100644
--- a/logic/utils/handler.go
+++ b/logic/utils/handler.go
@@ -392,7 +392,7 @@ func (c *Context) GetModelFromId(id_path string) (*dbtypes.BaseModel, *Error) {
}
model, err := dbtypes.GetBaseModel(c.Db, id)
- if err == dbtypes.ModelNotFoundError {
+ if err == dbtypes.NotFoundError {
return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
} else if err != nil {
return nil, c.Error500(err)
diff --git a/run.sh b/run.sh
new file mode 100644
index 0000000..393055b
--- /dev/null
+++ b/run.sh
@@ -0,0 +1,2 @@
+podman run --rm --network host --gpus all -ti -v (pwd):/app -e "TERM=xterm-256color" fyp-server bash
+
diff --git a/webpage/src/routes/models/edit/+page.svelte b/webpage/src/routes/models/edit/+page.svelte
index d3af0c9..f3a2d2f 100644
--- a/webpage/src/routes/models/edit/+page.svelte
+++ b/webpage/src/routes/models/edit/+page.svelte
@@ -215,7 +215,7 @@
{:else if m.status == -3 || m.status == -4}