Compare commits

..

5 Commits
runner ... main

Author SHA1 Message Date
652542d261 Improved classification performance 2024-05-15 05:32:49 +01:00
516d1d7634 added docker compose to run everything in one go 2024-05-12 15:29:36 +01:00
0c0d16c846 multiple fixes 2024-05-11 01:11:07 +01:00
0ac6ac8dce runner-go (#102)
Reviewed-on: #102
Co-authored-by: Andre Henriques <andr3h3nriqu3s@gmail.com>
Co-committed-by: Andre Henriques <andr3h3nriqu3s@gmail.com>
2024-05-10 02:13:02 +01:00
edd1e4c123 fix tensorflow version 2024-05-06 18:17:15 +01:00
92 changed files with 6464 additions and 4320 deletions

View File

@ -1,6 +1,6 @@
# vi: ft=dockerfile # vi: ft=dockerfile
FROM docker.io/nginx FROM docker.io/nginx
ADD nginx.dev.conf /nginx.conf ADD nginx.proxy.conf /nginx.conf
CMD ["nginx", "-c", "/nginx.conf", "-g", "daemon off;"] CMD ["nginx", "-c", "/nginx.conf", "-g", "daemon off;"]

View File

@ -14,7 +14,7 @@ ENV PATH=$PATH:/usr/local/go/bin
ENV GOPATH=/go ENV GOPATH=/go
RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | tar -C /usr/local -xz' RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | tar -C /usr/local -xz'
# RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.13.1.tar.gz" | tar -C /usr/local -xz' RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.15.0.tar.gz" | tar -C /usr/local -xz'
RUN ldconfig RUN ldconfig
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
@ -31,7 +31,10 @@ ADD go.mod .
ADD go.sum . ADD go.sum .
ADD main.go . ADD main.go .
ADD logic logic ADD logic logic
ADD entrypoint.sh .
RUN go install || true RUN go install || true
CMD ["go", "run", "."] RUN go build .
CMD ["./entrypoint.sh"]

42
README.md Normal file
View File

@ -0,0 +1,42 @@
# Configure the system
Go to the config.toml file and setup your hostname
# Build the containers
Running this commands on the root of the project will setup the nessesary.
Make sure that your docker/podman installation supports domain name resolution between containers
```bash
docker build -t andre-fyp-proxy -f DockerfileProxy
docker build -t andre-fyp-server -f DockerfileServer
cd webpage
docker build -t andre-fyp-web-server .
cd ..
```
# Run the docker compose
Running docker compose sets up the database server, the web page server, the proxy server and the main server
```bash
docker compose up
```
# Setup the Database
On another terminal instance create the database and tables.
Note: the password can be changed in the docker-compose file
```bash
PGPASSWORD=verysafepassword psql -h localhost -U postgres -f sql/base.sql
PGPASSWORD=verysafepassword psql -h localhost -U postgres -d fyp -f sql/user.sql
PGPASSWORD=verysafepassword psql -h localhost -U postgres -d fyp -f sql/models.sql
PGPASSWORD=verysafepassword psql -h localhost -U postgres -d fyp -f sql/tasks.sql
```
# Restart docker compose
Now restart docker compose and the system should be available under the domain name set up on the config.toml file

View File

@ -12,7 +12,12 @@ USER = "service"
[Worker] [Worker]
PULLING_TIME = "500ms" PULLING_TIME = "500ms"
NUMBER_OF_WORKERS = 1 NUMBER_OF_WORKERS = 16
[DB] [DB]
MAX_CONNECTIONS = 600 MAX_CONNECTIONS = 600
host = "db"
port = 5432
user = "postgres"
password = "verysafepassword"
dbname = "fyp"

View File

@ -1,11 +1,44 @@
version: "3.1"
services: services:
db: db:
image: docker.andr3h3nriqu3s.com/services/postgres image: docker.io/postgres:16.3
command: -c 'max_connections=600' command: -c 'max_connections=600'
restart: always restart: always
environment: networks:
POSTGRES_PASSWORD: verysafepassword - fyp-network
ports: environment:
- "5432:5432" POSTGRES_PASSWORD: verysafepassword
ports:
- "5432:5432"
web-page:
image: andre-fyp-web-server
hostname: webpage
networks:
- fyp-network
server:
image: andre-fyp-server
hostname: server
networks:
- fyp-network
depends_on:
- db
volumes:
- "./config.toml:/app/config.toml"
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
proxy-server:
image: andre-fyp-proxy
networks:
- fyp-network
ports:
- "8000:8000"
depends_on:
- web-page
- server
networks:
fyp-network: {}

4
entrypoint.sh Executable file
View File

@ -0,0 +1,4 @@
#/bin/bash
while true; do
./fyp
done

4
go.mod
View File

@ -9,10 +9,11 @@ require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
golang.org/x/crypto v0.19.0 golang.org/x/crypto v0.19.0
github.com/BurntSushi/toml v1.3.2
github.com/goccy/go-json v0.10.2
) )
require ( require (
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/lipgloss v0.9.1 // indirect github.com/charmbracelet/lipgloss v0.9.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
@ -20,7 +21,6 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.19.0 // indirect github.com/go-playground/validator/v10 v10.19.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx v3.6.2+incompatible // indirect github.com/jackc/pgx v3.6.2+incompatible // indirect

View File

@ -87,9 +87,9 @@ func (d Definition) GetLayers(db db.Db, filter string, args ...any) (layer []*La
return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...) return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...)
} }
func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64) (err error) { func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64, epoch int) (err error) {
d.Accuracy = accuracy d.Accuracy = accuracy
d.Epoch += 1 d.Epoch += epoch
_, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id) _, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id)
return return
} }

View File

@ -2,8 +2,10 @@ package dbtypes
import ( import (
"encoding/json" "encoding/json"
"fmt"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
"github.com/charmbracelet/log"
) )
type LayerType int type LayerType int
@ -18,15 +20,28 @@ const (
type Layer struct { type Layer struct {
Id string `db:"mdl.id" json:"id"` Id string `db:"mdl.id" json:"id"`
DefinitionId string `db:"mdl.def_id" json:"definition_id"` DefinitionId string `db:"mdl.def_id" json:"definition_id"`
LayerOrder string `db:"mdl.layer_order" json:"layer_order"` LayerOrder int `db:"mdl.layer_order" json:"layer_order"`
LayerType LayerType `db:"mdl.layer_type" json:"layer_type"` LayerType LayerType `db:"mdl.layer_type" json:"layer_type"`
Shape string `db:"mdl.shape" json:"shape"` Shape string `db:"mdl.shape" json:"shape"`
ExpType string `db:"mdl.exp_type" json:"exp_type"` ExpType int `db:"mdl.exp_type" json:"exp_type"`
}
func (x *Layer) ShapeToSize() {
v := x.GetShape()
switch x.LayerType {
case LAYER_INPUT:
x.Shape = fmt.Sprintf("%d,%d", v[1], v[2])
case LAYER_DENSE:
x.Shape = fmt.Sprintf("(%d)", v[0])
default:
x.Shape = "ERROR"
}
} }
func ShapeToString(args ...int) string { func ShapeToString(args ...int) string {
text, err := json.Marshal(args) text, err := json.Marshal(args)
if err != nil { if err != nil {
log.Error("json err!", "err", err)
panic("Could not generate Shape") panic("Could not generate Shape")
} }
return string(text) return string(text)
@ -35,12 +50,16 @@ func ShapeToString(args ...int) string {
func StringToShape(str string) (shape []int64) { func StringToShape(str string) (shape []int64) {
err := json.Unmarshal([]byte(str), &shape) err := json.Unmarshal([]byte(str), &shape)
if err != nil { if err != nil {
log.Error("json err!", "err", err)
panic("Could not parse Shape") panic("Could not parse Shape")
} }
return return
} }
func (l Layer) GetShape() []int64 { func (l Layer) GetShape() []int64 {
if l.Shape == "" {
return []int64{}
}
return StringToShape(l.Shape) return StringToShape(l.Shape)
} }

View File

@ -3,7 +3,6 @@ package dbtypes
import ( import (
"errors" "errors"
"fmt" "fmt"
"path"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
) )
@ -51,17 +50,16 @@ const (
) )
type BaseModel struct { type BaseModel struct {
Name string Name string `json:"name"`
Status int Status int `json:"status"`
Id string Id string `json:"id"`
ModelType int `db:"model_type" json:"model_type"`
ModelType int `db:"model_type"` ImageModeRaw string `db:"color_mode" json:"image_more_raw"`
ImageModeRaw string `db:"color_mode"` ImageMode int `db:"0" json:"image_mode"`
ImageMode int `db:"0"` Width int `json:"width"`
Width int Height int `json:"height"`
Height int Format string `json:"format"`
Format string CanTrain int `db:"can_train" json:"can_train"`
CanTrain int `db:"can_train"`
} }
var ModelNotFoundError = errors.New("Model not found error") var ModelNotFoundError = errors.New("Model not found error")
@ -102,6 +100,7 @@ func (m *BaseModel) UpdateStatus(db db.Db, status ModelStatus) (err error) {
} }
type DataPoint struct { type DataPoint struct {
Id string `json:"id"`
Class int `json:"class"` Class int `json:"class"`
Path string `json:"path"` Path string `json:"path"`
} }
@ -126,14 +125,11 @@ func (m BaseModel) DataPoints(db db.Db, mode DATA_POINT_MODE) (data []DataPoint,
if err = rows.Scan(&id, &class_order, &file_path); err != nil { if err = rows.Scan(&id, &class_order, &file_path); err != nil {
return return
} }
if file_path == "id://" { data = append(data, DataPoint{
data = append(data, DataPoint{ Id: id,
Path: path.Join("./savedData", m.Id, "data", id+"."+m.Format), Path: file_path,
Class: class_order, Class: class_order,
}) })
} else {
panic("TODO remote file path")
}
} }
return return
} }

View File

@ -14,10 +14,19 @@ const (
) )
type User struct { type User struct {
Id string `db:"u.id"` Id string `db:"u.id" json:"id"`
Username string `db:"u.username"` Username string `db:"u.username" json:"username"`
Email string `db:"u.email"` Email string `db:"u.email" json:"email"`
UserType int `db:"u.user_type"` UserType int `db:"u.user_type" json:"user_type"`
}
func UserFromId(db db.Db, id string) (*User, error) {
var user User
err := GetDBOnce(db, &user, "users as u where u.id=$1", id)
if err != nil {
return nil, err
}
return &user, nil
} }
func UserFromToken(db db.Db, token string) (*User, error) { func UserFromToken(db db.Db, token string) (*User, error) {

View File

@ -16,7 +16,6 @@ import (
) )
func loadBaseImage(c *Context, id string) { func loadBaseImage(c *Context, id string) {
// TODO handle more types than png
infile, err := os.Open(path.Join("savedData", id, "baseimage.png")) infile, err := os.Open(path.Join("savedData", id, "baseimage.png"))
if err != nil { if err != nil {
c.Logger.Errorf("Failed to read image for model with id %s\n", id) c.Logger.Errorf("Failed to read image for model with id %s\n", id)
@ -54,21 +53,29 @@ func loadBaseImage(c *Context, id string) {
model_color = "greyscale" model_color = "greyscale"
case color.NRGBAModel: case color.NRGBAModel:
fallthrough fallthrough
case color.RGBAModel:
fallthrough
case color.YCbCrModel: case color.YCbCrModel:
model_color = "rgb" model_color = "rgb"
default: default:
c.Logger.Error("Do not know how to handle this color model") c.Logger.Error("Do not know how to handle this color model")
if src.ColorModel() == color.RGBA64Model { if src.ColorModel() == color.RGBA64Model {
c.Logger.Error("Color is rgb") c.Logger.Error("Color is rgb 64")
} else if src.ColorModel() == color.NRGBA64Model { } else if src.ColorModel() == color.NRGBA64Model {
c.Logger.Error("Color is nrgb 64") c.Logger.Error("Color is nrgb 64")
} else if src.ColorModel() == color.AlphaModel { } else if src.ColorModel() == color.AlphaModel {
c.Logger.Error("Color is alpha") c.Logger.Error("Color is alpha")
} else if src.ColorModel() == color.CMYKModel { } else if src.ColorModel() == color.CMYKModel {
c.Logger.Error("Color is cmyk") c.Logger.Error("Color is cmyk")
} else if src.ColorModel() == color.NRGBA64Model {
c.Logger.Error("Color is cmyk")
} else if src.ColorModel() == color.NYCbCrAModel {
c.Logger.Error("Color is cmyk a")
} else if src.ColorModel() == color.Alpha16Model {
c.Logger.Error("Color is cmyk a")
} else { } else {
c.Logger.Error("Other so assuming color") c.Logger.Error("Other so assuming color", "color mode", src.ColorModel())
} }
ModelUpdateStatus(c, id, FAILED_PREPARING) ModelUpdateStatus(c, id, FAILED_PREPARING)

View File

@ -136,17 +136,16 @@ func processZipFile(c *Context, model *BaseModel) {
return return
} }
if paths[0] != "training" { if paths[0] == "training" {
training = InsertIfNotPresent(training, paths[1]) training = InsertIfNotPresent(training, paths[1])
} else if paths[0] != "testing" { } else if paths[0] == "testing" {
testing = InsertIfNotPresent(testing, paths[1]) testing = InsertIfNotPresent(testing, paths[1])
} }
} }
if !reflect.DeepEqual(testing, training) { if !reflect.DeepEqual(testing, training) {
c.Logger.Info("Diff", "testing", testing, "training", training) c.Logger.Warn("Diff", "testing", testing, "training", training)
failed("Testing and Training datesets are diferent") c.Logger.Warn("Testing and traing datasets differ")
return
} }
base_path := path.Join("savedData", model.Id, "data") base_path := path.Join("savedData", model.Id, "data")
@ -266,16 +265,15 @@ func processZipFileExpand(c *Context, model *BaseModel) {
return return
} }
if paths[0] != "training" { if paths[0] == "training" {
training = InsertIfNotPresent(training, paths[1]) training = InsertIfNotPresent(training, paths[1])
} else if paths[0] != "testing" { } else if paths[0] == "testing" {
testing = InsertIfNotPresent(testing, paths[1]) testing = InsertIfNotPresent(testing, paths[1])
} }
} }
if !reflect.DeepEqual(testing, training) { if !reflect.DeepEqual(testing, training) {
failed("testing and training are diferent") c.GetLogger().Warn("testing and training differ", "testing", testing, "training", training)
return
} }
base_path := path.Join("savedData", model.Id, "data") base_path := path.Join("savedData", model.Id, "data")
@ -636,7 +634,8 @@ func handleDataUpload(handle *Handle) {
// TODO work in allowing the model to add new in the pre ready moment // TODO work in allowing the model to add new in the pre ready moment
if model.Status != READY { if model.Status != READY {
return c.JsonBadRequest("Model not in the correct state to add a more classes") c.GetLogger().Error("Model not in the ready status", "status", model.Status)
return c.JsonBadRequest("Model not in the correct state to add more classes")
} }
// TODO mk this path configurable // TODO mk this path configurable

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"os" "os"
"path" "path"
"runtime/debug"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@ -37,11 +38,19 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
return image.Scale(0, 255) return image.Scale(0, 255)
} }
func runModelNormal(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { func runModelNormal(model *BaseModel, def_id string, inputImage *tf.Tensor, data *RunnerModelData) (order int, confidence float32, err error) {
order = 0 order = 0
err = nil err = nil
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) var tf_model *tg.Model = nil
if data.Id != nil && *data.Id == def_id {
tf_model = data.Model
} else {
tf_model = tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
data.Model = tf_model
data.Id = &def_id
}
results := tf_model.Exec([]tf.Output{ results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0), tf_model.Op("StatefulPartitionedCall", 0),
@ -125,10 +134,15 @@ func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.
return return
} }
func ClassifyTask(base BasePack, task Task) (err error) { type RunnerModelData struct {
Id *string
Model *tg.Model
}
func ClassifyTask(base BasePack, task Task, data *RunnerModelData) (err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
base.GetLogger().Error("Task failed due to", "error", r) base.GetLogger().Error("Task failed due to", "error", r, "stack", string(debug.Stack()))
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running")
} }
}() }()
@ -186,6 +200,8 @@ func ClassifyTask(base BasePack, task Task) (err error) {
if model.ModelType == 2 { if model.ModelType == 2 {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id) base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
data.Model = nil
data.Id = nil
vi, confidence, err = runModelExp(base, model, def_id, inputImage) vi, confidence, err = runModelExp(base, model, def_id, inputImage)
if err != nil { if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
@ -193,7 +209,7 @@ func ClassifyTask(base BasePack, task Task) (err error) {
} }
} else { } else {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id) base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
vi, confidence, err = runModelNormal(base, model, def_id, inputImage) vi, confidence, err = runModelNormal(model, def_id, inputImage, data)
if err != nil { if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
return return

View File

@ -38,6 +38,8 @@ func TestImgForModel(c *Context, model *BaseModel, path string) (result bool) {
model_color = "greyscale" model_color = "greyscale"
case color.NRGBAModel: case color.NRGBAModel:
fallthrough fallthrough
case color.RGBAModel:
fallthrough
case color.YCbCrModel: case color.YCbCrModel:
model_color = "rgb" model_color = "rgb"
default: default:

View File

@ -6,6 +6,7 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
"github.com/goccy/go-json" "github.com/goccy/go-json"
) )
@ -39,7 +40,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
} }
if model.ModelType == 2 { if model.ModelType == 2 {
panic("TODO")
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil { if full_error != nil {
l.Error("Failed to generate defintions", "err", full_error) l.Error("Failed to generate defintions", "err", full_error)
@ -57,7 +57,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
runners := handler.DataMap["runners"].(map[string]interface{}) runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{}) runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task runner["task"] = &task
runners[runner_id] = runner runners[runner_id] = runner
handler.DataMap["runners"] = runners handler.DataMap["runners"] = runners
@ -76,4 +75,44 @@ func CleanUpFailed(b BasePack, task *Task) {
l.Error("Failed to get status", err) l.Error("Failed to get status", err)
} }
} }
// Set the class status to trained
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("Failed to set class status")
return
}
}
func CleanUpFailedRetrain(b BasePack, task *Task) {
db := b.GetDb()
l := b.GetLogger()
model, err := GetBaseModel(db, *task.ModelId)
if err != nil {
l.Error("Failed to get model", "err", err)
} else {
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("Failed to get status", err)
}
}
ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
var defData struct {
Id string `db:"md.id"`
TargetAcuuracy float64 `db:"md.target_accuracy"`
}
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil {
log.Error("failed to get def data", err)
return
}
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
if err_ != nil {
panic(err_)
}
} }

View File

@ -101,6 +101,10 @@ func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, arg
return return
} }
func SetModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) {
return setModelClassStatus(c, status, filter, args...)
}
func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) { func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) {
db := c.GetDb() db := c.GetDb()
@ -157,40 +161,23 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool)
return return
} }
func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { func trainDefinition(c BasePack, model *BaseModel, def Definition, load_prev bool) (accuracy float64, err error) {
l := c.GetLogger() l := c.GetLogger()
db := c.GetDb()
l.Warn("About to start training definition") l.Warn("About to start training definition")
accuracy = 0 accuracy = 0
layers, err := db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
layers, err := def.GetLayers(c.GetDb(), " order by layer_order asc;")
if err != nil { if err != nil {
return return
} }
defer layers.Close()
type layerrow struct { for _, layer := range layers {
LayerType int layer.ShapeToSize()
Shape string
LayerNum int
}
got := []layerrow{}
i := 1
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
return
}
row.Shape = shapeToSize(row.Shape)
row.LayerNum = 1
got = append(got, row)
i = i + 1
} }
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", definition_id) run_path := path.Join("/tmp", model.Id, "defs", def.Id)
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
@ -215,17 +202,17 @@ func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_pr
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id) result_path := path.Join("savedData", model.Id, "defs", def.Id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": got, "Layers": layers,
"Size": got[0].Shape, "Size": layers[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"RunPath": run_path, "RunPath": run_path,
"ColorMode": model.ImageMode, "ColorMode": model.ImageMode,
"Model": model, "Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN, "EPOCH_PER_RUN": EPOCH_PER_RUN,
"DefId": definition_id, "DefId": def.Id,
"LoadPrev": load_prev, "LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"), "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path), "SaveModelPath": path.Join(getDir(), result_path),
@ -352,7 +339,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i
return return
} }
func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { func trainDefinitionExpandExp(c BasePack, model *BaseModel, def Definition, load_prev bool) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
l := c.GetLogger() l := c.GetLogger()
@ -367,7 +354,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string
} }
// status = 2 (INIT) 3 (TRAINING) // status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", def.Id)
if err != nil { if err != nil {
return return
} else if len(heads) == 0 { } else if len(heads) == 0 {
@ -386,62 +373,49 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string
return return
} }
layers, err := c.GetDb().Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := def.GetLayers(c.GetDb(), " order by layer_order asc;")
if err != nil { if err != nil {
return return
} }
defer layers.Close()
type layerrow struct { var got []*Layer
LayerType int
Shape string
ExpType int
LayerNum int
}
got := []layerrow{}
i := 1 i := 1
var last *layerrow = nil var last *Layer = nil
got_2 := false got_2 := false
var first *layerrow = nil var first *Layer = nil
for layers.Next() { for _, layer := range layers {
var row = layerrow{} layer.ShapeToSize()
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
// Keep track of the first layer so we can keep the size of the image // Keep track of the first layer so we can keep the size of the image
if first == nil { if first == nil {
first = &row first = layer
} }
row.LayerNum = i if layer.ExpType == 2 {
row.Shape = shapeToSize(row.Shape)
if row.ExpType == 2 {
if !got_2 { if !got_2 {
got = append(got, *last) got = append(got, last)
got_2 = true got_2 = true
} }
got = append(got, row) got = append(got, layer)
} }
last = &row last = layer
i += 1 i += 1
} }
got = append(got, layerrow{ got = append(got, &Layer{
LayerType: LAYER_DENSE, LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1), Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2, ExpType: 2,
LayerNum: i, LayerOrder: len(got),
}) })
l.Info("Got layers", "layers", got) l.Info("Got layers", "layers", got)
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain") run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-retrain")
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
@ -472,7 +446,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id) result_path := path.Join("savedData", model.Id, "defs", def.Id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": got, "Layers": got,
@ -528,7 +502,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string
return return
} }
func trainDefinitionExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { func trainDefinitionExp(c BasePack, model *BaseModel, def Definition, load_prev bool) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
l := c.GetLogger() l := c.GetLogger()
db := c.GetDb() db := c.GetDb()
@ -544,7 +518,7 @@ func trainDefinitionExp(c BasePack, model *BaseModel, definition_id string, load
} }
// status = 2 (INIT) 3 (TRAINING) // status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](db, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) heads, err := GetDbMultitple[ExpHead](db, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", def.Id)
if err != nil { if err != nil {
return return
} else if len(heads) == 0 { } else if len(heads) == 0 {
@ -562,42 +536,24 @@ func trainDefinitionExp(c BasePack, model *BaseModel, definition_id string, load
return return
} }
layers, err := db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := def.GetLayers(db, " order by layer_order asc;")
if err != nil { if err != nil {
return return
} }
defer layers.Close()
type layerrow struct { for _, layer := range layers {
LayerType int layer.ShapeToSize()
Shape string
ExpType int
LayerNum int
} }
got := []layerrow{} layers = append(layers, &Layer{
i := 1 LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
for layers.Next() { ExpType: 2,
var row = layerrow{} LayerOrder: len(layers),
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
row.LayerNum = i
row.Shape = shapeToSize(row.Shape)
got = append(got, row)
i += 1
}
got = append(got, layerrow{
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2,
LayerNum: i,
}) })
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id) run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id)
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
@ -624,11 +580,11 @@ func trainDefinitionExp(c BasePack, model *BaseModel, definition_id string, load
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id) result_path := path.Join("savedData", model.Id, "defs", def.Id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": got, "Layers": layers,
"Size": got[0].Shape, "Size": layers[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.Id, "HeadId": exp.Id,
"RunPath": run_path, "RunPath": run_path,
@ -696,21 +652,6 @@ func remove[T interface{}](lst []T, i int) []T {
return append(lst[:i], lst[i+1:]...) return append(lst[:i], lst[i+1:]...)
} }
type TrainModelRow struct {
id string
target_accuracy int
epoch int
acuracy float64
}
type TraingModelRowDefinitions []TrainModelRow
func (nf TraingModelRowDefinitions) Len() int { return len(nf) }
func (nf TraingModelRowDefinitions) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
func (nf TraingModelRowDefinitions) Less(i, j int) bool {
return nf[i].acuracy < nf[j].acuracy
}
type ToRemoveList []int type ToRemoveList []int
func (nf ToRemoveList) Len() int { return len(nf) } func (nf ToRemoveList) Len() int { return len(nf) }
@ -723,30 +664,16 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
db := c.GetDb() db := c.GetDb()
l := c.GetLogger() l := c.GetLogger()
definitionsRows, err := db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) defs_, err := model.GetDefinitions(db, "and md.status=$2", MODEL_DEFINITION_STATUS_INIT)
if err != nil { if err != nil {
l.Error("Failed to train Model! Err:") l.Error("Failed to train Model!", "err", err)
l.Error(err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return return
} }
defer definitionsRows.Close()
var definitions TraingModelRowDefinitions = []TrainModelRow{} var defs SortByAccuracyDefinitions = defs_
for definitionsRows.Next() { if len(defs) == 0 {
var rowv TrainModelRow
rowv.acuracy = 0
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
l.Error("Failed to train Model Could not read definition from db!Err:")
l.Error(err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return
}
definitions = append(definitions, rowv)
}
if len(definitions) == 0 {
l.Error("No Definitions defined!") l.Error("No Definitions defined!")
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return return
@ -757,32 +684,29 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
for { for {
var toRemove ToRemoveList = []int{} var toRemove ToRemoveList = []int{}
for i, def := range definitions { for i, def := range defs {
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, def.id, !firstRound) accuracy, err := trainDefinition(c, model, *def, !firstRound)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:", "err", err) l.Error("Failed to train definition!Err:", "err", err)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
continue continue
} }
def.epoch += EPOCH_PER_RUN def.Epoch += EPOCH_PER_RUN
accuracy = accuracy * 100 accuracy = accuracy * 100
def.acuracy = float64(accuracy) def.Accuracy = float64(accuracy)
definitions[i].epoch += EPOCH_PER_RUN if accuracy >= float64(def.TargetAccuracy) {
definitions[i].acuracy = accuracy
if accuracy >= float64(def.target_accuracy) {
l.Info("Found a definition that reaches target_accuracy!") l.Info("Found a definition that reaches target_accuracy!")
_, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id) _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:\n", "err", err) l.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return err return err
} }
_, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:\n", "err", err) l.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
@ -793,14 +717,14 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
break break
} }
if def.epoch > MAX_EPOCH { if def.Epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy) fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
continue continue
} }
_, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id) _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:\n", "err", err) l.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
@ -818,28 +742,26 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
l.Info("Round done", "toRemove", toRemove) l.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove { for _, n := range toRemove {
definitions = remove(definitions, n) defs = remove(defs, n)
} }
len_def := len(definitions) len_def := len(defs)
if len_def == 0 { if len_def == 0 {
break break
} } else if len_def == 1 {
if len_def == 1 {
continue continue
} }
sort.Sort(sort.Reverse(definitions)) sort.Sort(sort.Reverse(defs))
acc := definitions[0].acuracy - 20.0 acc := defs[0].Accuracy - 20.0
l.Info("Training models, Highest acc", "acc", definitions[0].acuracy, "mod_acc", acc) l.Info("Training models, Highest acc", "acc", defs[0].Accuracy, "mod_acc", acc)
toRemove = []int{} toRemove = []int{}
for i, def := range definitions { for i, def := range defs {
if def.acuracy < acc { if def.Accuracy < acc {
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
} }
} }
@ -849,8 +771,8 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
sort.Sort(sort.Reverse(toRemove)) sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove { for _, n := range toRemove {
l.Warn("Removing definition not fast enough learning", "n", n) l.Warn("Removing definition not fast enough learning", "n", n)
ModelDefinitionUpdateStatus(c, definitions[n].id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, defs[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
definitions = remove(definitions, n) defs = remove(defs, n)
} }
} }
@ -918,33 +840,18 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
return return
} }
type TrainModelRowUsable struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
Epoch int
Acuracy float64 `db:"0"`
}
type TrainModelRowUsables []*TrainModelRowUsable
func (nf TrainModelRowUsables) Len() int { return len(nf) }
func (nf TrainModelRowUsables) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
func (nf TrainModelRowUsables) Less(i, j int) bool {
return nf[i].Acuracy < nf[j].Acuracy
}
func trainModelExp(c BasePack, model *BaseModel) (err error) { func trainModelExp(c BasePack, model *BaseModel) (err error) {
l := c.GetLogger() l := c.GetLogger()
db := c.GetDb() db := c.GetDb()
var definitions TrainModelRowUsables defs_, err := model.GetDefinitions(db, " and status=$2;", MODEL_DEFINITION_STATUS_INIT)
definitions, err = GetDbMultitple[TrainModelRowUsable](db, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
if err != nil { if err != nil {
l.Error("Failed to get definitions") l.Error("Failed to get definitions")
return return
} }
if len(definitions) == 0 { var defs SortByAccuracyDefinitions = defs_
if len(defs) == 0 {
l.Error("No Definitions defined!") l.Error("No Definitions defined!")
return errors.New("No Definitions found") return errors.New("No Definitions found")
} }
@ -954,9 +861,9 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
for { for {
var toRemove ToRemoveList = []int{} var toRemove ToRemoveList = []int{}
for i, def := range definitions { for i, def := range defs {
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinitionExp(c, model, def.Id, !firstRound) accuracy, err := trainDefinitionExp(c, model, *def, !firstRound)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:", "err", err) l.Error("Failed to train definition!Err:", "err", err)
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
@ -965,10 +872,10 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
} }
def.Epoch += EPOCH_PER_RUN def.Epoch += EPOCH_PER_RUN
accuracy = accuracy * 100 accuracy = accuracy * 100
def.Acuracy = float64(accuracy) def.Accuracy = float64(accuracy)
definitions[i].Epoch += EPOCH_PER_RUN defs[i].Epoch += EPOCH_PER_RUN
definitions[i].Acuracy = accuracy defs[i].Accuracy = accuracy
if accuracy >= float64(def.TargetAccuracy) { if accuracy >= float64(def.TargetAccuracy) {
l.Info("Found a definition that reaches target_accuracy!") l.Info("Found a definition that reaches target_accuracy!")
@ -1018,10 +925,10 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
l.Info("Round done", "toRemove", toRemove) l.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove { for _, n := range toRemove {
definitions = remove(definitions, n) defs = remove(defs, n)
} }
len_def := len(definitions) len_def := len(defs)
if len_def == 0 { if len_def == 0 {
break break
@ -1029,14 +936,14 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
continue continue
} }
sort.Sort(sort.Reverse(definitions)) sort.Sort(sort.Reverse(defs))
acc := definitions[0].Acuracy - 20.0 acc := defs[0].Accuracy - 20.0
l.Info("Training models, Highest acc", "acc", definitions[0].Acuracy, "mod_acc", acc) l.Info("Training models, Highest acc", "acc", defs[0].Accuracy, "mod_acc", acc)
toRemove = []int{} toRemove = []int{}
for i, def := range definitions { for i, def := range defs {
if def.Acuracy < acc { if def.Accuracy < acc {
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
} }
} }
@ -1046,8 +953,8 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
sort.Sort(sort.Reverse(toRemove)) sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove { for _, n := range toRemove {
l.Warn("Removing definition not fast enough learning", "n", n) l.Warn("Removing definition not fast enough learning", "n", n)
ModelDefinitionUpdateStatus(c, definitions[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, defs[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
definitions = remove(definitions, n) defs = remove(defs, n)
} }
} }
@ -1062,6 +969,12 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return err return err
} }
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("All definitions failed to train! And Failed to set model status")
return err
}
l.Error("All definitions failed to train!") l.Error("All definitions failed to train!")
return err return err
} else if err != nil { } else if err != nil {
@ -1090,7 +1003,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return err return err
} }
if err = splitModel(c, model); err != nil { if err = SplitModel(c, model); err != nil {
err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
l.Error("Failed to split the model! And Failed to set class status") l.Error("Failed to split the model! And Failed to set class status")
@ -1123,7 +1036,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return return
} }
func splitModel(c BasePack, model *BaseModel) (err error) { func SplitModel(c BasePack, model *BaseModel) (err error) {
db := c.GetDb() db := c.GetDb()
l := c.GetLogger() l := c.GetLogger()
@ -1260,7 +1173,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
} }
db := c.GetDb() db := c.GetDb()
l := c.GetLogger()
def, err := MakeDefenition(db, model.Id, target_accuracy) def, err := MakeDefenition(db, model.Id, target_accuracy)
if err != nil { if err != nil {
@ -1279,67 +1191,35 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
} }
order++ order++
if complexity == 0 { loop := max(1, int(math.Ceil((math.Log(float64(model.Width))/math.Log(float64(10)))))+1)
/* for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
if err != nil { order++
failed()
return
}
order++
*/
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
if err != nil { if err != nil {
failed() failed()
return return
} }
order++ }
loop := int(math.Log2(float64(number_of_classes))) _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
for i := 0; i < loop; i++ { if err != nil {
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
return
}
}
} else if complexity == 1 || complexity == 2 {
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
order++
if err != nil {
failed()
return
}
}
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
if err != nil {
failed()
return
}
order++
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
if loop == 0 {
loop = 1
}
for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
failed()
return
}
}
} else {
l.Error("Unkown complexity", "complexity", complexity)
failed() failed()
return return
} }
order++
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
if loop == 0 {
loop = 1
}
for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
failed()
return
}
}
return def.UpdateStatus(db, DEFINITION_STATUS_INIT) return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
} }
@ -1410,29 +1290,16 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order := 1 order := 1
width := model.Width err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height), 1)
height := model.Height
// Note the shape of the first layer defines the import size
if complexity == 2 {
// Note the shape for now is no used
width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
}
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
order++
// handle the errors inside the pervious if block
if err != nil { if err != nil {
failed() failed()
return return
} }
order++
// Create the blocks // Create the blocks
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) loop := int(math.Ceil((math.Log(float64(model.Width)) / math.Log(float64(10))))) + 1
/*if model.Width < 50 && model.Height < 50 { /*if model.Width < 50 && model.Height < 50 {
loop = 0 loop = 0
@ -1440,7 +1307,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
log.Info("Size of the simple block", "loop", loop) log.Info("Size of the simple block", "loop", loop)
//loop = max(loop, 3) loop = max(loop, min(2, model.ImageMode))
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayerExpandable(db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1) err = MakeLayerExpandable(db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
@ -1460,7 +1327,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order++ order++
// Flatten the blocks into dense // Flatten the blocks into dense
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1) err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*2), 1)
if err != nil { if err != nil {
failed() failed()
return return
@ -1474,7 +1341,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
loop = max(loop, 3) loop = max(loop, 3)
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)*2), 2) err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)*2), 2)
order++ order++
if err != nil { if err != nil {
failed() failed()
@ -1557,31 +1424,31 @@ func trainExpandable(c *Context, model *BaseModel) {
ResetClasses(c, model) ResetClasses(c, model)
} }
var definitions TrainModelRowUsables defs_, err := model.GetDefinitions(c, " and status=$2", MODEL_DEFINITION_STATUS_READY)
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil { if err != nil {
failed("Failed to get definitions") failed("Failed to get definitions")
return return
} }
if len(definitions) != 1 { var defs SortByAccuracyDefinitions = defs_
if len(defs) != 1 {
failed("There should only be one definition available!") failed("There should only be one definition available!")
return return
} }
firstRound := true firstRound := true
def := definitions[0] def := defs[0]
epoch := 0 epoch := 0
for { for {
acc, err := trainDefinitionExp(c, model, def.Id, !firstRound) acc, err := trainDefinitionExp(c, model, *def, !firstRound)
if err != nil { if err != nil {
failed("Failed to train definition!") failed("Failed to train definition!")
return return
} }
epoch += EPOCH_PER_RUN epoch += EPOCH_PER_RUN
if float64(acc*100) >= float64(def.Acuracy) { if float64(acc*100) >= float64(def.Accuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!") c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING) _, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
@ -1686,22 +1553,18 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining") task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining")
var defData struct { defs, err := model.GetDefinitions(db, "")
Id string `db:"md.id"`
TargetAcuuracy float64 `db:"md.target_accuracy"`
}
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil { if err != nil {
failed() failed()
return return
} }
def := *defs[0]
failed = func() { failed = func() {
ResetClasses(b, model) ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED) ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining") task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining")
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id) _, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", def.Id)
if err_ != nil { if err_ != nil {
panic(err_) panic(err_)
} }
@ -1712,21 +1575,21 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
var epocs = 0 var epocs = 0
// TODO make max epochs come from db // TODO make max epochs come from db
// TODO re increase the target accuracy // TODO re increase the target accuracy
for acc*100 < defData.TargetAcuuracy-5 && epocs < 10 { for acc*100 < float64(def.TargetAccuracy)-5 && epocs < 10 {
// This is something I have to check // This is something I have to check
acc, err = trainDefinitionExpandExp(b, model, defData.Id, epocs > 0) acc, err = trainDefinitionExpandExp(b, model, def, epocs > 0)
if err != nil { if err != nil {
failed() failed()
return return
} }
l.Info("Retrained model", "accuracy", acc, "target", defData.TargetAcuuracy) l.Info("Retrained model", "accuracy", acc, "target", def.TargetAccuracy)
epocs += 1 epocs += 1
} }
if acc*100 < defData.TargetAcuuracy { if acc*100 < float64(def.TargetAccuracy)-5 {
l.Error("Model never achived targetd accuracy", "acc", acc*100, "target", defData.TargetAcuuracy) l.Error("Model never achived targetd accuracy", "acc", acc*100, "target", def.TargetAccuracy)
failed() failed()
return return
} }
@ -1747,6 +1610,13 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
return return
} }
_, err = db.Exec("update exp_model_head set status=$1 where status=$2 and def_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, def.Id)
if err != nil {
l.Error("Error while updating the classes", "error", err)
failed()
return
}
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining") task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
return return

View File

@ -68,7 +68,7 @@ func handleTasksStats(handle *Handle) {
} else if task.Status < 2 { } else if task.Status < 2 {
total.Classfication_pre += 1 total.Classfication_pre += 1
hours[hour].Classfication_pre += 1 hours[hour].Classfication_pre += 1
} else if task.Status < 4 { } else if task.Status < 4 || task.Status == 5 {
total.Classfication_running += 1 total.Classfication_running += 1
hours[hour].Classfication_running += 1 hours[hour].Classfication_running += 1
} }

View File

@ -1,7 +0,0 @@
# Runner Protocol
```
/----\
\/ |
Register -> Init -> Active ---> Ready -> Info
```

View File

@ -9,4 +9,5 @@ func HandleTasks(handle *Handle) {
handleList(handle) handleList(handle)
handleRequests(handle) handleRequests(handle)
handleRemoteRunner(handle) handleRemoteRunner(handle)
handleRunnerData(handle)
} }

View File

@ -1,6 +1,8 @@
package tasks package tasks
import ( import (
"os"
"path"
"sync" "sync"
"time" "time"
@ -8,6 +10,7 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
) )
func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) { func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) {
@ -32,6 +35,12 @@ type VerifyTask struct {
TaskId string `json:"taskId" validate:"required"` TaskId string `json:"taskId" validate:"required"`
} }
type RunnerTrainDef struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"`
}
func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) { func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) {
mutex := x.DataMap["runners_mutex"].(*sync.Mutex) mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock() mutex.Lock()
@ -51,6 +60,18 @@ func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Erro
return runner_data["task"].(*Task), nil return runner_data["task"].(*Task), nil
} }
func clearRunnerTask(x *Handle, runner_id string) {
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[runner_id].(map[string]interface{})
runner_data["task"] = nil
runners[runner_id] = runner_data
x.DataMap["runners"] = runners
}
func handleRemoteRunner(x *Handle) { func handleRemoteRunner(x *Handle) {
type RegisterRunner struct { type RegisterRunner struct {
@ -200,6 +221,10 @@ func handleRemoteRunner(x *Handle) {
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
CleanUpFailed(c, task) CleanUpFailed(c, task)
case int(TASK_TYPE_RETRAINING):
CleanUpFailedRetrain(c, task)
case int(TASK_TYPE_CLASSIFICATION):
// DO nothing
default: default:
panic("Do not know how to handle this") panic("Do not know how to handle this")
} }
@ -218,7 +243,7 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
PostAuthJson(x, "/tasks/runner/train/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error { PostAuthJson(x, "/tasks/runner/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -229,7 +254,15 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
if task.TaskType != int(TASK_TYPE_TRAINING) { var status DefinitionStatus
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
status = DEFINITION_STATUS_INIT
case int(TASK_TYPE_RETRAINING):
fallthrough
case int(TASK_TYPE_CLASSIFICATION):
status = DEFINITION_STATUS_READY
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions") return c.JsonBadRequest("Task is not the right type go get the definitions")
} }
@ -239,7 +272,7 @@ func handleRemoteRunner(x *Handle) {
return c.E500M("Failed to get model information", err) return c.E500M("Failed to get model information", err)
} }
defs, err := model.GetDefinitions(c, "and md.status=$2", DEFINITION_STATUS_INIT) defs, err := model.GetDefinitions(c, "and md.status=$2", status)
if err != nil { if err != nil {
return c.E500M("Failed to get the model definitions", err) return c.E500M("Failed to get the model definitions", err)
} }
@ -247,7 +280,7 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON(defs) return c.SendJSON(defs)
}) })
PostAuthJson(x, "/tasks/runner/train/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error { PostAuthJson(x, "/tasks/runner/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -258,22 +291,35 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId) model, err := GetBaseModel(c, *task.ModelId)
if err != nil { if err != nil {
return c.E500M("Failed to get model information", err) return c.E500M("Failed to get model information", err)
} }
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN) switch task.TaskType {
if err != nil { case int(TASK_TYPE_TRAINING):
return c.E500M("Failed to get the model classes", err) classes, err := model.GetClasses(c, "and status in ($2, $3) order by mc.class_order asc", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
case int(TASK_TYPE_RETRAINING):
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
case int(TASK_TYPE_CLASSIFICATION):
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINED)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
} }
return c.SendJSON(classes)
}) })
type RunnerTrainDefStatus struct { type RunnerTrainDefStatus struct {
@ -311,12 +357,13 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
type RunnerTrainDefLayers struct { type RunnerTrainDefHeadStatus struct {
Id string `json:"id" validate:"required"` Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"` TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"` DefId string `json:"defId" validate:"required"`
Status ModelHeadStatus `json:"status" validate:"required"`
} }
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error { PostAuthJson(x, "/tasks/runner/train/def/head/status", User_Normal, func(c *Context, dat *RunnerTrainDefHeadStatus) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -337,6 +384,69 @@ func handleRemoteRunner(x *Handle) {
return c.E500M("Failed to get definition information", err) return c.E500M("Failed to get definition information", err)
} }
_, err = c.Exec("update exp_model_head set status=$1 where def_id=$2;", dat.Status, def.Id)
if err != nil {
log.Error("Failed to train definition!")
return c.E500M("Failed to train definition", err)
}
return c.SendJSON("Ok")
})
type RunnerRetrainDefHeadStatus struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
HeadId string `json:"defId" validate:"required"`
Status ModelHeadStatus `json:"status" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/retrain/def/head/status", User_Normal, func(c *Context, dat *RunnerRetrainDefHeadStatus) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_RETRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
if err := UpdateStatus(c.GetDb(), "exp_model_head", dat.HeadId, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return c.E500M("Failed to update head status", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
// Do nothing
case int(TASK_TYPE_RETRAINING):
// Do nothing
default:
c.Logger.Error("Task not is not the right type to get the layers", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the layers")
}
def, err := GetDefinition(c, dat.DefId)
if err != nil {
return c.E500M("Failed to get definition information", err)
}
layers, err := def.GetLayers(c, " order by layer_order asc") layers, err := def.GetLayers(c, " order by layer_order asc")
if err != nil { if err != nil {
return c.E500M("Failed to get layers", err) return c.E500M("Failed to get layers", err)
@ -356,7 +466,12 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
if task.TaskType != int(TASK_TYPE_TRAINING) { switch task.TaskType {
case int(TASK_TYPE_TRAINING):
// DO nothing
case int(TASK_TYPE_RETRAINING):
// DO nothing
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions") return c.JsonBadRequest("Task is not the right type go get the definitions")
} }
@ -383,4 +498,463 @@ func handleRemoteRunner(x *Handle) {
Training: training_points, Training: training_points,
}) })
}) })
type RunnerTrainDefEpoch struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"`
Epoch int `json:"epoch" validate:"required"`
Accuracy float64 `json:"accuracy" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/train/epoch", User_Normal, func(c *Context, dat *RunnerTrainDefEpoch) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{
Id: dat.Id,
TaskId: dat.TaskId,
})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
def, err := GetDefinition(c, dat.DefId)
if err != nil {
return c.E500M("Failed to get definition information", err)
}
err = def.UpdateAfterEpoch(c, dat.Accuracy, dat.Epoch)
if err != nil {
return c.E500M("Failed to update model", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/mark-failed", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{
Id: dat.Id,
TaskId: dat.TaskId,
})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
_, err := c.Exec(
"update model_definition set status=$1 "+
"where model_id=$2 and status in ($3, $4)",
MODEL_DEFINITION_STATUS_CANCELD_TRAINING,
task.ModelId,
MODEL_DEFINITION_STATUS_TRAINING,
MODEL_DEFINITION_STATUS_PAUSED_TRAINING,
)
if err != nil {
return c.E500M("Failed to mark definition as failed", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/model", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
//DO NOTHING
case int(TASK_TYPE_RETRAINING):
//DO NOTHING
case int(TASK_TYPE_CLASSIFICATION):
//DO NOTHING
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
return c.E500M("Failed to get model information", err)
}
return c.SendJSON(model)
})
PostAuthJson(x, "/tasks/runner/heads", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
type ExpHead struct {
Id string `json:"id"`
Start int `db:"range_start" json:"start"`
End int `db:"range_end" json:"end"`
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
fallthrough
case int(TASK_TYPE_RETRAINING):
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and status in (2,3)", dat.DefId)
if err != nil {
return c.E500M("Failed getting active heads", err)
}
return c.SendJSON(heads)
case int(TASK_TYPE_CLASSIFICATION):
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1", dat.DefId)
if err != nil {
return c.E500M("Failed getting active heads", err)
}
return c.SendJSON(heads)
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
})
PostAuthJson(x, "/tasks/runner/train_exp/class/status/train", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
return c.E500M("Failed to get model", err)
}
err = SetModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN)
if err != nil {
return c.E500M("Failed update status", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// TODO Make the Model status have a message
c.Logger.Error("All definitions failed to train!")
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
return c.SendJSON("Ok")
} else if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
return c.E500M("Failed to get model definition", err)
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
return c.E500M("Failed to update model definition", err)
}
to_delete, err := c.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
defer to_delete.Close()
for to_delete.Next() {
var id string
if err = to_delete.Scan(&id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
}
// TODO Check if returning also works here
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
// Set the class status to trained
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1;", model.Id)
if err != nil {
c.Logger.Error("Failed to set class status")
return c.E500M("Failed to set class status", err)
}
if err = model.UpdateStatus(c, READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to update status of model", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
})
type RunnerClassDone struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
Result string `json:"result" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/class/done", User_Normal, func(c *Context, dat *RunnerClassDone) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{
Id: dat.Id,
TaskId: dat.TaskId,
})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_CLASSIFICATION) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
err := task.SetResultText(c, dat.Result)
if err != nil {
return c.E500M("Failed to update the task", err)
}
err = task.UpdateStatus(c, TASK_DONE, "Task completed")
if err != nil {
return c.E500M("Failed to update task", err)
}
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
runner_data["task"] = nil
runners[dat.Id] = runner_data
x.DataMap["runners"] = runners
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train_exp/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
// TODO add check the to the model
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// TODO Make the Model status have a message
c.Logger.Error("All definitions failed to train!")
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
} else if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
return c.E500M("Failed to get model definition", err)
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
return c.E500M("Failed to update model definition", err)
}
to_delete, err := GetDbMultitple[JustId](c, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
c.GetLogger().Error("Failed to select model_definition to delete")
return c.E500M("Failed to select model definition to delete", err)
}
for _, d := range to_delete {
os.RemoveAll(path.Join("savedData", model.Id, "defs", d.Id))
}
// TODO Check if returning also works here
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
if err = SplitModel(c, model); err != nil {
err = SetModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
c.Logger.Error("Failed to split the model! And Failed to set class status")
return c.E500M("Failed to split the model", err)
}
c.Logger.Error("Failed to split the model")
return c.E500M("Failed to split the model", err)
}
// Set the class status to trained
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
c.Logger.Error("Failed to set class status")
return c.E500M("Failed to set class status", err)
}
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
if err = model.UpdateStatus(c, READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to update status of model", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
runner_data["task"] = nil
runners[dat.Id] = runner_data
x.DataMap["runners"] = runners
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/retrain/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_RETRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to set class status", err)
}
defs, err := model.GetDefinitions(c, "")
if err != nil {
return c.E500M("Failed to get definitions", err)
}
_, err = c.Exec("update exp_model_head set status=$1 where status=$2 and def_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, defs[0].Id)
if err != nil {
return c.E500M("Failed to set head status", err)
}
err = model.UpdateStatus(c, READY)
if err != nil {
return c.E500M("Failed to set class status", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
})
} }

View File

@ -19,6 +19,8 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
) )
var QUEUE_SIZE = 10
/** /**
* Actually runs the code * Actually runs the code
*/ */
@ -47,17 +49,28 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
Host: config.Hostname, Host: config.Hostname,
} }
loaded_model := RunnerModelData{
Id: nil,
Model: nil,
}
count := 0
for task := range task_channel { for task := range task_channel {
logger.Info("Got task", "task", task) logger.Info("Got task", "task", task)
task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task") task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task")
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) { if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
logger.Info("Classification Task") logger.Info("Classification Task")
if err = ClassifyTask(base, task); err != nil { if err = ClassifyTask(base, task, &loaded_model); err != nil {
logger.Error("Classification task failed", "error", err) logger.Error("Classification task failed", "error", err)
} }
back_channel <- index if count == QUEUE_SIZE {
back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} else if task.TaskType == int(TASK_TYPE_TRAINING) { } else if task.TaskType == int(TASK_TYPE_TRAINING) {
logger.Info("Training Task") logger.Info("Training Task")
@ -65,7 +78,12 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
back_channel <- index if count == QUEUE_SIZE {
back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} else if task.TaskType == int(TASK_TYPE_RETRAINING) { } else if task.TaskType == int(TASK_TYPE_RETRAINING) {
logger.Info("Retraining Task") logger.Info("Retraining Task")
@ -73,7 +91,12 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
back_channel <- index if count == QUEUE_SIZE {
back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} else if task.TaskType == int(TASK_TYPE_DELETE_USER) { } else if task.TaskType == int(TASK_TYPE_DELETE_USER) {
logger.Warn("User deleting Task") logger.Warn("User deleting Task")
@ -81,13 +104,23 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
back_channel <- index if count == QUEUE_SIZE {
back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} }
logger.Error("Do not know how to route task", "task", task) logger.Error("Do not know how to route task", "task", task)
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task")
back_channel <- index if count == QUEUE_SIZE {
back_channel <- index
count = 0
} else {
count += 1
}
} }
} }
@ -120,10 +153,22 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
defer mutex.Unlock() defer mutex.Unlock()
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_RETRAINING):
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
if err := PrepareTraining(handler, base, task, runner_id); err != nil { if err := PrepareTraining(handler, base, task, runner_id); err != nil {
logger.Error("Failed to prepare for training", "err", err) logger.Error("Failed to prepare for training", "err", err)
} }
case int(TASK_TYPE_CLASSIFICATION):
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
default: default:
logger.Error("Not sure what to do panicing", "taskType", task.TaskType) logger.Error("Not sure what to do panicing", "taskType", task.TaskType)
panic("not sure what to do") panic("not sure what to do")
@ -133,7 +178,7 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
/** /**
* Tells the orcchestator to look at the task list from time to time * Tells the orcchestator to look at the task list from time to time
*/ */
func attentionSeeker(config Config, back_channel chan int) { func attentionSeeker(config Config, db db.Db, back_channel chan int) {
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true, ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
@ -158,6 +203,20 @@ func attentionSeeker(config Config, back_channel chan int) {
for true { for true {
back_channel <- 0 back_channel <- 0
for {
var s struct {
Count int `json:"count(*)"`
}
err := GetDBOnce(db, &s, "tasks where stauts = 5 or status = 3")
if err != nil {
break
}
if s.Count == 0 {
break
}
time.Sleep(t)
}
time.Sleep(t) time.Sleep(t)
} }
} }
@ -173,9 +232,7 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
Prefix: "Runner Orchestrator Logger", Prefix: "Runner Orchestrator Logger",
}) })
// Setup vars setupHandle(handler)
handler.DataMap["runners"] = map[string]interface{}{}
handler.DataMap["runners_mutex"] = &sync.Mutex{}
base := BasePackStruct{ base := BasePackStruct{
Db: db, Db: db,
@ -184,17 +241,22 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
} }
gpu_workers := config.GpuWorker.NumberOfWorkers gpu_workers := config.GpuWorker.NumberOfWorkers
def_wait, err := time.ParseDuration(config.GpuWorker.Pulling)
if err != nil {
logger.Error("Failed to load", "error", err)
return
}
logger.Info("Starting runners") logger.Info("Starting runners")
task_runners := make([]chan Task, gpu_workers) task_runners := make([]chan Task, gpu_workers)
task_runners_used := make([]bool, gpu_workers) task_runners_used := make([]int, gpu_workers)
// One more to accomudate the Attention Seeker channel // One more to accomudate the Attention Seeker channel
back_channel := make(chan int, gpu_workers+1) back_channel := make(chan int, gpu_workers+1)
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
logger.Error("Recovered in Orchestrator restarting", "due to", r) logger.Error("Recovered in Orchestrator restarting", "due to", r, "stack", string(debug.Stack()))
for x := range task_runners { for x := range task_runners {
close(task_runners[x]) close(task_runners[x])
} }
@ -203,87 +265,120 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
} }
}() }()
go attentionSeeker(config, back_channel) // go attentionSeeker(config, db, back_channel)
// Start the runners // Start the runners
for i := 0; i < gpu_workers; i++ { for i := 0; i < gpu_workers; i++ {
task_runners[i] = make(chan Task, 10) task_runners[i] = make(chan Task, QUEUE_SIZE)
task_runners_used[i] = false task_runners_used[i] = 0
AddLocalRunner(handler, LocalRunner{
RunnerNum: i + 1,
Task: nil,
})
go runner(config, db, task_runners[i], i+1, back_channel) go runner(config, db, task_runners[i], i+1, back_channel)
} }
var task_to_dispatch *Task = nil used := 0
wait := time.Nanosecond * 100
for i := range back_channel { for {
out := true
if i > 0 { for out {
logger.Info("Runner freed", "runner", i) select {
task_runners_used[i-1] = false case i := <-back_channel:
} else if i < 0 { if i != 0 {
logger.Error("Runner died! Restarting!", "runner", i) if i > 0 {
i = int(math.Abs(float64(i)) - 1) logger.Info("Runner freed", "runner", i)
task_runners_used[i] = false task_runners_used[i-1] = 0
go runner(config, db, task_runners[i], i+1, back_channel) used = 0
} else if i < 0 {
logger.Error("Runner died! Restarting!", "runner", i)
i = int(math.Abs(float64(i)) - 1)
task_runners_used[i] = 0
used = 0
go runner(config, db, task_runners[i], i+1, back_channel)
}
AddLocalTask(handler, int(math.Abs(float64(i))), nil)
} else if used == len(task_runners_used) {
continue
}
case <-time.After(wait):
if wait == time.Nanosecond*100 {
wait = def_wait
}
out = false
}
} }
if task_to_dispatch == nil { for {
var task TaskT tasks, err := GetDbMultitple[TaskT](db, "tasks as t "+
err := GetDBOnce(db, &task, "tasks as t "+
// Get depenencies // Get depenencies
"left join tasks_dependencies as td on t.id=td.main_id "+ "left join tasks_dependencies as td on t.id=td.main_id "+
// Get the task that the depencey resolves to // Get the task that the depencey resolves to
"left join tasks as t2 on t2.id=td.dependent_id "+ "left join tasks as t2 on t2.id=td.dependent_id "+
"where t.status=1 "+ "where t.status=1 "+
"group by t.id having count(td.id) filter (where t2.status in (0,1,2,3)) = 0;") "group by t.id having count(td.id) filter (where t2.status in (0,1,2,3)) = 0 limit 20;")
if err != NotFoundError && err != nil { if err != NotFoundError && err != nil {
log.Error("Failed to get tasks from db", "err", err) log.Error("Failed to get tasks from db", "err", err)
continue continue
} }
if err == NotFoundError { if err == NotFoundError || len(tasks) == 0 {
task_to_dispatch = nil break
} else {
temp := Task(task)
task_to_dispatch = &temp
} }
}
if task_to_dispatch != nil { for _, task_to_dispatch := range tasks {
ttd := Task(*task_to_dispatch)
if task_to_dispatch != nil && task_to_dispatch.TaskType != int(TASK_TYPE_DELETE_USER) {
// TODO split tasks into cpu tasks and GPU tasks
mutex := handler.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
remote_runners := handler.DataMap["runners"].(map[string]interface{})
// Only let CPU tasks be done by the local users for k, v := range remote_runners {
if task_to_dispatch.TaskType == int(TASK_TYPE_DELETE_USER) { runner_data := v.(map[string]interface{})
for i := 0; i < len(task_runners_used); i += 1 { runner_info := runner_data["runner_info"].(*Runner)
if !task_runners_used[i] {
task_runners[i] <- *task_to_dispatch if runner_data["task"] != nil {
task_runners_used[i] = true continue
}
if runner_info.UserId != task_to_dispatch.UserId {
continue
}
go handleRemoteTask(handler, base, k, ttd)
task_to_dispatch = nil task_to_dispatch = nil
break break
} }
}
continue
}
mutex := handler.DataMap["runners_mutex"].(*sync.Mutex) mutex.Unlock()
mutex.Lock()
remote_runners := handler.DataMap["runners"].(map[string]interface{})
for k, v := range remote_runners {
runner_data := v.(map[string]interface{})
runner_info := runner_data["runner_info"].(*Runner)
if runner_data["task"] != nil {
continue
} }
if runner_info.UserId == task_to_dispatch.UserId { used = 0
go handleRemoteTask(handler, base, k, *task_to_dispatch) if task_to_dispatch != nil {
task_to_dispatch = nil for i := 0; i < len(task_runners_used); i += 1 {
if task_runners_used[i] <= QUEUE_SIZE {
ttd.UpdateStatusLog(base, TASK_QUEUED, "Runner picked up task")
task_runners[i] <- ttd
task_runners_used[i] += 1
AddLocalTask(handler, i+1, &ttd)
task_to_dispatch = nil
wait = time.Nanosecond * 100
break
} else {
used += 1
}
}
}
if used == len(task_runners_used) {
break break
} }
} }
mutex.Unlock() if used == len(task_runners_used) {
break
}
} }
} }
} }

View File

@ -0,0 +1,51 @@
package task_runner
import (
"sync"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
type LocalRunner struct {
RunnerNum int `json:"id"`
Task *Task `json:"task"`
}
type LocalRunners map[int]*LocalRunner
func LockRunners(handler *Handle, t string) *sync.Mutex {
req := t + "_runners_mutex"
if t == "" {
req = "runners_mutex"
}
mutex := handler.DataMap[req].(*sync.Mutex)
mutex.Lock()
return mutex
}
func setupHandle(handler *Handle) {
// Setup Remote Runner data
handler.DataMap["runners"] = map[string]interface{}{}
handler.DataMap["runners_mutex"] = &sync.Mutex{}
// Setup Local Runner data
handler.DataMap["local_runners"] = &LocalRunners{}
handler.DataMap["local_runners_mutex"] = &sync.Mutex{}
}
func AddLocalRunner(handler *Handle, runner LocalRunner) {
mutex := LockRunners(handler, "local")
defer mutex.Unlock()
runners := handler.DataMap["local_runners"].(*LocalRunners)
(*runners)[runner.RunnerNum] = &runner
}
func AddLocalTask(handler *Handle, runner_id int, task *Task) {
mutex := LockRunners(handler, "local")
defer mutex.Unlock()
runners := handler.DataMap["local_runners"].(*LocalRunners)
(*(*runners)[runner_id]).Task = task
}

View File

@ -0,0 +1,25 @@
package tasks
import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func handleRunnerData(x *Handle) {
type NonType struct{}
PostAuthJson(x, "/tasks/runner/info", User_Admin, func(c *Context, dat *NonType) *Error {
mutex_remote := LockRunners(x, "")
defer mutex_remote.Unlock()
mutex_local := LockRunners(x, "local")
defer mutex_local.Unlock()
return c.SendJSON(struct {
RemoteRunners map[string]interface{} `json:"remoteRunners"`
LocalRunner *LocalRunners `json:"localRunners"`
}{
RemoteRunners: x.DataMap["runners"].(map[string]interface{}),
LocalRunner: x.DataMap["local_runners"].(*LocalRunners),
})
})
}

View File

@ -50,6 +50,7 @@ const (
TASK_PREPARING = 0 TASK_PREPARING = 0
TASK_TODO = 1 TASK_TODO = 1
TASK_PICKED_UP = 2 TASK_PICKED_UP = 2
TASK_QUEUED = 5
TASK_RUNNING = 3 TASK_RUNNING = 3
TASK_DONE = 4 TASK_DONE = 4
) )
@ -101,7 +102,11 @@ func (t Task) SetResult(base BasePack, result any) (err error) {
if err != nil { if err != nil {
return return
} }
_, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", text, t.Id) return t.SetResultText(base, string(text))
}
func (t Task) SetResultText(base BasePack, text string) (err error) {
_, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", []byte(text), t.Id)
return return
} }

View File

@ -241,6 +241,17 @@ func UsersEndpints(db db.Db, handle *Handle) {
return c.SendJSON(userReturn) return c.SendJSON(userReturn)
}) })
PostAuthJson(handle, "/user/info/get", User_Admin, func(c *Context, dat *JustId) *Error {
var user *User
user, err := UserFromId(c, dat.Id)
if err == NotFoundError {
return c.SendJSONStatus(404, "User not found")
} else if err != nil {
return c.E500M("Could not get user", err)
}
return c.SendJSON(user)
})
// Handles updating users // Handles updating users
type UpdateUserData struct { type UpdateUserData struct {
Id string `json:"id"` Id string `json:"id"`

View File

@ -23,7 +23,12 @@ type ServiceUser struct {
} }
type DbInfo struct { type DbInfo struct {
MaxConnections int `toml:"max_connections"` MaxConnections int `toml:"max_connections"`
Host string `toml:"host"`
Port int `toml:"port"`
User string `toml:"user"`
Password string `toml:"password"`
Dbname string `toml:"dbname"`
} }
type Config struct { type Config struct {
@ -97,7 +102,7 @@ func (c *Config) Cleanup(db db.Db) {
failLog(err) failLog(err)
_, err = db.Exec("update models set status=$1 where status=$2", FAILED_PREPARING, PREPARING) _, err = db.Exec("update models set status=$1 where status=$2", FAILED_PREPARING, PREPARING)
failLog(err) failLog(err)
_, err = db.Exec("update tasks set status=$1 where status=$2", TASK_TODO, TASK_PICKED_UP) _, err = db.Exec("update tasks set status=$1 where status=$2 or status=$3", TASK_TODO, TASK_PICKED_UP, TASK_QUEUED)
failLog(err) failLog(err)
tasks, err := GetDbMultitple[Task](db, "tasks where status=$1", TASK_RUNNING) tasks, err := GetDbMultitple[Task](db, "tasks where status=$1", TASK_RUNNING)
@ -114,12 +119,16 @@ func (c *Config) Cleanup(db db.Db) {
tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again") tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again")
_, err = db.Exec("update models set status=$1 where id=$2", READY_RETRAIN_FAILED, tasks[i].ModelId) _, err = db.Exec("update models set status=$1 where id=$2", READY_RETRAIN_FAILED, tasks[i].ModelId)
failLog(err) failLog(err)
_, err = db.Exec("update model_classes set status=$1 where model_id=$2 and status=$3", CLASS_STATUS_TO_TRAIN, tasks[i].ModelId, CLASS_STATUS_TRAINING)
failLog(err)
continue continue
} }
if tasks[i].TaskType == int(TASK_TYPE_TRAINING) { if tasks[i].TaskType == int(TASK_TYPE_TRAINING) {
tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again") tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again")
_, err = db.Exec("update models set status=$1 where id=$2", FAILED_TRAINING, tasks[i].ModelId) _, err = db.Exec("update models set status=$1 where id=$2", FAILED_TRAINING, tasks[i].ModelId)
failLog(err) failLog(err)
_, err = db.Exec("update model_classes set status=$1 where model_id=$2 and status=$3", CLASS_STATUS_TO_TRAIN, tasks[i].ModelId, CLASS_STATUS_TRAINING)
failLog(err)
continue continue
} }
} }

View File

@ -449,7 +449,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true, ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
TimeFormat: time.Kitchen, TimeFormat: time.DateTime,
Prefix: r.URL.Path, Prefix: r.URL.Path,
}) })

15
main.go
View File

@ -15,25 +15,18 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
) )
const (
host = "localhost"
port = 5432
user = "postgres"
password = "verysafepassword"
dbname = "aistuff"
)
func main() { func main() {
config := LoadConfig()
log.Info("Config loaded!", "config", config)
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
"password=%s dbname=%s sslmode=disable", "password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname) config.DbInfo.Host, config.DbInfo.Port, config.DbInfo.User, config.DbInfo.Password, config.DbInfo.Dbname)
db := db.StartUp(psqlInfo) db := db.StartUp(psqlInfo)
defer db.Close() defer db.Close()
config := LoadConfig()
log.Info("Config loaded!", "config", config)
config.GenerateToken(db) config.GenerateToken(db)
//TODO check if file structure exists to save data //TODO check if file structure exists to save data

View File

@ -1,6 +1,6 @@
events { events {
worker_connections 1024; worker_connections 2024;
} }
http { http {
@ -17,7 +17,7 @@ http {
location / { location / {
proxy_http_version 1.1; proxy_http_version 1.1;
proxy_pass http://127.0.0.1:5001; proxy_pass http://webpage:5001;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header Upgrade $http_upgrade; proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade; proxy_set_header Connection $connection_upgrade;
@ -25,7 +25,7 @@ http {
location /api { location /api {
proxy_pass http://127.0.0.1:5002; proxy_pass http://server:5002;
} }
} }
} }

View File

@ -1,5 +1,4 @@
# tensorflow[and-cuda] == 2.15.1 tensorflow[and-cuda] == 2.15.1
tensorflow[and-cuda] == 2.9.1
pandas pandas
# Make sure to install the nvidia pyindex first # Make sure to install the nvidia pyindex first
# nvidia-pyindex # nvidia-pyindex

2
run.sh
View File

@ -1,2 +1,2 @@
#!/bin/bash #!/bin/bash
podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-server bash podman run --network host --gpus all --replace --name fyp-server --ulimit=nofile=100000:100000 -it -v $(pwd):/app -e "TERM=xterm-256color" --restart=always andre-fyp-server

1
runner/.gitignore vendored
View File

@ -1 +0,0 @@
target/

1936
runner/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +0,0 @@
[package]
name = "runner"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.82"
serde = { version = "1.0.200", features = ["derive"] }
toml = "0.8.12"
reqwest = { version = "0.12", features = ["json"] }
tokio = { version = "1", features = ["full"] }
serde_json = "1.0.116"
serde_repr = "0.1"
tch = { version = "0.16.0", features = ["download-libtorch"] }
rand = "0.8.5"

View File

@ -1,12 +0,0 @@
FROM docker.io/nvidia/cuda:11.7.1-devel-ubuntu22.04
RUN apt-get update
RUN apt-get install -y curl
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="$PATH:/root/.cargo/bin"
RUN rustup toolchain install stable
RUN apt-get install -y pkg-config libssl-dev
WORKDIR /app

View File

@ -1,3 +0,0 @@
hostname = "https://testing.andr3h3nriqu3s.com/api"
token = "d2bc41e8293937bcd9397870c98f97acc9603f742924b518e193cd1013e45d57897aa302b364001c72b458afcfb34239dfaf38a66b318e5cbc973eea"
data_path = "/home/andr3/Documents/my-repos/fyp"

View File

@ -1 +0,0 @@
id = "a7cec9e9-1d05-4633-8bc5-6faabe4fd5a3"

View File

@ -1,2 +0,0 @@
#!/bin/bash
podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-runner bash

View File

@ -1,115 +0,0 @@
use crate::{model::DataPoint, settings::ConfigFile};
use std::{path::Path, sync::Arc};
use tch::Tensor;
pub struct DataLoader {
pub batch_size: i64,
pub len: usize,
pub inputs: Vec<Tensor>,
pub labels: Vec<Tensor>,
pub pos: usize,
}
fn import_image(
item: &DataPoint,
base_path: &Path,
classes_len: i64,
inputs: &mut Vec<Tensor>,
labels: &mut Vec<Tensor>,
) {
inputs.push(
tch::vision::image::load(base_path.join(&item.path))
.ok()
.unwrap()
.unsqueeze(0),
);
if item.class >= 0 {
let t = tch::Tensor::from_slice(&[item.class]).onehot(classes_len as i64);
labels.push(t);
} else {
labels.push(tch::Tensor::zeros(
[1, classes_len as i64],
(tch::Kind::Float, tch::Device::Cpu),
))
}
}
impl DataLoader {
pub fn new(
config: Arc<ConfigFile>,
data: Vec<DataPoint>,
classes_len: i64,
batch_size: i64,
) -> DataLoader {
let len: f64 = (data.len() as f64) / (batch_size as f64);
let min_len: i64 = len.floor() as i64;
let max_len: i64 = len.ceil() as i64;
println!(
"Creating dataloader data len: {} len: {} min_len: {} max_len:{}",
data.len(),
len,
min_len,
max_len
);
let base_path = Path::new(&config.data_path);
let mut inputs: Vec<Tensor> = Vec::new();
let mut all_labels: Vec<Tensor> = Vec::new();
for batch in 0..min_len {
let mut batch_acc: Vec<Tensor> = Vec::new();
let mut labels: Vec<Tensor> = Vec::new();
for image in 0..batch_size {
let i: usize = (batch * batch_size + image).try_into().unwrap();
let item = &data[i];
import_image(item, base_path, classes_len, &mut batch_acc, &mut labels)
}
inputs.push(tch::Tensor::cat(&batch_acc[0..], 0));
all_labels.push(tch::Tensor::cat(&labels[0..], 0));
}
// Import the last batch that has irregular sizing
if min_len != max_len {
let mut batch_acc: Vec<Tensor> = Vec::new();
let mut labels: Vec<Tensor> = Vec::new();
for image in 0..(data.len() - (batch_size * min_len) as usize) {
let i: usize = (min_len * batch_size + (image as i64)) as usize;
let item = &data[i];
import_image(item, base_path, classes_len, &mut batch_acc, &mut labels);
}
inputs.push(tch::Tensor::cat(&batch_acc[0..], 0));
all_labels.push(tch::Tensor::cat(&labels[0..], 0));
}
println!("ins shape: {:?}", inputs[0].size());
return DataLoader {
batch_size,
inputs,
labels: all_labels,
len: max_len as usize,
pos: 0,
};
}
pub fn restart(self: &mut DataLoader) {
self.pos = 0;
}
pub fn next(self: &mut DataLoader) -> Option<(Tensor, Tensor)> {
if self.pos >= self.len {
return None;
}
let input = self.inputs[self.pos].empty_like();
self.inputs[self.pos] = self.inputs[self.pos].clone(&input);
let label = self.labels[self.pos].empty_like();
self.labels[self.pos] = self.labels[self.pos].clone(&label);
self.pos += 1;
return Some((input, label));
}
}

View File

@ -1,206 +0,0 @@
mod dataloader;
mod model;
mod settings;
mod tasks;
mod training;
mod types;
use crate::settings::*;
use crate::tasks::{fail_task, Task, TaskType};
use crate::training::handle_train;
use anyhow::{bail, Result};
use reqwest::StatusCode;
use serde_json::json;
use std::{fs, process::exit, sync::Arc, time::Duration};
enum ResultAlive {
Ok,
Error,
NotInit,
}
async fn send_keep_alive_message(
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> ResultAlive {
let client = reqwest::Client::new();
let to_send = json!({
"id": runner_data.id,
});
let resp = client
.post(format!("{}/tasks/runner/beat", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await;
if resp.is_err() {
return ResultAlive::Error;
}
let resp = resp.ok();
if resp.is_none() {
return ResultAlive::Error;
}
let resp = resp.unwrap();
// TODO see if the message is related to not being inited
if resp.status() != 200 {
println!("Could not connect with the status");
return ResultAlive::Error;
}
ResultAlive::Ok
}
async fn keep_alive(config: Arc<ConfigFile>, runner_data: Arc<RunnerData>) -> Result<()> {
let mut failed = 0;
loop {
match send_keep_alive_message(config.clone(), runner_data.clone()).await {
ResultAlive::Error => failed += 1,
ResultAlive::NotInit => {
println!("Runner not inited! Restarting!");
exit(1)
}
ResultAlive::Ok => failed = 0,
}
// TODO move to config
if failed > 20 {
println!("Failed to connect to API! More than 20 times in a row stoping");
exit(1)
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
async fn handle_task(
task: Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> Result<()> {
let res = match task.task_type {
TaskType::Training => handle_train(&task, config.clone(), runner_data.clone()).await,
_ => {
println!("Do not know how to handle this task #{:?}", task);
bail!("Failed")
}
};
if res.is_err() {
println!("task failed #{:?}", res);
fail_task(
&task,
config,
runner_data,
"Do not know how to handle this kind of task",
)
.await?
}
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
// Load config file
let config_data = fs::read_to_string("./config.toml")?;
let mut config: ConfigFile = toml::from_str(&config_data)?;
let client = reqwest::Client::new();
if config.config_path == None {
config.config_path = Some(String::from("./data.toml"))
}
let runner_data: RunnerData = load_runner_data(&config).await?;
let to_send = json!({
"id": runner_data.id,
});
// Inform the server that the runner is available
let resp = client
.post(format!("{}/tasks/runner/init", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if resp.status() != 200 {
println!(
"Could not connect with the api: status {} body {}",
resp.status(),
resp.text().await?
);
return Ok(());
}
let res = resp.json::<String>().await?;
if res != "Ok" {
print!("Unexpected problem: {}", res);
return Ok(());
}
let config = Arc::new(config);
let runner_data = Arc::new(runner_data);
let config_alive = config.clone();
let runner_data_alive = runner_data.clone();
std::thread::spawn(move || keep_alive(config_alive.clone(), runner_data_alive.clone()));
println!("Started main loop");
loop {
//TODO move time to config
tokio::time::sleep(Duration::from_secs(1)).await;
let to_send = json!({ "id": runner_data.id });
let resp = client
.post(format!("{}/tasks/runner/active", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await;
if resp.is_err() || resp.as_ref().ok().is_none() {
println!("Failed to get info from server {:?}", resp);
continue;
}
let resp = resp?;
match resp.status() {
// No active task
StatusCode::NOT_FOUND => (),
StatusCode::OK => {
println!("Found task!");
let task: Result<Task, reqwest::Error> = resp.json().await;
if task.is_err() || task.as_ref().ok().is_none() {
println!("Failed to resolve the json {:?}", task);
continue;
}
let task = task?;
let res = handle_task(task, config.clone(), runner_data.clone()).await;
if res.is_err() || res.as_ref().ok().is_none() {
println!("Failed to run the task");
}
_ = res;
()
}
_ => {
println!("Unexpected error #{:?}", resp);
exit(1)
}
}
}
}

View File

@ -1,117 +0,0 @@
use anyhow::bail;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use tch::{
nn::{self, Module},
Device,
};
#[derive(Debug)]
pub struct Model {
pub vs: nn::VarStore,
pub seq: nn::Sequential,
pub layers: Vec<Layer>,
}
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum LayerType {
Input = 1,
Dense = 2,
Flatten = 3,
SimpleBlock = 4,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Layer {
pub id: String,
pub definition_id: String,
pub layer_order: String,
pub layer_type: LayerType,
pub shape: String,
pub exp_type: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DataPoint {
pub class: i64,
pub path: String,
}
pub fn build_model(layers: Vec<Layer>, last_linear_size: i64, add_sigmoid: bool) -> Model {
let vs = nn::VarStore::new(Device::Cuda(0));
let mut seq = nn::seq();
let mut last_linear_size = last_linear_size;
let mut last_linear_conv: Vec<i64> = Vec::new();
for layer in layers.iter() {
match layer.layer_type {
LayerType::Input => {
last_linear_conv = serde_json::from_str(&layer.shape).ok().unwrap();
println!("Layer: Input, In: {:?}", last_linear_conv);
}
LayerType::Dense => {
let shape: Vec<i64> = serde_json::from_str(&layer.shape).ok().unwrap();
println!("Layer: Dense, In: {}, Out: {}", last_linear_size, shape[0]);
seq = seq
.add(nn::linear(
&vs.root(),
last_linear_size,
shape[0],
Default::default(),
))
.add_fn(|xs| xs.relu());
last_linear_size = shape[0];
}
LayerType::Flatten => {
seq = seq.add_fn(|xs| xs.flatten(1, -1));
last_linear_size = 1;
for i in &last_linear_conv {
last_linear_size *= i;
}
println!(
"Layer: flatten, In: {:?}, Out: {}",
last_linear_conv, last_linear_size
)
}
LayerType::SimpleBlock => {
let new_last_linear_conv =
vec![128, last_linear_conv[1] / 2, last_linear_conv[2] / 2];
println!(
"Layer: block, In: {:?}, Put: {:?}",
last_linear_conv, new_last_linear_conv,
);
let out_size = vec![new_last_linear_conv[1], new_last_linear_conv[2]];
seq = seq
.add(nn::conv2d(
&vs.root(),
last_linear_conv[0],
128,
3,
nn::ConvConfig::default(),
))
.add_fn(|xs| xs.relu())
.add(nn::conv2d(
&vs.root(),
128,
128,
3,
nn::ConvConfig::default(),
))
.add_fn(|xs| xs.relu())
.add_fn(move |xs| xs.adaptive_avg_pool2d([out_size[1], out_size[1]]))
.add_fn(|xs| xs.leaky_relu());
//m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0]))
last_linear_conv = new_last_linear_conv;
}
}
}
if add_sigmoid {
seq = seq.add_fn(|xs| xs.sigmoid());
}
return Model { vs, layers, seq };
}

View File

@ -1,57 +0,0 @@
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{fs, path};
#[derive(Deserialize)]
pub struct ConfigFile {
// Hostname to connect with the api
pub hostname: String,
// Token used in the api to authenticate
pub token: String,
// Path to where to store some generated configuration values
// defaults to ./data.toml
pub config_path: Option<String>,
// Data Path
// Path to where the data is mounted
pub data_path: String,
}
#[derive(Deserialize, Serialize)]
pub struct RunnerData {
pub id: String,
}
pub async fn load_runner_data(config: &ConfigFile) -> Result<RunnerData> {
let data_path = config.config_path.as_ref().unwrap();
let data_path = path::Path::new(&*data_path);
if data_path.exists() {
let runner_data = fs::read_to_string(data_path)?;
Ok(toml::from_str(&runner_data)?)
} else {
let client = reqwest::Client::new();
let to_send = json!({
"token": config.token,
"type": 1,
});
let register_resp = client
.post(format!("{}/tasks/runner/register", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if register_resp.status() != 200 {
bail!(format!("Could not create runner {:#?}", register_resp));
}
let runner_data: RunnerData = register_resp.json().await?;
fs::write(data_path, toml::to_string(&runner_data)?)
.expect("Faield to save data for runner");
Ok(runner_data)
}
}

View File

@ -1,90 +0,0 @@
use std::sync::Arc;
use anyhow::{bail, Result};
use serde::Deserialize;
use serde_json::json;
use serde_repr::Deserialize_repr;
use crate::{ConfigFile, RunnerData};
#[derive(Clone, Copy, Deserialize_repr, Debug)]
#[repr(i8)]
pub enum TaskStatus {
FailedRunning = -2,
FailedCreation = -1,
Preparing = 0,
Todo = 1,
PickedUp = 2,
Running = 3,
Done = 4,
}
#[derive(Clone, Copy, Deserialize_repr, Debug)]
#[repr(i8)]
pub enum TaskType {
Classification = 1,
Training = 2,
Retraining = 3,
DeleteUser = 4,
}
#[derive(Deserialize, Debug)]
pub struct Task {
pub id: String,
pub user_id: String,
pub model_id: String,
pub status: TaskStatus,
pub status_message: String,
pub user_confirmed: i8,
pub compacted: i8,
#[serde(alias = "type")]
pub task_type: TaskType,
pub extra_task_info: String,
pub result: String,
pub created: String,
}
pub async fn fail_task(
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
reason: &str,
) -> Result<()> {
println!("Marking Task as failed");
let client = reqwest::Client::new();
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
"reason": reason
});
let resp = client
.post(format!("{}/tasks/runner/fail", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if resp.status() != 200 {
println!("Failed to update status of task");
bail!("Failed to update status of task");
}
Ok(())
}
impl Task {
pub async fn fail(
self: &mut Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
reason: &str,
) -> Result<()> {
fail_task(self, config, runner_data, reason).await?;
self.status = TaskStatus::FailedRunning;
self.status_message = reason.to_string();
Ok(())
}
}

View File

@ -1,599 +0,0 @@
use crate::{
dataloader::DataLoader,
model::{self, build_model},
settings::{ConfigFile, RunnerData},
tasks::{fail_task, Task},
types::{DataPointRequest, Definition, ModelClass},
};
use std::{
io::{self, Write},
sync::Arc,
};
use anyhow::Result;
use rand::{seq::SliceRandom, thread_rng};
use serde_json::json;
use tch::{
nn::{self, Module, OptimizerConfig},
Cuda, Tensor,
};
pub async fn handle_train(
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> Result<()> {
let client = reqwest::Client::new();
println!("Preparing to train a model");
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
});
let mut defs: Vec<Definition> = client
.post(format!("{}/tasks/runner/train/defs", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
if defs.len() == 0 {
println!("No defs found");
fail_task(task, config, runner_data, "No definitions found").await?;
return Ok(());
}
let classes: Vec<ModelClass> = client
.post(format!("{}/tasks/runner/train/classes", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
let data: DataPointRequest = client
.post(format!("{}/tasks/runner/train/datapoints", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
let mut testing = data.testing;
testing.shuffle(&mut thread_rng());
let mut data_loader = DataLoader::new(config.clone(), testing, classes.len() as i64, 64);
// TODO make this a vec
let mut model: Option<model::Model> = None;
loop {
let config = config.clone();
let runner_data = runner_data.clone();
let mut to_remove: Vec<usize> = Vec::new();
let mut def_iter = defs.iter_mut();
let mut i: usize = 0;
while let Some(def) = def_iter.next() {
def.updateStatus(
task,
config.clone(),
runner_data.clone(),
crate::types::DefinitionStatus::Training,
)
.await?;
let model_err = train_definition(
def,
&mut data_loader,
model,
config.clone(),
runner_data.clone(),
&task,
)
.await;
if model_err.is_err() {
println!("Failed to create model {:?}", model_err);
model = None;
to_remove.push(i);
continue;
}
model = model_err?;
i += 1;
}
defs = defs
.into_iter()
.enumerate()
.filter(|&(i, _)| to_remove.iter().any(|b| *b == i))
.map(|(_, e)| e)
.collect();
break;
}
fail_task(task, config, runner_data, "TODO").await?;
Ok(())
/*
for {
// Keep track of definitions that did not train fast enough
var toRemove ToRemoveList = []int{}
for i, def := range definitions {
accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes)
if err != nil {
log.Error("Failed to train definition!Err:", "err", err)
def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
models[def.Id] = ml_model
if accuracy >= float64(def.TargetAccuracy) {
log.Info("Found a definition that reaches target_accuracy!")
_, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
_, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
finished = true
break
}
if def.Epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
_, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
}
if finished {
break
}
sort.Sort(sort.Reverse(toRemove))
log.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove {
// Clean up unsed models
models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
len_def := len(definitions)
if len_def == 0 {
break
}
if len_def == 1 {
continue
}
sort.Sort(sort.Reverse(definitions))
acc := definitions[0].Accuracy - 20.0
log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc)
toRemove = []int{}
for i, def := range definitions {
if def.Accuracy < acc {
toRemove = append(toRemove, i)
}
}
log.Info("Removing due to accuracy", "toRemove", toRemove)
sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove {
log.Warn("Removing definition not fast enough learning", "n", n)
definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
}
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED)
if err != nil {
if err == NotFoundError {
log.Error("All definitions failed to train!")
} else {
log.Error("DB: failed to read definition", "err", err)
}
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
log.Error("Failed to update model definition", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id)
if err != nil {
log.Error("Failed to select model_definition to delete")
log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
defer to_delete.Close()
for to_delete.Next() {
var id string
if err = to_delete.Scan(&id); err != nil {
log.Error("Failed to scan the id of a model_definition to delete", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
}
// TODO Check if returning also works here
if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil {
log.Error("Failed to delete model_definition")
log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
ModelUpdateStatus(c, model.Id, READY)
return
*/
}
async fn train_definition(
def: &Definition,
data_loader: &mut DataLoader,
model: Option<model::Model>,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
task: &Task,
) -> Result<Option<model::Model>> {
let client = reqwest::Client::new();
println!("About to start training definition");
let mut accuracy = 0;
let model = model.unwrap_or({
let layers: Vec<model::Layer> = client
.post(format!("{}/tasks/runner/train/def/layers", config.hostname))
.header("token", &config.token)
.body(
json!({
"id": runner_data.id,
"taskId": task.id,
"defId": def.id,
})
.to_string(),
)
.send()
.await?
.json()
.await?;
build_model(layers, 0, true)
});
// TODO CUDA
// get device
// Move model to cuda
let mut opt = nn::Adam::default().build(&model.vs, 1e-3)?;
let mut last_acc = 0.0;
for epoch in 1..40 {
data_loader.restart();
let mut mean_loss: f64 = 0.0;
let mut mean_acc: f64 = 0.0;
while let Some((inputs, labels)) = data_loader.next() {
let inputs = inputs
.to_kind(tch::Kind::Float)
.to_device(tch::Device::Cuda(0));
let labels = labels
.to_kind(tch::Kind::Float)
.to_device(tch::Device::Cuda(0));
let out = model.seq.forward(&inputs);
let weight: Option<Tensor> = None;
let loss = out.binary_cross_entropy(&labels, weight, tch::Reduction::Mean);
opt.backward_step(&loss);
mean_loss += loss
.to_device(tch::Device::Cpu)
.unsqueeze(0)
.double_value(&[0]);
let out = out.to_device(tch::Device::Cpu);
let test = out.empty_like();
_ = out.clone(&test);
let out = test.argmax(1, true);
let mut labels = labels.to_device(tch::Device::Cpu);
labels = labels.unsqueeze(-1);
let size = out.size()[0];
let mut acc = 0;
for i in 0..size {
let res = out.double_value(&[i]);
let exp = labels.double_value(&[i, res as i64]);
if exp == 1.0 {
acc += 1;
}
}
mean_acc += acc as f64 / size as f64;
last_acc = acc as f64 / size as f64;
}
print!(
"\repoch: {} loss: {} acc: {} l acc: {} ",
epoch,
mean_loss / data_loader.len as f64,
mean_acc / data_loader.len as f64,
last_acc
);
io::stdout().flush().expect("Unable to flush stdout");
}
println!("\nlast acc: {}", last_acc);
return Ok(Some(model));
/*
opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ {
var trainIter *torch.Iter2
trainIter, err = ds.TrainIter(32)
if err != nil {
return
}
// trainIter.ToDevice(device)
log.Info("epoch", "epoch", epoch)
var trainLoss float64 = 0
var trainCorrect float64 = 0
ok := true
for ok {
var item torch.Iter2Item
var loss *torch.Tensor
item, ok = trainIter.Next()
if !ok {
continue
}
data := item.Data
data, err = data.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
var size []int64
size, err = data.Size()
if err != nil {
return
}
var zeros *torch.Tensor
zeros, err = torch.Zeros(size, gotch.Float, device)
if err != nil {
return
}
data, err = zeros.Add(data, true)
if err != nil {
return
}
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
data, err = data.SetRequiresGrad(true, false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
err = data.RetainGrad(false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
pred := model.ForwardT(data, true)
pred, err = pred.SetRequiresGrad(true, true)
if err != nil {
return
}
err = pred.RetainGrad(false)
if err != nil {
return
}
label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
label, err = label.SetRequiresGrad(true, true)
if err != nil {
return
}
err = label.RetainGrad(false)
if err != nil {
return
}
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
if err != nil {
return
}
loss, err = loss.SetRequiresGrad(true, false)
if err != nil {
return
}
err = loss.RetainGrad(false)
if err != nil {
return
}
err = opt.ZeroGrad()
if err != nil {
return
}
err = loss.Backward()
if err != nil {
return
}
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
vars := model.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false))
}
model.Debug()
err = opt.Step()
if err != nil {
return
}
trainLoss = loss.Float64Values()[0]
// Calculate accuracy
/ *var p_pred, p_labels *torch.Tensor
p_pred, err = pred.Argmax([]int64{1}, true, false)
if err != nil {
return
}
p_labels, err = item.Label.Argmax([]int64{1}, true, false)
if err != nil {
return
}
floats := p_pred.Float64Values()
floats_labels := p_labels.Float64Values()
for i := range floats {
if floats[i] == floats_labels[i] {
trainCorrect += 1
}
} * /
panic("fornow")
}
//v := []float64{}
log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize))
/ *correct := int64(0)
//torch.NoGrad(func() {
ok = true
testIter := ds.TestIter(64)
for ok {
var item torch.Iter2Item
item, ok = testIter.Next()
if !ok {
continue
}
output := model.Forward(item.Data)
var pred, labels *torch.Tensor
pred, err = output.Argmax([]int64{1}, true, false)
if err != nil {
return
}
labels, err = item.Label.Argmax([]int64{1}, true, false)
if err != nil {
return
}
floats := pred.Float64Values()
floats_labels := labels.Float64Values()
for i := range floats {
if floats[i] == floats_labels[i] {
correct += 1
}
}
}
accuracy = float64(correct) / float64(ds.TestImagesSize)
log.Info("Eval accuracy", "accuracy", accuracy)
err = def.UpdateAfterEpoch(db, accuracy*100)
if err != nil {
return
}* /
//})
}
result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id)
err = os.MkdirAll(result_path, os.ModePerm)
if err != nil {
return
}
err = my_torch.SaveModel(model, path.Join(result_path, "model.dat"))
if err != nil {
return
}
log.Info("Model finished training!", "accuracy", accuracy)
return
*/
}

View File

@ -1,89 +0,0 @@
use crate::{model, tasks::Task, ConfigFile, RunnerData};
use anyhow::{bail, Result};
use serde::Deserialize;
use serde_json::json;
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::sync::Arc;
#[derive(Clone, Copy, Deserialize_repr, Serialize_repr, Debug)]
#[repr(i8)]
pub enum DefinitionStatus {
CanceldTraining = -4,
FailedTraining = -3,
PreInit = 1,
Init = 2,
Training = 3,
PausedTraining = 6,
Tranied = 4,
Ready = 5,
}
#[derive(Deserialize, Debug)]
pub struct Definition {
pub id: String,
pub model_id: String,
pub accuracy: f64,
pub target_accuracy: i64,
pub epoch: i64,
pub status: i64,
pub created: String,
pub epoch_progress: i64,
}
impl Definition {
pub async fn updateStatus(
self: &mut Definition,
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
status: DefinitionStatus,
) -> Result<()> {
println!("Marking Task as faield");
let client = reqwest::Client::new();
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
"defId": self.id,
"status": status,
});
let resp = client
.post(format!("{}/tasks/runner/train/def/status", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if resp.status() != 200 {
println!("Failed to update status of task");
bail!("Failed to update status of task");
}
Ok(())
}
}
#[derive(Clone, Copy, Deserialize_repr, Debug)]
#[repr(i8)]
pub enum ModelClassStatus {
ToTrain = 1,
Training = 2,
Trained = 3,
}
#[derive(Deserialize, Debug)]
pub struct ModelClass {
pub id: String,
pub model_id: String,
pub name: String,
pub class_order: i64,
pub status: ModelClassStatus,
}
#[derive(Deserialize, Debug)]
pub struct DataPointRequest {
pub testing: Vec<model::DataPoint>,
pub training: Vec<model::DataPoint>,
}

View File

@ -1 +1 @@
CREATE DATABASE aistuff; CREATE DATABASE fyp;

View File

@ -35,7 +35,7 @@ create table if not exists model_classes (
-- 1: to_train -- 1: to_train
-- 2: training -- 2: training
-- 3: trained -- 3: trained
status integer default 1, status integer default 1
); );
-- drop table if exists model_data_point; -- drop table if exists model_data_point;
@ -59,7 +59,6 @@ create table if not exists model_definition (
accuracy real default 0, accuracy real default 0,
target_accuracy integer not null, target_accuracy integer not null,
epoch integer default 0, epoch integer default 0,
-- TODO add max epoch
-- 1: Pre Init -- 1: Pre Init
-- 2: Init -- 2: Init
-- 3: Training -- 3: Training
@ -78,7 +77,7 @@ create table if not exists model_definition_layer (
-- 1: input -- 1: input
-- 2: dense -- 2: dense
-- 3: flatten -- 3: flatten
-- TODO add conv -- 4: block
layer_type integer not null, layer_type integer not null,
-- ei 28,28,1 -- ei 28,28,1
-- a 28x28 grayscale image -- a 28x28 grayscale image
@ -102,7 +101,6 @@ create table if not exists exp_model_head (
accuracy real default 0, accuracy real default 0,
-- TODO add max epoch
-- 1: Pre Init -- 1: Pre Init
-- 2: Init -- 2: Init
-- 3: Training -- 3: Training

View File

@ -143,6 +143,15 @@ def addBlock(
model.add(layers.Dropout(0.4)) model.add(layers.Dropout(0.4))
return model return model
def resblock(x, kernelsize = 3, filters = 128):
fx = layers.Conv2D(filters, kernelsize, activation='relu', padding='same')(x)
fx = layers.BatchNormalization()(fx)
fx = layers.Conv2D(filters, kernelsize, padding='same')(fx)
out = layers.Add()([x,fx])
out = layers.ReLU()(out)
out = layers.BatchNormalization()(out)
return out
{{ if .LoadPrev }} {{ if .LoadPrev }}
model = tf.keras.saving.load_model('{{.LastModelRunPath}}') model = tf.keras.saving.load_model('{{.LastModelRunPath}}')

1
webpage/.dockerignore Symbolic link
View File

@ -0,0 +1 @@
.gitignore

View File

@ -27,5 +27,11 @@ module.exports = {
parser: '@typescript-eslint/parser' parser: '@typescript-eslint/parser'
} }
} }
] ],
rules: {
'svelte/no-at-html-tags': 'off',
// TODO remove this
'@typescript-eslint/no-explicit-any': 'off'
}
}; };

9
webpage/Dockerfile Normal file
View File

@ -0,0 +1,9 @@
FROM docker.io/node:22
ADD . .
RUN npm install
RUN npm run build
CMD ["npm", "run", "preview"]

Binary file not shown.

4125
webpage/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,38 +1,41 @@
{ {
"name": "webpage", "name": "webpage",
"version": "0.0.1", "version": "0.0.1",
"private": true, "private": true,
"scripts": { "scripts": {
"dev:raw": "vite dev", "dev:raw": "vite dev",
"dev": "vite dev --port 5001 --host", "dev": "vite dev --port 5001 --host",
"build": "vite build", "build": "vite build",
"preview": "vite preview", "preview": "vite preview --port 5001 --host",
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json", "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch", "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
"lint": "prettier --check . && eslint .", "lint": "prettier --check . && eslint .",
"format": "prettier --write ." "format": "prettier --write ."
}, },
"devDependencies": { "devDependencies": {
"@sveltejs/adapter-auto": "^3.2.0", "@sveltejs/adapter-auto": "^3.2.0",
"@sveltejs/kit": "^2.5.6", "@sveltejs/kit": "^2.5.6",
"@sveltejs/vite-plugin-svelte": "3.0.0", "@sveltejs/vite-plugin-svelte": "^3.0.0",
"@types/eslint": "^8.56.9", "@types/d3": "^7.4.3",
"@typescript-eslint/eslint-plugin": "^7.7.0", "@types/eslint": "^8.56.9",
"@typescript-eslint/parser": "^7.7.0", "@typescript-eslint/eslint-plugin": "^7.7.0",
"eslint": "^8.57.0", "@typescript-eslint/parser": "^7.7.0",
"eslint-config-prettier": "^9.1.0", "eslint": "^8.57.0",
"eslint-plugin-svelte": "^2.37.0", "eslint-config-prettier": "^9.1.0",
"prettier": "^3.2.5", "eslint-plugin-svelte": "^2.37.0",
"prettier-plugin-svelte": "^3.2.3", "prettier": "^3.2.5",
"sass": "^1.75.0", "prettier-plugin-svelte": "^3.2.3",
"svelte": "5.0.0-next.104", "sass": "^1.75.0",
"svelte-check": "^3.6.9", "svelte": "^5.0.0-next.104",
"tslib": "^2.6.2", "svelte-check": "^3.6.9",
"typescript": "^5.4.5", "tslib": "^2.6.2",
"vite": "^5.2.8" "typescript": "^5.4.5",
}, "vite": "^5.2.8"
"type": "module", },
"dependencies": { "type": "module",
"chart.js": "^4.4.2" "dependencies": {
} "chart.js": "^4.4.2",
"d3": "^7.9.0",
"highlight.js": "^11.9.0"
}
} }

View File

@ -15,8 +15,18 @@
{/if} {/if}
<li class="expand"></li> <li class="expand"></li>
{#if userStore.user} {#if userStore.user}
{#if userStore.user.user_type == 2}
<li>
<a href="/admin/runners">
<span class="bi bi-cpu-fill"></span>
Runner
</a>
</li>
{/if}
<li> <li>
<a href="/user/info"> <span class="bi bi-person-fill"></span> {userStore.user.username} </a> <a href="/user/info"> <span class="bi bi-person-fill"></span> {userStore.user.username} </a>
</li>
<li>
<a href="/logout"> <span class="bi bi-box-arrow-right"></span> Logout </a> <a href="/logout"> <span class="bi bi-box-arrow-right"></span> Logout </a>
</li> </li>
{:else} {:else}

View File

@ -9,6 +9,8 @@
<link rel="preconnect" href="https://fonts.googleapis.com" /> <link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.googleapis.com" /> <link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin /> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
<link href="https://fonts.googleapis.com/css2?family=Andada+Pro:ital,wght@0,400..840;1,400..840&family=Bebas+Neue&family=Fira+Code:wght@300..700&display=swap" rel="stylesheet">
<link <link
href="https://fonts.googleapis.com/css2?family=Andada+Pro:ital,wght@0,400..840;1,400..840&family=Bebas+Neue&display=swap" href="https://fonts.googleapis.com/css2?family=Andada+Pro:ital,wght@0,400..840;1,400..840&family=Bebas+Neue&display=swap"
rel="stylesheet" rel="stylesheet"

View File

@ -36,7 +36,7 @@
class="icon" class="icon"
class:adapt={replace_slot && file && !notExpand} class:adapt={replace_slot && file && !notExpand}
type="button" type="button"
on:click={() => fileInput.click()} onclick={() => fileInput.click()}
> >
{#if replace_slot && file} {#if replace_slot && file}
<slot name="replaced" {file}> <slot name="replaced" {file}>
@ -54,6 +54,6 @@
required required
{accept} {accept}
bind:this={fileInput} bind:this={fileInput}
on:change={onChange} onchange={onChange}
/> />
</div> </div>

View File

@ -1,57 +0,0 @@
<script context="module" lang="ts">
export type DisplayFn = (
msg: string,
options?: {
type?: 'error' | 'success';
timeToShow?: number;
}
) => void;
</script>
<script lang="ts">
let message = $state<string | undefined>(undefined);
let type = $state<'error' | 'success'>('error');
let timeout: number | undefined = undefined;
export function clear() {
if (timeout) clearTimeout(timeout);
message = undefined;
}
export function display(
msg: string,
options?: {
type?: 'error' | 'success';
timeToShow?: number;
}
) {
if (timeout) clearTimeout(timeout);
if (!msg) {
message = undefined;
return;
}
let { type: l_type, timeToShow } = options ?? { type: 'error', timeToShow: undefined };
if (l_type) {
type = l_type;
}
message = msg;
if (timeToShow) {
timeout = setTimeout(() => {
message = undefined;
timeout = undefined;
}, timeToShow);
}
}
</script>
{#if message}
<div class="form-msg {type}">
{message}
</div>
{/if}

View File

@ -1,5 +1,5 @@
<script lang="ts"> <script lang="ts">
let { title } = $props<{ title: string }>(); let { title }: { title: string } = $props();
let isHovered = $state(false); let isHovered = $state(false);
let x = $state(0); let x = $state(0);
@ -30,10 +30,10 @@
<div <div
bind:this={div} bind:this={div}
on:mouseover={mouseOver} onmouseover={mouseOver}
on:mouseleave={mouseLeave} onmouseleave={mouseLeave}
on:mousemove={mouseMove} onmousemove={mouseMove}
on:focus={focus} onfocus={focus}
role="tooltip" role="tooltip"
class="tooltipContainer" class="tooltipContainer"
> >

6
webpage/src/lib/utils.ts Normal file
View File

@ -0,0 +1,6 @@
export function preventDefault(fn: any) {
return function (event: Event) {
event.preventDefault();
fn.call(this, event);
};
}

View File

@ -1,7 +1,7 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { rdelete } from '$lib/requests.svelte'; import { rdelete } from '$lib/requests.svelte';
type User = { export type User = {
token: string; token: string;
id: string; id: string;
user_type: number; user_type: number;

View File

@ -0,0 +1,350 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { post, showMessage } from 'src/lib/requests.svelte';
import { userStore } from 'src/routes/UserStore.svelte';
import { onMount } from 'svelte';
import * as d3 from 'd3';
import type { Base } from './types';
import CardInfo from './CardInfo.svelte';
let width = $state(0);
let height = $state(0);
function drag(simulation: d3.Simulation<d3.HierarchyNode<Base>, undefined>) {
function dragstarted(event: any, d: any) {
if (!event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
selected = d.data;
}
function dragged(event: any, d: any) {
d.fx = event.x;
d.fy = event.y;
}
function dragended(event: any, d: any) {
if (!event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
return d3.drag().on('start', dragstarted).on('drag', dragged).on('end', dragended);
}
let graph: HTMLDivElement;
let selected: Base | undefined = $state();
async function getData() {
const dataObj: Base = {
name: 'API',
type: 'api',
children: []
};
if (!dataObj.children) throw new Error();
const localRunners: Base[] = [];
const remotePairs: Record<string, Base[]> = {};
try {
let data = await post('tasks/runner/info', {});
if (Object.keys(data.localRunners).length > 0) {
for (const objId of Object.keys(data.localRunners)) {
localRunners.push({ name: objId, type: 'local_runner', task: data.localRunners[objId] });
}
dataObj.children.push({
name: 'local runners',
type: 'runner_group',
children: localRunners
});
}
if (Object.keys(data.remoteRunners).length > 0) {
for (const objId of Object.keys(data.remoteRunners)) {
let obj = data.remoteRunners[objId];
if (remotePairs[obj.runner_info.user_id as string]) {
remotePairs[obj.runner_info.user_id as string].push({
name: objId,
type: 'runner',
task: obj.task,
parent: data.remoteRunners[objId].runner_info.user_id
});
} else {
remotePairs[data.remoteRunners[objId].runner_info.user_id] = [
{
name: objId,
type: 'runner',
task: obj.task,
parent: data.remoteRunners[objId].runner_info.user_id
}
];
}
}
}
dataObj.children.push({
name: 'remote runners',
type: 'runner_group',
task: undefined,
children: Object.keys(remotePairs).map(
(name) =>
({
name,
type: 'user_group',
task: undefined,
children: remotePairs[name]
}) as Base
)
});
} catch (ex) {
showMessage(ex, notificationStore, 'Failed to get Runner information');
return;
}
const root = d3.hierarchy(dataObj);
const links = root.links();
const nodes = root.descendants();
console.log(root, links, nodes);
const simulation = d3
.forceSimulation(nodes)
.force(
'link',
d3
.forceLink(links)
.id((d: any) => d.id)
.distance((d: any) => {
let data = d.source.data as Base;
switch (data.type) {
case 'api':
return 150;
case 'runner_group':
return 90;
case 'user_group':
return 80;
case 'runner':
case 'local_runner':
return 20;
default:
throw new Error();
}
})
.strength(1)
)
.force('charge', d3.forceManyBody().strength(-1000))
.force('x', d3.forceX())
.force('y', d3.forceY());
const svg = d3
.create('svg')
.attr('width', width)
.attr('height', height - 62)
.attr('viewBox', [-width / 2, -height / 2, width, height])
.attr('style', 'max-width: 100%; height: auto;');
// Append links.
const link = svg
.append('g')
.attr('stroke', '#999')
.attr('stroke-opacity', 0.6)
.selectAll('line')
.data(links)
.join('line');
const database_svg = `
<svg xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-database" viewBox="0 0 32 32">
<path transform="scale(2)" d="M4.318 2.687C5.234 2.271 6.536 2 8 2s2.766.27 3.682.687C12.644 3.125 13 3.627 13 4c0 .374-.356.875-1.318 1.313C10.766 5.729 9.464 6 8 6s-2.766-.27-3.682-.687C3.356 4.875 3 4.373 3 4c0-.374.356-.875 1.318-1.313M13 5.698V7c0 .374-.356.875-1.318 1.313C10.766 8.729 9.464 9 8 9s-2.766-.27-3.682-.687C3.356 7.875 3 7.373 3 7V5.698c.271.202.58.378.904.525C4.978 6.711 6.427 7 8 7s3.022-.289 4.096-.777A5 5 0 0 0 13 5.698M14 4c0-1.007-.875-1.755-1.904-2.223C11.022 1.289 9.573 1 8 1s-3.022.289-4.096.777C2.875 2.245 2 2.993 2 4v9c0 1.007.875 1.755 1.904 2.223C4.978 15.71 6.427 16 8 16s3.022-.289 4.096-.777C13.125 14.755 14 14.007 14 13zm-1 4.698V10c0 .374-.356.875-1.318 1.313C10.766 11.729 9.464 12 8 12s-2.766-.27-3.682-.687C3.356 10.875 3 10.373 3 10V8.698c.271.202.58.378.904.525C4.978 9.71 6.427 10 8 10s3.022-.289 4.096-.777A5 5 0 0 0 13 8.698m0 3V13c0 .374-.356.875-1.318 1.313C10.766 14.729 9.464 15 8 15s-2.766-.27-3.682-.687C3.356 13.875 3 13.373 3 13v-1.302c.271.202.58.378.904.525C4.978 12.71 6.427 13 8 13s3.022-.289 4.096-.777c.324-.147.633-.323.904-.525"/>
</svg>
`;
const cpu_svg = `
<svg stroke="white" fill="white" xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-cpu-fill" viewBox="0 0 32 32">
<path transform="scale(2)" d="M6.5 6a.5.5 0 0 0-.5.5v3a.5.5 0 0 0 .5.5h3a.5.5 0 0 0 .5-.5v-3a.5.5 0 0 0-.5-.5z"/>
<path transform="scale(2)" d="M5.5.5a.5.5 0 0 0-1 0V2A2.5 2.5 0 0 0 2 4.5H.5a.5.5 0 0 0 0 1H2v1H.5a.5.5 0 0 0 0 1H2v1H.5a.5.5 0 0 0 0 1H2v1H.5a.5.5 0 0 0 0 1H2A2.5 2.5 0 0 0 4.5 14v1.5a.5.5 0 0 0 1 0V14h1v1.5a.5.5 0 0 0 1 0V14h1v1.5a.5.5 0 0 0 1 0V14h1v1.5a.5.5 0 0 0 1 0V14a2.5 2.5 0 0 0 2.5-2.5h1.5a.5.5 0 0 0 0-1H14v-1h1.5a.5.5 0 0 0 0-1H14v-1h1.5a.5.5 0 0 0 0-1H14v-1h1.5a.5.5 0 0 0 0-1H14A2.5 2.5 0 0 0 11.5 2V.5a.5.5 0 0 0-1 0V2h-1V.5a.5.5 0 0 0-1 0V2h-1V.5a.5.5 0 0 0-1 0V2h-1zm1 4.5h3A1.5 1.5 0 0 1 11 6.5v3A1.5 1.5 0 0 1 9.5 11h-3A1.5 1.5 0 0 1 5 9.5v-3A1.5 1.5 0 0 1 6.5 5"/>
</svg>
`;
const user_svg = `
<svg fill="white" stroke="white" xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-person-fill" viewBox="0 0 32 32">
<path transform="scale(2)" d="M3 14s-1 0-1-1 1-4 6-4 6 3 6 4-1 1-1 1zm5-6a3 3 0 1 0 0-6 3 3 0 0 0 0 6"/>
</svg>
`;
const inbox_fill = `
<svg fill="white" stroke="white" xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-inbox-fill" viewBox="0 0 32 32">
<path transform="scale(2)" d="M4.98 4a.5.5 0 0 0-.39.188L1.54 8H6a.5.5 0 0 1 .5.5 1.5 1.5 0 1 0 3 0A.5.5 0 0 1 10 8h4.46l-3.05-3.812A.5.5 0 0 0 11.02 4zm-1.17-.437A1.5 1.5 0 0 1 4.98 3h6.04a1.5 1.5 0 0 1 1.17.563l3.7 4.625a.5.5 0 0 1 .106.374l-.39 3.124A1.5 1.5 0 0 1 14.117 13H1.883a1.5 1.5 0 0 1-1.489-1.314l-.39-3.124a.5.5 0 0 1 .106-.374z"/>
</svg>
`;
const node = svg
.append('g')
.attr('fill', '#fff')
.attr('stroke', '#000')
.attr('stroke-width', 1.5)
.selectAll('g')
.data(nodes)
.join('g')
.attr('style', 'cursor: pointer;')
.call(drag(simulation) as any)
.on('click', (e) => {
console.log('test');
function findData(obj: HTMLElement) {
if ((obj as any).__data__) {
return (obj as any).__data__;
}
if (!obj.parentElement) {
throw new Error();
}
return findData(obj.parentElement);
}
let obj = findData(e.srcElement);
console.log(obj);
selected = obj.data;
});
node
.append('circle')
.attr('fill', (d: any) => {
let data = d.data as Base;
switch (data.type) {
case 'api':
return '#caf0f8';
case 'runner_group':
return '#00b4d8';
case 'user_group':
return '#0000ff';
case 'runner':
case 'local_runner':
return '#03045e';
default:
throw new Error();
}
})
.attr('stroke', (d: any) => {
let data = d.data as Base;
switch (data.type) {
case 'api':
case 'user_group':
case 'runner_group':
return '#fff';
case 'runner':
case 'local_runner':
// TODO make this relient on the stauts
return '#000';
default:
throw new Error();
}
})
.attr('r', (d: any) => {
let data = d.data as Base;
switch (data.type) {
case 'api':
return 30;
case 'runner_group':
return 20;
case 'user_group':
return 25;
case 'runner':
case 'local_runner':
return 30;
default:
throw new Error();
}
})
.append('title')
.text((d: any) => d.data.name);
node
.filter((d) => {
return ['api', 'local_runner', 'runner', 'user_group', 'runner_group'].includes(
d.data.type
);
})
.append('g')
.html((d) => {
switch (d.data.type) {
case 'api':
return database_svg;
case 'user_group':
return user_svg;
case 'runner_group':
return inbox_fill;
case 'local_runner':
case 'runner':
return cpu_svg;
default:
throw new Error();
}
});
simulation.on('tick', () => {
link
.attr('x1', (d: any) => d.source.x)
.attr('y1', (d: any) => d.source.y)
.attr('x2', (d: any) => d.target.x)
.attr('y2', (d: any) => d.target.y);
node
.select('circle')
.attr('cx', (d: any) => d.x)
.attr('cy', (d: any) => d.y);
node
.select('svg')
.attr('x', (d: any) => d.x - 16)
.attr('y', (d: any) => d.y - 16);
});
//invalidation.then(() => simulation.stop());
graph.appendChild(svg.node() as any);
}
$effect(() => {
console.log(selected);
});
onMount(() => {
// Check if logged in and admin
if (!userStore.user || userStore.user.user_type != 2) {
goto('/');
return;
}
getData();
});
</script>
<svelte:window bind:innerWidth={width} bind:innerHeight={height} />
<svelte:head>
<title>Runners</title>
</svelte:head>
<div class="graph-container">
<div class="graph" bind:this={graph}></div>
{#if selected}
<div class="selected">
<CardInfo item={selected} />
</div>
{/if}
</div>
<style lang="css">
.graph-container {
position: relative;
.selected {
position: absolute;
right: 40px;
top: 40px;
width: 20%;
height: auto;
padding: 20px;
background: white;
border-radius: 20px;
box-shadow: 1px 1px 8px 2px #22222244;
}
}
</style>

View File

@ -0,0 +1,90 @@
<script lang="ts">
import { post, showMessage } from 'src/lib/requests.svelte';
import type { Base } from './types';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import type { User } from 'src/routes/UserStore.svelte';
import Spinner from 'src/lib/Spinner.svelte';
import Tooltip from 'src/lib/Tooltip.svelte';
let { item }: { item: Base } = $props();
let user_data: User | undefined = $state();
async function getUserData(id: string) {
try {
user_data = await post('user/info/get', { id });
console.log(user_data);
} catch (ex) {
showMessage(ex, notificationStore, 'Could not get user information');
}
}
$effect(() => {
user_data = undefined;
if (item.type == 'user_group') {
getUserData(item.name);
} else if (item.type == 'runner') {
getUserData(item.parent ?? '');
}
});
</script>
{#if item.type == 'api'}
<h3>API</h3>
{:else if item.type == 'runner_group'}
<h3>Runner Group</h3>
This reprents a the group of {item.name}.
{:else if item.type == 'user_group'}
<h3>User</h3>
{#if user_data}
All Runners connected to this node bellong to <span class="accent">{user_data.username}</span>
{:else}
<div style="text-align: center;">
<Spinner />
</div>
{/if}
{:else if item.type == 'local_runner'}
<h3>Local Runner</h3>
This is a local runner
<div>
{#if item.task}
This runner is runing a <Tooltip title={item.task.id}>task</Tooltip>
{:else}
Not running any task
{/if}
</div>
{:else if item.type == 'runner'}
<h3>Runner</h3>
{#if user_data}
<p>
This is a remote runner. This runner is owned by<span class="accent"
>{user_data?.username}</span
>
</p>
<div>
{#if item.task}
This runner is runing a <Tooltip title={item.task.id}>task</Tooltip>
{:else}
Not running any task
{/if}
</div>
{:else}
<div style="text-align: center;">
<Spinner />
</div>
{/if}
{:else}
{item.type}
{/if}
<style lang="scss">
h3 {
text-align: center;
margin: 0;
}
.accent {
background: #22222222;
padding: 1px;
border-radius: 5px;
}
</style>

View File

@ -0,0 +1,10 @@
import type { Task } from 'src/routes/models/edit/tasks/types';
export type BaseType = 'api' | 'runner_group' | 'user_group' | 'runner' | 'local_runner';
export type Base = {
name: string;
type: BaseType;
children?: Base[];
task?: Task;
parent?: string;
};

View File

@ -3,6 +3,7 @@
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { userStore } from '../UserStore.svelte'; import { userStore } from '../UserStore.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { preventDefault } from 'src/lib/utils';
let submitted = $state(false); let submitted = $state(false);
@ -39,7 +40,7 @@
<div class="login-page"> <div class="login-page">
<div> <div>
<h1>Login</h1> <h1>Login</h1>
<form on:submit|preventDefault={onSubmit} class:submitted> <form onsubmit={preventDefault(onSubmit)} class:submitted>
<fieldset> <fieldset>
<label for="email">Email</label> <label for="email">Email</label>
<input type="email" required name="email" bind:value={loginData.email} /> <input type="email" required name="email" bind:value={loginData.email} />

View File

@ -1,26 +1,22 @@
<script lang="ts"> <script lang="ts">
import MessageSimple from 'src/lib/MessageSimple.svelte';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import { get } from '$lib/requests.svelte'; import { get, showMessage } from '$lib/requests.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import Spinner from 'src/lib/Spinner.svelte';
let list = $state< let list = $state<
{ | {
name: string; name: string;
id: string; id: string;
}[] }[]
>([]); | undefined
>(undefined);
let message: MessageSimple;
onMount(async () => { onMount(async () => {
try { try {
list = await get('models'); list = await get('models');
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not request list of models');
message.display(await e.json());
} else {
message.display('Could not request list of models');
}
} }
}); });
</script> </script>
@ -30,39 +26,44 @@
</svelte:head> </svelte:head>
<main> <main>
<MessageSimple bind:this={message} /> {#if list}
{#if list.length > 0} {#if list.length > 0}
<div class="list-header"> <div class="list-header">
<h2>My Models</h2> <h2>My Models</h2>
<div class="expand"></div> <div class="expand"></div>
<a class="button" href="/models/add"> New </a> <a class="button" href="/models/add"> New </a>
</div> </div>
<table class="table"> <table class="table">
<thead> <thead>
<tr>
<th> Name </th>
<th>
<!-- Open Button -->
</th>
</tr>
</thead>
<tbody>
{#each list as item}
<tr> <tr>
<td> <th> Name </th>
{item.name} <th>
</td> <!-- Open Button -->
<td class="text-center"> </th>
<a class="button simple" href="/models/edit?id={item.id}"> Edit </a>
</td>
</tr> </tr>
{/each} </thead>
</tbody> <tbody>
</table> {#each list as item}
<tr>
<td>
{item.name}
</td>
<td class="text-center">
<a class="button simple" href="/models/edit?id={item.id}"> Edit </a>
</td>
</tr>
{/each}
</tbody>
</table>
{:else}
<h2 class="text-center">You don't have any models</h2>
<div class="text-center">
<a class="button padded" href="/models/add"> Create a new model </a>
</div>
{/if}
{:else} {:else}
<h2 class="text-center">You don't have any models</h2> <div style="text-align: center;">
<div class="text-center"> <Spinner />
<a class="button padded" href="/models/add"> Create a new model </a>
</div> </div>
{/if} {/if}
</main> </main>
@ -84,11 +85,4 @@
.list-header .expand { .list-header .expand {
flex-grow: 1; flex-grow: 1;
} }
.list-header .button,
.list-header button {
padding: 10px 10px;
height: calc(100% - 20px);
margin-top: 5px;
}
</style> </style>

View File

@ -1,15 +1,14 @@
<script lang="ts"> <script lang="ts">
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte'; import { postFormData, showMessage } from 'src/lib/requests.svelte';
import { postFormData } from 'src/lib/requests.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { preventDefault } from 'src/lib/utils';
let submitted = $state(false); let submitted = $state(false);
let message: MessageSimple;
let buttonClicked: Promise<void> = $state(Promise.resolve()); let buttonClicked: Promise<void> = $state(Promise.resolve());
let data = $state<{ let data = $state<{
@ -21,7 +20,6 @@
}); });
async function onSubmit() { async function onSubmit() {
message.display('');
buttonClicked = new Promise<void>(() => {}); buttonClicked = new Promise<void>(() => {});
if (!data.file || !data.name) return; if (!data.file || !data.name) return;
@ -34,11 +32,7 @@
let id = await postFormData('models/add', formData); let id = await postFormData('models/add', formData);
goto(`/models/edit?id=${id}`); goto(`/models/edit?id=${id}`);
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Was not able to create model');
message.display(await e.json());
} else {
message.display('Was not able to create model');
}
} }
buttonClicked = Promise.resolve(); buttonClicked = Promise.resolve();
@ -51,7 +45,7 @@
<main> <main>
<h1>Create new Model</h1> <h1>Create new Model</h1>
<form class:submitted on:submit|preventDefault={onSubmit}> <form class:submitted onsubmit={preventDefault(onSubmit)}>
<fieldset> <fieldset>
<label for="name">Name</label> <label for="name">Name</label>
<input id="name" name="name" required bind:value={data.name} /> <input id="name" name="name" required bind:value={data.name} />
@ -75,7 +69,6 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={message} />
{#await buttonClicked} {#await buttonClicked}
<div class="text-center">File Uploading</div> <div class="text-center">File Uploading</div>
{:then} {:then}

View File

@ -30,18 +30,19 @@
import BaseModelInfo from './BaseModelInfo.svelte'; import BaseModelInfo from './BaseModelInfo.svelte';
import DeleteModel from './DeleteModel.svelte'; import DeleteModel from './DeleteModel.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { get, rdelete } from 'src/lib/requests.svelte'; import { get, rdelete, showMessage } from 'src/lib/requests.svelte';
import MessageSimple from '$lib/MessageSimple.svelte'; import { preventDefault } from 'src/lib/utils';
import ModelData from './ModelData.svelte'; import ModelData from './ModelData.svelte';
import DeleteZip from './DeleteZip.svelte'; import DeleteZip from './DeleteZip.svelte';
import RunModel from './RunModel.svelte'; import RunModel from './RunModel.svelte';
import Tabs from 'src/lib/Tabs.svelte'; import Tabs from 'src/lib/Tabs.svelte';
import TasksDataPage from './TasksDataPage.svelte'; import TasksDataPage from './TasksDataPage.svelte';
import ModelDataPage from './ModelDataPage.svelte'; import ModelDataPage from './ModelDataPage.svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import Spinner from 'src/lib/Spinner.svelte';
let model: Promise<Model> = $state(new Promise(() => {})); let model: Promise<Model> = $state(new Promise(() => {}));
let _model: Model | undefined = $state(undefined); let _model: Model | undefined = $state(undefined);
@ -92,10 +93,7 @@
getModel(); getModel();
}); });
let resetMessages: MessageSimple;
async function resetModel() { async function resetModel() {
resetMessages.display('');
let _model = await model; let _model = await model;
try { try {
@ -105,11 +103,7 @@
getModel(); getModel();
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not reset model!');
resetMessages.display(await e.json());
} else {
resetMessages.display('Could not reset model!');
}
} }
} }
@ -147,7 +141,8 @@
<div slot="buttons" let:setActive let:isActive> <div slot="buttons" let:setActive let:isActive>
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('model')} type="button"
onclick={setActive('model')}
class:selected={isActive('model')} class:selected={isActive('model')}
> >
Model Model
@ -155,7 +150,8 @@
{#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)} {#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)}
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('model-data')} type="button"
onclick={setActive('model-data')}
class:selected={isActive('model-data')} class:selected={isActive('model-data')}
> >
Model Data Model Data
@ -164,7 +160,8 @@
{#if _model && [5, 6, 7, -6, -7].includes(_model.status)} {#if _model && [5, 6, 7, -6, -7].includes(_model.status)}
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('tasks')} type="button"
onclick={setActive('tasks')}
class:selected={isActive('tasks')} class:selected={isActive('tasks')}
> >
Tasks Tasks
@ -172,7 +169,7 @@
{/if} {/if}
</div> </div>
{#if _model} {#if _model}
<ModelDataPage model={_model} on:reload={getModel} active={isActive('model-data')} /> <ModelDataPage model={_model} onreload={getModel} active={isActive('model-data')} />
<TasksDataPage model={_model} active={isActive('tasks')} /> <TasksDataPage model={_model} active={isActive('tasks')} />
{/if} {/if}
<div class="content" class:selected={isActive('model')}> <div class="content" class:selected={isActive('model')}>
@ -192,7 +189,6 @@
<h1 class="text-center"> <h1 class="text-center">
{m.name} {m.name}
</h1> </h1>
<!-- TODO improve message -->
<h2 class="text-center">Failed to prepare model</h2> <h2 class="text-center">Failed to prepare model</h2>
<DeleteModel model={m} /> <DeleteModel model={m} />
@ -200,25 +196,23 @@
<!-- PRE TRAINING STATUS --> <!-- PRE TRAINING STATUS -->
{:else if m.status == 2} {:else if m.status == 2}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<ModelData model={m} on:reload={getModel} /> <ModelData model={m} onreload={getModel} />
<!-- {{ template "train-model-card" . }} --> <!-- {{ template "train-model-card" . }} -->
<DeleteModel model={m} /> <DeleteModel model={m} />
{:else if m.status == -2} {:else if m.status == -2}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<DeleteZip model={m} on:reload={getModel} /> <DeleteZip model={m} onreload={getModel} />
<DeleteModel model={m} /> <DeleteModel model={m} />
{:else if m.status == 3} {:else if m.status == 3}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<div class="card"> <div class="card">
<!-- TODO improve this --> Processing zip file... <Spinner />
Processing zip file...
</div> </div>
{:else if m.status == -3 || m.status == -4} {:else if m.status == -3 || m.status == -4}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<form on:submit={resetModel}> <form onsubmit={preventDefault(resetModel)}>
Failed Prepare for training.<br /> Failed Prepare for training.<br />
<div class="spacer"></div> <div class="spacer"></div>
<MessageSimple bind:this={resetMessages} />
<button class="danger"> Try Again </button> <button class="danger"> Try Again </button>
</form> </form>
<DeleteModel model={m} /> <DeleteModel model={m} />
@ -337,7 +331,7 @@
<div class="card">Model expading... Processing ZIP file</div> <div class="card">Model expading... Processing ZIP file</div>
{/if} {/if}
{#if m.status == -6} {#if m.status == -6}
<DeleteZip model={m} on:reload={getModel} expand /> <DeleteZip model={m} onreload={getModel} expand />
{/if} {/if}
{#if m.status == -7} {#if m.status == -7}
<form> <form>
@ -346,7 +340,7 @@
</form> </form>
{/if} {/if}
{#if m.model_type == 2} {#if m.model_type == 2}
<ModelData simple model={m} on:reload={getModel} /> <ModelData simple model={m} onreload={getModel} />
{/if} {/if}
<DeleteModel model={m} /> <DeleteModel model={m} />
{:else} {:else}
@ -383,10 +377,4 @@
table tr th:first-child { table tr th:first-child {
border-left: none; border-left: none;
} }
table tr td button,
table tr td .button {
padding: 5px 10px;
box-shadow: 0 2px 5px 1px #66666655;
}
</style> </style>

View File

@ -1,6 +1,6 @@
<script lang="ts"> <script lang="ts">
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
let { model } = $props<{ model: Model }>(); let { model }: { model: Model } = $props();
</script> </script>
<div class="card model-card"> <div class="card model-card">

View File

@ -1,39 +1,31 @@
<script lang="ts"> <script lang="ts">
import MessageSimple from 'src/lib/MessageSimple.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import { rdelete } from '$lib/requests.svelte'; import { rdelete, showMessage } from '$lib/requests.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
let { model } = $props<{ model: Model }>(); let { model }: { model: Model } = $props();
let name: string = $state(''); let name: string = $state('');
let submmited: boolean = $state(false); let submmited: boolean = $state(false);
let messageSimple: MessageSimple;
async function deleteModel() { async function deleteModel() {
submmited = true; submmited = true;
messageSimple.display('');
try { try {
await rdelete('models/delete', { id: model.id, name }); await rdelete('models/delete', { id: model.id, name });
goto('/models'); goto('/models');
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not delete the model');
messageSimple.display(await e.json());
} else {
messageSimple.display('Could not delete the model');
}
} }
} }
</script> </script>
<form on:submit|preventDefault={deleteModel} class:submmited class="danger-bg"> <form onsubmit={deleteModel} class:submmited class="danger-bg">
<fieldset> <fieldset>
<label for="name"> <label for="name">
To delete this model please type "{model.name}": To delete this model please type "{model.name}":
</label> </label>
<input name="name" id="name" required bind:value={name} /> <input name="name" id="name" required bind:value={name} />
</fieldset> </fieldset>
<MessageSimple bind:this={messageSimple} />
<button class="danger"> Delete </button> <button class="danger"> Delete </button>
</form> </form>

View File

@ -1,32 +1,30 @@
<script lang="ts"> <script lang="ts">
import { rdelete } from 'src/lib/requests.svelte'; import { rdelete, showMessage } from 'src/lib/requests.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte'; import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { createEventDispatcher } from 'svelte'; import { preventDefault } from 'src/lib/utils';
let message: MessageSimple; let {
model,
let { model, expand } = $props<{ model: Model; expand?: boolean }>(); expand,
onreload = () => {}
const dispatch = createEventDispatcher<{ reload: void }>(); }: {
model: Model;
expand?: boolean;
onreload?: () => void;
} = $props();
async function deleteZip() { async function deleteZip() {
message.clear();
try { try {
await rdelete('models/data/delete-zip-file', { id: model.id }); await rdelete('models/data/delete-zip-file', { id: model.id });
dispatch('reload'); onreload();
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not delete the zip file');
message.display(await e.json());
} else {
message.display('Could not delete the zip file');
}
} }
} }
</script> </script>
<form on:submit|preventDefault={deleteZip}> <form onsubmit={preventDefault(deleteZip)}>
{#if expand} {#if expand}
Failed to proccess the zip file.<br /> Failed to proccess the zip file.<br />
Delete file and upload a correct version do add more classes.<br /> Delete file and upload a correct version do add more classes.<br />
@ -37,6 +35,5 @@
<br /> <br />
{/if} {/if}
<div class="spacer"></div> <div class="spacer"></div>
<MessageSimple bind:this={message} />
<button class="danger"> Delete Zip File </button> <button class="danger"> Delete Zip File </button>
</form> </form>

View File

@ -1,38 +1,29 @@
<script lang="ts" context="module">
export type Class = {
name: string;
id: string;
status: number;
};
</script>
<script lang="ts"> <script lang="ts">
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import Tabs from 'src/lib/Tabs.svelte'; import Tabs from 'src/lib/Tabs.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import { postFormData, get } from 'src/lib/requests.svelte'; import type { Class } from './types';
import MessageSimple from 'src/lib/MessageSimple.svelte'; import { postFormData, get, showMessage } from 'src/lib/requests.svelte';
import { createEventDispatcher } from 'svelte';
import ModelTable from './ModelTable.svelte'; import ModelTable from './ModelTable.svelte';
import TrainModel from './TrainModel.svelte'; import TrainModel from './TrainModel.svelte';
import ZipStructure from './ZipStructure.svelte'; import ZipStructure from './ZipStructure.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils';
let { model, simple } = $props<{ model: Model; simple?: boolean }>(); let {
model,
simple,
onreload = () => {}
}: { model: Model; simple?: boolean; onreload?: () => void } = $props();
let classes: Class[] = $state([]); let classes: Class[] = $state([]);
let has_data: boolean = $state(false); let has_data: boolean = $state(false);
let file: File | undefined = $state(); let file: File | undefined = $state();
const dispatch = createEventDispatcher<{
reload: void;
}>();
let uploading: Promise<void> = $state(Promise.resolve()); let uploading: Promise<void> = $state(Promise.resolve());
let numberOfInvalidImages = $state(0); let numberOfInvalidImages = $state(0);
let uploadImage: MessageSimple;
async function uploadZip() { async function uploadZip() {
if (!file) return; if (!file) return;
@ -44,13 +35,9 @@
try { try {
await postFormData('models/data/upload', form); await postFormData('models/data/upload', form);
dispatch('reload'); onreload();
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not upload data');
uploadImage.display(await e.json());
} else {
uploadImage.display('');
}
} }
uploading = Promise.resolve(); uploading = Promise.resolve();
@ -67,8 +54,8 @@
classes = data.classes; classes = data.classes;
numberOfInvalidImages = data.number_of_invalid_images; numberOfInvalidImages = data.number_of_invalid_images;
has_data = data.has_data; has_data = data.has_data;
} catch { } catch (e) {
return; showMessage(e, notificationStore, 'Could not get information on classes');
} }
} }
</script> </script>
@ -80,22 +67,22 @@
<p>You need to upload data so the model can train.</p> <p>You need to upload data so the model can train.</p>
<Tabs active="upload" let:isActive> <Tabs active="upload" let:isActive>
<div slot="buttons" let:setActive let:isActive> <div slot="buttons" let:setActive let:isActive>
<button class="tab" class:selected={isActive('upload')} on:click={setActive('upload')}> <button class="tab" class:selected={isActive('upload')} onclick={setActive('upload')}>
Upload Upload
</button> </button>
<button <!--button
class="tab" class="tab"
class:selected={isActive('create-class')} class:selected={isActive('create-class')}
on:click={setActive('create-class')} onclick={setActive('create-class')}
> >
Create Class Create Class
</button> </button-->
<button class="tab" class:selected={isActive('api')} on:click={setActive('api')}> <!--button class="tab" class:selected={isActive('api')} onclick={setActive('api')}>
Api Api
</button> </button-->
</div> </div>
<div class="content" class:selected={isActive('upload')}> <div class="content" class:selected={isActive('upload')}>
<form on:submit|preventDefault={uploadZip}> <form onsubmit={preventDefault(uploadZip)}>
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Data file</label> <label for="file">Data file</label>
<div class="form-msg"> <div class="form-msg">
@ -115,7 +102,6 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={uploadImage} />
{#if file} {#if file}
{#await uploading} {#await uploading}
<button disabled> Uploading </button> <button disabled> Uploading </button>
@ -125,10 +111,10 @@
{/if} {/if}
</form> </form>
</div> </div>
<div class="content" class:selected={isActive('create-class')}> <!--div class="content" class:selected={isActive('create-class')}>
<ModelTable {classes} {model} on:reload={() => dispatch('reload')} /> <ModelTable {classes} {model} {onreload} />
</div> </div-->
<div class="content" class:selected={isActive('api')}>TODO</div> <!--div class="content" class:selected={isActive('api')}>TODO</div-->
</Tabs> </Tabs>
<div class="tabs"></div> <div class="tabs"></div>
{:else} {:else}
@ -136,36 +122,14 @@
{#if numberOfInvalidImages > 0} {#if numberOfInvalidImages > 0}
<p class="danger"> <p class="danger">
There are images {numberOfInvalidImages} that were loaded that do not have the correct format. There are images {numberOfInvalidImages} that were loaded that do not have the correct format.
These images will be delete when the model trains. These images will be deleted when the model trains.
</p> </p>
{/if} {/if}
<Tabs active="create-class" let:isActive> <ModelTable {classes} {model} {onreload} />
<div slot="buttons" let:setActive let:isActive>
<button
class="tab"
class:selected={isActive('create-class')}
on:click={setActive('create-class')}
>
Create Class
</button>
<button class="tab" class:selected={isActive('api')} on:click={setActive('api')}>
Api
</button>
</div>
<div class="content" class:selected={isActive('create-class')}>
<ModelTable {classes} {model} on:reload={() => dispatch('reload')} />
</div>
<div class="content" class:selected={isActive('api')}>TODO</div>
</Tabs>
{/if} {/if}
</div> </div>
{/if} {/if}
{#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)} {#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)}
<TrainModel <TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} {onreload} />
number_of_invalid_images={numberOfInvalidImages}
{model}
{has_data}
on:reload={() => dispatch('reload')}
/>
{/if} {/if}

View File

@ -1,5 +1,4 @@
<script lang="ts"> <script lang="ts">
import { createEventDispatcher } from 'svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import ModelData from './ModelData.svelte'; import ModelData from './ModelData.svelte';
import { post, showMessage } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
@ -7,9 +6,11 @@
import type { ModelStats } from './types'; import type { ModelStats } from './types';
import DeleteZip from './DeleteZip.svelte'; import DeleteZip from './DeleteZip.svelte';
let { model, active }: { model: Model; active?: boolean } = $props(); let {
model,
const dispatch = createEventDispatcher<{ reload: void }>(); active,
onreload = () => {}
}: { model: Model; active?: boolean; onreload?: () => void } = $props();
$effect(() => { $effect(() => {
if (active) getData(); if (active) getData();
@ -34,14 +35,14 @@
{/if} {/if}
{#if [-6, -2].includes(model.status)} {#if [-6, -2].includes(model.status)}
<DeleteZip {model} on:reload={() => dispatch('reload')} expand /> <DeleteZip {model} {onreload} expand />
{/if} {/if}
<ModelData <ModelData
{model} {model}
on:reload={() => { onreload={() => {
getData(); getData();
dispatch('reload'); onreload();
}} }}
/> />
</div> </div>

View File

@ -97,7 +97,7 @@
}); });
</script> </script>
<div><canvas bind:this={ctx} /></div> <div><canvas bind:this={ctx}></canvas></div>
<style lang="scss"> <style lang="scss">
canvas { canvas {

View File

@ -1,33 +1,24 @@
<script lang="ts" context="module">
export type Image = {
file_path: string;
mode: number;
status: number;
id: string;
};
</script>
<script lang="ts"> <script lang="ts">
import Tabs from 'src/lib/Tabs.svelte'; import Tabs from 'src/lib/Tabs.svelte';
import type { Class } from './ModelData.svelte'; import type { Class, Image } from './types';
import { post, postFormData, rdelete, showMessage } from 'src/lib/requests.svelte'; import { post, postFormData, rdelete, showMessage } from 'src/lib/requests.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte';
import { createEventDispatcher } from 'svelte';
import ZipStructure from './ZipStructure.svelte'; import ZipStructure from './ZipStructure.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
const dispatch = createEventDispatcher<{ reload: void }>(); import { preventDefault } from 'src/lib/utils.js';
import CreateNewClass from './api/CreateNewClass.svelte';
let selected_class: Class | undefined = $state(); let selected_class: Class | undefined = $state();
let { classes, model }: { classes: Class[]; model: Model } = $props(); let { classes, model, onreload }: { classes: Class[]; model: Model; onreload?: () => void } =
$props();
let createClass: { className: string } = $state({ let createClass: { className: string } = $state({
className: '' className: ''
}); });
let page = $state(0); let page = $state(-1);
let showNext = $state(false); let showNext = $state(false);
let image_list = $state<Image[]>([]); let image_list = $state<Image[]>([]);
@ -41,9 +32,10 @@
}); });
async function getList() { async function getList() {
if (!selected_class) return;
try { try {
let res = await post('models/data/list', { let res = await post('models/data/list', {
id: selected_class?.id ?? '', id: selected_class.id,
page: page page: page
}); });
showNext = res.showNext; showNext = res.showNext;
@ -53,19 +45,20 @@
} }
} }
$effect(() => {
getList();
});
$effect(() => { $effect(() => {
if (selected_class) { if (selected_class) {
page = 0; page = 0;
getList();
} }
}); });
let file: File | undefined = $state(); let file: File | undefined = $state();
let uploadImage: MessageSimple;
let uploading = $state(Promise.resolve()); let uploading = $state(Promise.resolve());
async function uploadZip() { async function uploadZip() {
uploadImage.clear();
if (!file) return; if (!file) return;
uploading = new Promise(() => {}); uploading = new Promise(() => {});
@ -76,19 +69,14 @@
try { try {
await postFormData('models/data/class/upload', form); await postFormData('models/data/class/upload', form);
dispatch('reload'); if (onreload) onreload();
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Failed to upload');
uploadImage.display(await e.json());
} else {
uploadImage.display('');
}
} }
uploading = Promise.resolve(); uploading = Promise.resolve();
} }
let createNewClassMessages: MessageSimple;
async function createNewClass() { async function createNewClass() {
try { try {
const r = await post('models/data/class/new', { const r = await post('models/data/class/new', {
@ -100,7 +88,7 @@
classes = classes; classes = classes;
getList(); getList();
} catch (e) { } catch (e) {
showMessage(e, createNewClassMessages); showMessage(e, notificationStore);
} }
} }
@ -109,12 +97,11 @@
rdelete('models/data/point', { id }); rdelete('models/data/point', { id });
getList(); getList();
} catch (e) { } catch (e) {
console.error('TODO notify user', e); showMessage(e, notificationStore);
} }
} }
let addFile: File | undefined = $state(); let addFile: File | undefined = $state();
let addImageMessages: MessageSimple;
let adding = $state(Promise.resolve()); let adding = $state(Promise.resolve());
let uploadImageDialog: HTMLDialogElement; let uploadImageDialog: HTMLDialogElement;
async function addImage() { async function addImage() {
@ -136,7 +123,7 @@
addFile = undefined; addFile = undefined;
getList(); getList();
} catch (e) { } catch (e) {
showMessage(e, addImageMessages); showMessage(e, notificationStore);
} }
} }
</script> </script>
@ -151,30 +138,30 @@
{#each classes as item} {#each classes as item}
<button <button
style="width: auto; white-space: nowrap;" style="width: auto; white-space: nowrap;"
on:click={() => setActiveClass(item, setActive)} onclick={() => setActiveClass(item, setActive)}
class="tab" class="tab"
class:selected={isActive(item.name)} class:selected={isActive(item.name)}
> >
{item.name} {item.name}
{#if model.model_type == 2} {#if model.model_type == 2}
{#if item.status == 1} {#if item.status == 1}
<span class="bi bi-book" style="color: orange;" /> <span class="bi bi-book" style="color: orange;"></span>
{:else if item.status == 2} {:else if item.status == 2}
<span class="bi bi-book" style="color: green;" /> <span class="bi bi-book" style="color: green;"></span>
{:else if item.status == 3} {:else if item.status == 3}
<span class="bi bi-check" style="color: green;" /> <span class="bi bi-check" style="color: green;"></span>
{/if} {/if}
{/if} {/if}
</button> </button>
{/each} {/each}
</div> </div>
<button <button
on:click={() => { onclick={() => {
setActive('-----New Class-----')(); setActive('-----New Class-----')();
selected_class = undefined; selected_class = undefined;
}} }}
> >
<span class="bi bi-plus" /> <span class="bi bi-plus"></span>
</button> </button>
</div> </div>
{#if selected_class == undefined && isActive('-----New Class-----')} {#if selected_class == undefined && isActive('-----New Class-----')}
@ -184,21 +171,31 @@
<div slot="buttons" let:setActive let:isActive> <div slot="buttons" let:setActive let:isActive>
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('zip')} type="button"
onclick={setActive('zip')}
class:selected={isActive('zip')} class:selected={isActive('zip')}
> >
Zip Zip
</button> </button>
<button <button
class="tab" class="tab"
on:click|preventDefault={setActive('empty')} type="button"
onclick={setActive('empty')}
class:selected={isActive('empty')} class:selected={isActive('empty')}
> >
Empty Class Empty Class
</button> </button>
<button
class="tab"
type="button"
onclick={setActive('api')}
class:selected={isActive('api')}
>
API
</button>
</div> </div>
<div class="content" class:selected={isActive('zip')}> <div class="content" class:selected={isActive('zip')}>
<form on:submit|preventDefault={uploadZip}> <form onsubmit={preventDefault(uploadZip)}>
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Data file</label> <label for="file">Data file</label>
<div class="form-msg"> <div class="form-msg">
@ -218,7 +215,6 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={uploadImage} />
{#if file} {#if file}
{#await uploading} {#await uploading}
<button disabled> Uploading </button> <button disabled> Uploading </button>
@ -229,7 +225,7 @@
</form> </form>
</div> </div>
<div class="content" class:selected={isActive('empty')}> <div class="content" class:selected={isActive('empty')}>
<form on:submit|preventDefault={createNewClass}> <form onsubmit={preventDefault(createNewClass)}>
<div class="form-msg"> <div class="form-msg">
This Creates an empty class that allows images to be added after This Creates an empty class that allows images to be added after
</div> </div>
@ -237,10 +233,12 @@
<label for="className">Class Name</label> <label for="className">Class Name</label>
<input required name="className" bind:value={createClass.className} /> <input required name="className" bind:value={createClass.className} />
</fieldset> </fieldset>
<MessageSimple bind:this={createNewClassMessages} />
<button> Create New Class </button> <button> Create New Class </button>
</form> </form>
</div> </div>
<div class="content" class:selected={isActive('api')}>
<CreateNewClass {model} />
</div>
</Tabs> </Tabs>
</div> </div>
{/if} {/if}
@ -258,7 +256,7 @@
{:else} {:else}
Class to train Class to train
{/if} {/if}
<button on:click={() => uploadImageDialog.showModal()}> Upload Image </button> <button onclick={() => uploadImageDialog.showModal()}> Upload Image </button>
</h2> </h2>
<table> <table>
<thead> <thead>
@ -314,7 +312,7 @@
{/if} {/if}
</td> </td>
<td style="width: 3ch"> <td style="width: 3ch">
<button class="danger" on:click={() => deleteDataPoint(image.id)}> <button class="danger" onclick={() => deleteDataPoint(image.id)}>
<span class="bi bi-trash"></span> <span class="bi bi-trash"></span>
</button> </button>
</td> </td>
@ -325,7 +323,7 @@
<div class="flex justify-center align-center"> <div class="flex justify-center align-center">
<div class="grow-1 flex justify-end align-center"> <div class="grow-1 flex justify-end align-center">
{#if page > 0} {#if page > 0}
<button on:click={() => (page -= 1)}> Prev </button> <button onclick={() => (page -= 1)}> Prev </button>
{/if} {/if}
</div> </div>
@ -335,7 +333,7 @@
<div class="grow-1 flex justify-start align-center"> <div class="grow-1 flex justify-start align-center">
{#if showNext} {#if showNext}
<button on:click={() => (page += 1)}> Next </button> <button onclick={() => (page += 1)}> Next </button>
{/if} {/if}
</div> </div>
</div> </div>
@ -345,7 +343,7 @@
{/if} {/if}
<dialog class="newImageDialog" bind:this={uploadImageDialog}> <dialog class="newImageDialog" bind:this={uploadImageDialog}>
<form on:submit|preventDefault={addImage}> <form onsubmit={preventDefault(addImage)}>
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Data file</label> <label for="file">Data file</label>
<div class="form-msg"> <div class="form-msg">
@ -365,7 +363,6 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={addImageMessages} />
{#if addFile} {#if addFile}
{#await adding} {#await adding}
<button disabled> Uploading </button> <button disabled> Uploading </button>
@ -415,10 +412,4 @@
table tr th:first-child { table tr th:first-child {
border-left: none; border-left: none;
} }
table tr td button,
table tr td .button {
padding: 5px 10px;
box-shadow: 0 2px 5px 1px #66666655;
}
</style> </style>

View File

@ -1,26 +1,29 @@
<script lang="ts"> <script lang="ts">
import { post, postFormData } from 'src/lib/requests.svelte'; import { post, postFormData, showMessage } from 'src/lib/requests.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte'; import { onDestroy } from 'svelte';
import { createEventDispatcher, onDestroy } from 'svelte';
import Spinner from 'src/lib/Spinner.svelte'; import Spinner from 'src/lib/Spinner.svelte';
import type { Task } from './TasksTable.svelte'; import type { Task } from './tasks/types';
import Tabs from 'src/lib/Tabs.svelte';
import hljs from 'highlight.js';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils';
let { model } = $props<{ model: Model }>(); let {
model,
onupload = () => {},
ontaskReload = () => {}
}: { model: Model; onupload?: () => void; ontaskReload?: () => void } = $props();
let file: File | undefined = $state(); let file: File | undefined = $state();
const dispatch = createEventDispatcher<{ upload: void; taskReload: void }>();
let _result: Promise<Task> = $state(new Promise(() => {})); let _result: Promise<Task> = $state(new Promise(() => {}));
let run = $state(false); let run = $state(false);
let last_task: string | undefined = $state(); let last_task: string | undefined = $state();
let last_task_timeout: number | null = null; let last_task_timeout: number | null = null;
let messages: MessageSimple;
async function reloadLastTimeout() { async function reloadLastTimeout() {
if (!last_task) { if (!last_task) {
return; return;
@ -31,7 +34,7 @@
const r = await post('tasks/task', { id: last_task }); const r = await post('tasks/task', { id: last_task });
if ([0, 1, 2, 3].includes(r.status)) { if ([0, 1, 2, 3].includes(r.status)) {
setTimeout(reloadLastTimeout, 500); setTimeout(reloadLastTimeout, 500);
setTimeout(() => dispatch('taskReload'), 500); setTimeout(ontaskReload, 500);
} else { } else {
_result = Promise.resolve(r); _result = Promise.resolve(r);
} }
@ -42,7 +45,6 @@
async function submit() { async function submit() {
if (!file) return; if (!file) return;
messages.clear();
let form = new FormData(); let form = new FormData();
form.append('json_data', JSON.stringify({ id: model.id })); form.append('json_data', JSON.stringify({ id: model.id }));
@ -56,14 +58,10 @@
file = undefined; file = undefined;
last_task_timeout = setTimeout(() => reloadLastTimeout()); last_task_timeout = setTimeout(() => reloadLastTimeout());
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not run the model');
messages.display(await e.json());
} else {
messages.display('Could not run the model');
}
} }
dispatch('upload'); onupload();
} }
onDestroy(() => { onDestroy(() => {
@ -73,39 +71,111 @@
}); });
</script> </script>
<form on:submit|preventDefault={submit}> <Tabs active="upload" let:isActive>
<fieldset class="file-upload"> <div class="buttons" slot="buttons" let:setActive let:isActive>
<label for="file">Image</label> <button class="tab" class:selected={isActive('upload')} onclick={setActive('upload')}>
<div class="form-msg">Run image through them model and get the result</div> Upload
</button>
<button class="tab" class:selected={isActive('api')} onclick={setActive('api')}> Api </button>
</div>
<div class="content" class:selected={isActive('api')}>
<div class="codeinfo">
To perform an image classfication please follow the example bellow:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`let form = new FormData();
form.append('json_data', JSON.stringify({ id: '${model.id}' }));
form.append('file', file, 'file');
<FileUpload replace_slot bind:file accept="image/*"> const headers = new Headers();
<img src="/imgs/upload-icon.png" alt="" /> headers.append('response-type', 'application/json');
<span> Upload image </span> headers.append('token', token);
<div slot="replaced-name">
<span> Image selected </span> const r = await fetch('${window.location.protocol}//${window.location.hostname}/api/tasks/start/image', {
</div> method: 'POST',
</FileUpload> headers: headers,
</fieldset> body: form
<MessageSimple bind:this={messages} /> });`,
<button> Run </button> { language: 'javascript' }
{#if run} ).value}</pre>
{#await _result} On Success the request will return a json with this format:
<h1> <pre style="font-family: Fira Code;">{@html hljs.highlight(
Processing Image! <Spinner /> `{ id "00000000-0000-0000-0000-000000000000" }`,
</h1> { language: 'json' }
{:then result} ).value}</pre>
{#if result.status == 4} This id can be used to query the API for the result of the task:
{@const res = JSON.parse(result.result)} <pre style="font-family: Fira Code;">{@html hljs.highlight(
<div> `const headers = new Headers();
<h1>Result</h1> headers.append('content-type', 'application/json');
The image was classified as {res.class} with confidence: {res.confidence} headers.append('token', token);
</div>
{:else} const r = await fetch('${window.location.protocol}//${window.location.hostname}/api/tasks/task', {
<div class="result"> method: 'POST',
<h1>There was a problem running the task:</h1> headers: headers,
{result?.status_message} body: JSON.stringify({ id: '00000000-0000-0000-0000-000000000000' })
</div> });`,
{ language: 'javascript' }
).value}</pre>
Once the task shows the status as 4 then the data can be obatined in the result field: The successful
return value has this type:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`{
"id": string,
"user_id": string,
"model_id": string,
"status": number,
"status_message": string,
"user_confirmed": number,
"compacted": number,
"type": number,
"extra_task_info": string,
"result": string,
"created": string
}`,
{ language: 'javascript' }
).value}</pre>
</div>
</div>
<div class="content" class:selected={isActive('upload')}>
<form onsubmit={preventDefault(submit)} style="box-shadow: none;">
<fieldset class="file-upload">
<label for="file">Image</label>
<div class="form-msg">Run image through them model and get the result</div>
<FileUpload replace_slot bind:file accept="image/*">
<img src="/imgs/upload-icon.png" alt="" />
<span> Upload image </span>
<div slot="replaced-name">
<span> Image selected </span>
</div>
</FileUpload>
</fieldset>
<button> Run </button>
{#if run}
{#await _result}
<h1>
Processing Image! <Spinner />
</h1>
{:then result}
{#if result.status == 4}
{@const res = JSON.parse(result.result)}
<div>
<h1>Result</h1>
The image was classified as {res.class} with confidence: {res.confidence}
</div>
{:else}
<div class="result">
<h1>There was a problem running the task:</h1>
{result?.status_message}
</div>
{/if}
{/await}
{/if} {/if}
{/await} </form>
{/if} </div>
</form> </Tabs>
<style lang="scss">
.codeinfo {
padding: 20px;
}
</style>

View File

@ -1,5 +1,4 @@
<script lang="ts"> <script lang="ts">
import { post } from 'src/lib/requests.svelte';
import type { Model } from 'src/routes/models/edit/+page.svelte'; import type { Model } from 'src/routes/models/edit/+page.svelte';
import RunModel from './RunModel.svelte'; import RunModel from './RunModel.svelte';
import TasksTable from './tasks/TasksTable.svelte'; import TasksTable from './tasks/TasksTable.svelte';
@ -12,7 +11,7 @@
{#if active} {#if active}
<div class="content selected"> <div class="content selected">
<RunModel {model} on:upload={() => table.getList()} on:taskReload={() => table.getList()} /> <RunModel {model} onupload={() => table.getList()} ontaskReload={() => table.getList()} />
<TasksTable {model} bind:this={table} /> <TasksTable {model} bind:this={table} />
<Stats {model} /> <Stats {model} />
</div> </div>

View File

@ -1,14 +1,20 @@
<script lang="ts"> <script lang="ts">
import MessageSimple from 'src/lib/MessageSimple.svelte'; import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import { post } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
import { createEventDispatcher } from 'svelte'; import { preventDefault } from 'src/lib/utils';
let { number_of_invalid_images, has_data, model } = $props<{ let {
number_of_invalid_images,
has_data,
model,
onreload = () => {}
}: {
number_of_invalid_images: number; number_of_invalid_images: number;
has_data: boolean; has_data: boolean;
model: Model; model: Model;
}>(); onreload?: () => void;
} = $props();
let data = $state({ let data = $state({
model_type: 'simple', model_type: 'simple',
@ -18,54 +24,39 @@
let submitted = $state(false); let submitted = $state(false);
let dispatch = createEventDispatcher<{ reload: void }>();
let messages: MessageSimple;
async function submit() { async function submit() {
messages.clear();
submitted = true; submitted = true;
try { try {
await post('models/train', { await post('models/train', {
id: model.id, id: model.id,
...data ...data
}); });
dispatch('reload'); onreload();
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not start the training of the model');
messages.display(await e.json());
} else {
messages.display('Could not start the training of the model');
}
} }
} }
async function submitRetrain() { async function submitRetrain() {
messages.clear();
submitted = true; submitted = true;
try { try {
await post('model/train/retrain', { id: model.id }); await post('model/train/retrain', { id: model.id });
dispatch('reload'); onreload();
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not start the training of the model');
messages.display(await e.json());
} else {
messages.display('Could not start the training of the model');
}
} }
} }
</script> </script>
{#if model.status == 2} {#if model.status == 2}
<form class:submitted on:submit|preventDefault={submit}> <form class:submitted onsubmit={preventDefault(submit)}>
{#if has_data} {#if has_data}
{#if number_of_invalid_images > 0} {#if number_of_invalid_images > 0}
<p class="danger"> <p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be delete when the model trains. These images will be deleted when the model trains.
</p> </p>
{/if} {/if}
<MessageSimple bind:this={messages} />
<!-- TODO expading mode --> <!-- TODO expading mode -->
<fieldset> <fieldset>
<legend> Model Type </legend> <legend> Model Type </legend>
@ -110,17 +101,16 @@
<h2>To train the model please provide data to the model first</h2> <h2>To train the model please provide data to the model first</h2>
{/if} {/if}
</form> </form>
{:else} {:else if ![4, 6, 7].includes(model.status)}
<form class:submitted on:submit|preventDefault={submitRetrain}> <form class:submitted onsubmit={submitRetrain}>
{#if has_data} {#if has_data}
<h2>This model has new classes and can be expanded</h2> <h2>This model has new classes and can be expanded</h2>
{#if number_of_invalid_images > 0} {#if number_of_invalid_images > 0}
<p class="danger"> <p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be delete when the model trains. These images will be deleted when the model trains.
</p> </p>
{/if} {/if}
<MessageSimple bind:this={messages} />
<button> Retrain </button> <button> Retrain </button>
{:else} {:else}
<h2>To train the model please provide data to the model first</h2> <h2>To train the model please provide data to the model first</h2>

View File

@ -0,0 +1,27 @@
<script lang="ts">
import hljs from 'highlight.js';
import type { Model } from '../+page.svelte';
let { model }: { model: Model } = $props();
</script>
To create a new class via the API you can:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`let form = new FormData();
form.append('json_data', JSON.stringify({
id: '${model.id}',
name: 'New class name'
}));
form.append('file', file, 'file');
const headers = new Headers();
headers.append('response-type', 'application/json');
headers.append('token', token);
const r = await fetch('${window.location.protocol}//${window.location.hostname}/models/data/class/new', {
method: 'POST',
headers: headers,
body: form
});`,
{ language: 'javascript' }
).value}</pre>

View File

@ -1,5 +1,5 @@
<script lang="ts"> <script lang="ts">
import { onDestroy, onMount } from 'svelte'; import { onDestroy } from 'svelte';
import type { Model } from '../+page.svelte'; import type { Model } from '../+page.svelte';
import { post, showMessage } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
import type { DataPoint, TasksStatsDay } from 'src/types/stats/task'; import type { DataPoint, TasksStatsDay } from 'src/types/stats/task';
@ -18,7 +18,6 @@
PointElement, PointElement,
LineElement LineElement
} from 'chart.js'; } from 'chart.js';
import ModelData from '../ModelData.svelte';
Chart.register( Chart.register(
Title, Title,
@ -57,7 +56,9 @@
} }
let pie: HTMLCanvasElement; let pie: HTMLCanvasElement;
let pie2: HTMLCanvasElement;
let pieChart: Chart<'pie'> | undefined; let pieChart: Chart<'pie'> | undefined;
let pie2Chart: Chart<'pie'> | undefined;
function createPie(s: TasksStatsDay) { function createPie(s: TasksStatsDay) {
if (pieChart) { if (pieChart) {
pieChart.destroy(); pieChart.destroy();
@ -72,23 +73,31 @@
'Classfication Failure', 'Classfication Failure',
'Classfication Preparing', 'Classfication Preparing',
'Classfication Running', 'Classfication Running',
'Classfication Unknown', 'Classfication Unknown'
'Non Classfication Error',
'Non Classfication Success'
], ],
datasets: [ datasets: [
{ {
label: 'Total', label: 'Total',
data: [ data: [t.c_error, t.c_success, t.c_failure, t.c_pre_running, t.c_running, t.c_unknown]
t.c_error, }
t.c_success, ]
t.c_failure, },
t.c_pre_running, options: {
t.c_running, animation: false
t.c_unknown, }
t.nc_error, });
t.nc_success
] if (pie2Chart) {
pieChart.destroy();
}
pie2Chart = new Chart(pie2, {
type: 'pie',
data: {
labels: ['Non Classfication Error', 'Non Classfication Success'],
datasets: [
{
label: 'Total',
data: [t.nc_error, t.nc_success]
} }
] ]
}, },
@ -115,7 +124,6 @@
nc_error: 'Non Classfication Error', nc_error: 'Non Classfication Error',
nc_success: 'Non Classfication Success' nc_success: 'Non Classfication Success'
}; };
let t = s.total;
let labels = new Array(24).fill(0).map((_, i) => i); let labels = new Array(24).fill(0).map((_, i) => i);
lineChart = new Chart(line, { lineChart = new Chart(line, {
type: 'line', type: 'line',
@ -147,8 +155,13 @@
<h1>Statistics (Day)</h1> <h1>Statistics (Day)</h1>
<h2>Total</h2> <h2>Total</h2>
<div> <div class="pies">
<canvas bind:this={pie}></canvas> <div>
<canvas bind:this={pie}></canvas>
</div>
<div>
<canvas bind:this={pie2}></canvas>
</div>
</div> </div>
<h2>Hourly</h2> <h2>Hourly</h2>
@ -160,4 +173,12 @@
canvas { canvas {
width: 100%; width: 100%;
} }
.pies {
display: flex;
align-content: stretch;
div {
width: 50%;
}
}
</style> </style>

View File

@ -1,16 +1,4 @@
<script lang="ts" context="module"> <script lang="ts" context="module">
export type Task = {
id: string;
user_id: string;
model_id: string;
status: number;
status_message: string;
user_confirmed: number;
compacted: number;
type: number;
created: string;
result: string;
};
export const TaskType = { export const TaskType = {
TASK_FAILED_RUNNING: -2, TASK_FAILED_RUNNING: -2,
TASK_FAILED_CREATION: -1, TASK_FAILED_CREATION: -1,
@ -52,9 +40,10 @@
<script lang="ts"> <script lang="ts">
import { post, showMessage } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
import type { Model } from '../+page.svelte'; import type { Model } from '../+page.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte';
import Tooltip from 'src/lib/Tooltip.svelte'; import Tooltip from 'src/lib/Tooltip.svelte';
import type { Task } from './types';
let { model }: { model: Model } = $props(); let { model }: { model: Model } = $props();
let page = $state(0); let page = $state(0);
@ -80,7 +69,6 @@
} }
}); });
let userPreceptionMessages: MessageSimple;
// This returns a function that performs the call and does not do the call it self // This returns a function that performs the call and does not do the call it self
function userPreception(task: string, agree: number) { function userPreception(task: string, agree: number) {
return async function () { return async function () {
@ -99,7 +87,6 @@
<div> <div>
<h2>Tasks</h2> <h2>Tasks</h2>
<MessageSimple bind:this={userPreceptionMessages} />
<table> <table>
<thead> <thead>
<tr> <tr>
@ -156,14 +143,14 @@
<div> <div>
{#if task.user_confirmed != 1} {#if task.user_confirmed != 1}
<Tooltip title="Agree with the result of the task"> <Tooltip title="Agree with the result of the task">
<button type="button" on:click={userPreception(task.id, 1)}> <button type="button" onclick={userPreception(task.id, 1)}>
<span class="bi bi-check"></span> <span class="bi bi-check"></span>
</button> </button>
</Tooltip> </Tooltip>
{/if} {/if}
{#if task.user_confirmed != -1} {#if task.user_confirmed != -1}
<Tooltip title="Disagree with the result"> <Tooltip title="Disagree with the result">
<button class="danger" type="button" on:click={userPreception(task.id, -1)}> <button class="danger" type="button" onclick={userPreception(task.id, -1)}>
<span class="bi bi-x-lg"></span> <span class="bi bi-x-lg"></span>
</button> </button>
</Tooltip> </Tooltip>
@ -208,7 +195,7 @@
<div class="flex justify-center align-center"> <div class="flex justify-center align-center">
<div class="grow-1 flex justify-end align-center"> <div class="grow-1 flex justify-end align-center">
{#if page > 0} {#if page > 0}
<button on:click={() => (page -= 1)}> Prev </button> <button onclick={() => (page -= 1)}> Prev </button>
{/if} {/if}
</div> </div>
@ -218,7 +205,7 @@
<div class="grow-1 flex justify-start align-center"> <div class="grow-1 flex justify-start align-center">
{#if showNext} {#if showNext}
<button on:click={() => (page += 1)}> Next </button> <button onclick={() => (page += 1)}> Next </button>
{/if} {/if}
</div> </div>
</div> </div>
@ -229,10 +216,6 @@
width: 100%; width: 100%;
display: flex; display: flex;
justify-content: space-between; justify-content: space-between;
& > button {
margin: 3px 5px;
}
} }
table { table {

View File

@ -0,0 +1,12 @@
export type Task = {
id: string;
user_id: string;
model_id: string;
status: number;
status_message: string;
user_confirmed: number;
compacted: number;
type: number;
created: string;
result: string;
};

View File

@ -3,3 +3,16 @@ export type ModelStats = Array<{
training: number; training: number;
testing: number; testing: number;
}>; }>;
export type Class = {
name: string;
id: string;
status: number;
};
export type Image = {
file_path: string;
mode: number;
status: number;
id: string;
};

View File

@ -3,6 +3,7 @@
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { userStore } from '../UserStore.svelte'; import { userStore } from '../UserStore.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { preventDefault } from 'src/lib/utils';
let submitted = $state(false); let submitted = $state(false);
@ -39,7 +40,7 @@
<div class="login-page"> <div class="login-page">
<div> <div>
<h1>Register</h1> <h1>Register</h1>
<form on:submit|preventDefault={onSubmit} class:submitted> <form onsubmit={preventDefault(onSubmit)} class:submitted>
<fieldset> <fieldset>
<label for="username">Username</label> <label for="username">Username</label>
<input required name="username" bind:value={loginData.username} /> <input required name="username" bind:value={loginData.username} />

View File

@ -4,10 +4,11 @@
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { post } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
import MessageSimple, { type DisplayFn } from 'src/lib/MessageSimple.svelte';
import TokenTable from './TokenTable.svelte'; import TokenTable from './TokenTable.svelte';
import DeleteUser from './DeleteUser.svelte'; import DeleteUser from './DeleteUser.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils';
onMount(() => { onMount(() => {
if (!userStore.isLogin()) { if (!userStore.isLogin()) {
@ -26,12 +27,8 @@
let submiitedEmail = $state(false); let submiitedEmail = $state(false);
let submiitedPassword = $state(false); let submiitedPassword = $state(false);
let msgEmail: MessageSimple;
let msgPassword: MessageSimple;
async function onSubmitEmail() { async function onSubmitEmail() {
submiitedEmail = true; submiitedEmail = true;
msgEmail.display('');
if (!userStore.user) return; if (!userStore.user) return;
@ -44,31 +41,31 @@
...userStore.user, ...userStore.user,
...req ...req
}; };
msgEmail.display('User updated successufly!', { type: 'success', timeToShow: 10000 }); notificationStore.add({
message: 'User updated successufly!',
type: 'success',
timeToLive: 10000
});
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not update email');
msgEmail.display(await e.json());
} else {
msgEmail.display('Could not update email');
}
} }
} }
async function onSubmitPassword() { async function onSubmitPassword() {
submiitedPassword = true; submiitedPassword = true;
msgPassword.display('');
if (!userStore.user) return; if (!userStore.user) return;
try { try {
await post('user/info/password', passwordData); await post('user/info/password', passwordData);
passwordData = { old_password: '', password: '', password2: '' }; passwordData = { old_password: '', password: '', password2: '' };
msgPassword.display('Password updated successufly!', { type: 'success', timeToShow: 10000 });
notificationStore.add({
message: 'Password updated successufly!',
type: 'success',
timeToLive: 10000
});
} catch (e) { } catch (e) {
if (e instanceof Response) { showMessage(e, notificationStore, 'Could not update password');
msgPassword.display(await e.json());
} else {
msgPassword.display('Could not update password');
}
} }
} }
</script> </script>
@ -80,15 +77,14 @@
<div class="login-page"> <div class="login-page">
<div> <div>
<h1>User Infomation</h1> <h1>User Infomation</h1>
<form on:submit|preventDefault={onSubmitEmail} class:submiitedEmail> <form onsubmit={onSubmitEmail} class:submiitedEmail>
<fieldset> <fieldset>
<label for="email">Email</label> <label for="email">Email</label>
<input type="email" required name="email" bind:value={email} /> <input type="email" required name="email" bind:value={email} />
</fieldset> </fieldset>
<MessageSimple bind:this={msgEmail} />
<button> Update </button> <button> Update </button>
</form> </form>
<form on:submit|preventDefault={onSubmitPassword} class:submiitedPassword> <form onsubmit={preventDefault(onSubmitPassword)} class:submiitedPassword>
<fieldset> <fieldset>
<label for="old_password">Old Password</label> <label for="old_password">Old Password</label>
<input <input
@ -106,7 +102,6 @@
<label for="password2">Repeat New Password</label> <label for="password2">Repeat New Password</label>
<input required bind:value={passwordData.password2} name="password2" type="password" /> <input required bind:value={passwordData.password2} name="password2" type="password" />
</fieldset> </fieldset>
<MessageSimple bind:this={msgPassword} />
<div> <div>
<button> Update </button> <button> Update </button>
</div> </div>

View File

@ -1,36 +1,32 @@
<script lang="ts"> <script lang="ts">
import {createEventDispatcher} from 'svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte';
import Tooltip from 'src/lib/Tooltip.svelte'; import Tooltip from 'src/lib/Tooltip.svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import {post} from 'src/lib/requests.svelte'; import { post } from 'src/lib/requests.svelte';
import Spinner from 'src/lib/Spinner.svelte'; import Spinner from 'src/lib/Spinner.svelte';
import { preventDefault } from 'src/lib/utils';
let { onreload = () => {} }: { onreload?: () => void } = $props();
const dispatch = createEventDispatcher<{reload: void}>();
let addNewToken = $state(false); let addNewToken = $state(false);
let messages: MessageSimple;
let expiry_date: HTMLInputElement = $state(undefined as any); let expiry_date: HTMLInputElement = $state(undefined as any);
type NewToken = {
name: string,
expiry: number,
token: string,
}
let token: Promise<NewToken> | undefined = $state(); type NewToken = {
name: string;
expiry: number;
token: string;
};
let token: Promise<NewToken> | undefined = $state();
let newToken: { let newToken: {
name: string; name: string;
expiry: string; expiry: string;
password: string; password: string;
} = $state({ } = $state({
name: '', name: '',
expiry: '', expiry: '',
password: '', password: ''
}); });
async function createToken(e: SubmitEvent & { currentTarget: HTMLFormElement }) { async function createToken(e: SubmitEvent & { currentTarget: HTMLFormElement }) {
@ -47,86 +43,85 @@
} }
} }
try { try {
const r = await post("user/token/add", { const r = await post('user/token/add', {
name: newToken.name, name: newToken.name,
expiry: expiry, expiry: expiry,
password: newToken.password, password: newToken.password
}); });
token = Promise.resolve(r) token = Promise.resolve(r);
setTimeout(() => dispatch('reload'), 500) setTimeout(onreload, 500);
} catch (e) { } catch (e) {
token = undefined; token = undefined;
console.error("Notify user", e) console.error('Notify user', e);
} }
} }
$effect(() => { $effect(() => {
if (expiry_date) { if (expiry_date) {
if (isNaN(Number(newToken.expiry))) { if (isNaN(Number(newToken.expiry))) {
expiry_date.setCustomValidity('Invalid Number'); expiry_date.setCustomValidity('Invalid Number');
} else { } else {
expiry_date.setCustomValidity(''); expiry_date.setCustomValidity('');
} }
} }
}); });
</script> </script>
{#if addNewToken} {#if addNewToken}
<div> <div>
<h2>Add New Token</h2> <h2>Add New Token</h2>
{#if !token} {#if !token}
<form on:submit|preventDefault={createToken}> <form onsubmit={preventDefault(createToken)}>
<fieldset> <fieldset>
<label for="name">Name</label> <label for="name">Name</label>
<input required bind:value={newToken.name} name="name" /> <input required bind:value={newToken.name} name="name" />
</fieldset> </fieldset>
<fieldset> <fieldset>
<label for="expiry_date">Expiry Date</label> <label for="expiry_date">Expiry Date</label>
<div class="flex"> <div class="flex">
<input bind:this={expiry_date} bind:value={newToken.expiry} name="expiry_date" /> <input bind:this={expiry_date} bind:value={newToken.expiry} name="expiry_date" />
<Tooltip title="Time in seconds. Leave empty to last forever"> <Tooltip title="Time in seconds. Leave empty to last forever">
<span class="center-question bi bi-question-circle-fill" /> <span class="center-question bi bi-question-circle-fill"></span>
</Tooltip> </Tooltip>
</div> </div>
</fieldset> </fieldset>
<fieldset> <fieldset>
<label for="password">Password</label> <label for="password">Password</label>
<input required bind:value={newToken.password} name="password" /> <input required bind:value={newToken.password} name="password" />
</fieldset> </fieldset>
<MessageSimple bind:this={messages} /> <div>
<div> <button> Update </button>
<button> Update </button> </div>
</div> </form>
</form> {:else}
{:else} {#await token}
{#await token} <Spinner /> Generating
<Spinner /> Generating {:then t}
{:then t} <h3>Token generated</h3>
<h3> Token generated </h3> <form onsubmit={preventDefault(() => {})}>
<form on:submit|preventDefault={() => {}}> <fieldset>
<fieldset> <label for="token">Token</label>
<label for="token">Token</label> <div class="flex">
<div class="flex"> <input value={t.token} oninput={(e) => e.preventDefault()} name="token" />
<input value={t.token} on:input={(e) => e.preventDefault() } name="token" /> <div style="width: 5em;">
<div style="width: 5em;"> <button onclick={() => navigator.clipboard.writeText(t.token)}>
<button on:click={() => navigator.clipboard.writeText(t.token)} > <span class="bi bi-clipboard"></span>
<span class="bi bi-clipboard" /> </button>
</button> </div>
</div> </div>
</div> </fieldset>
</fieldset> <div>
<div> <button onclick={() => (token = undefined)}> Generate new token </button>
<button on:click={() => token = undefined}> Generate new token </button> </div>
</div> </form>
</form> {:catch e}
{:catch e} {e}
{e} {/await}
{/await} {/if}
{/if}
</div> </div>
{:else} {:else}
<div> <div>
<button class="expander" on:click={() => (addNewToken = true)}> Add New Token </button> <button class="expander" onclick={() => (addNewToken = true)}> Add New Token </button>
</div> </div>
{/if} {/if}

View File

@ -2,6 +2,7 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte'; import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { rdelete, showMessage } from 'src/lib/requests.svelte'; import { rdelete, showMessage } from 'src/lib/requests.svelte';
import { preventDefault } from 'src/lib/utils';
import { userStore } from 'src/routes/UserStore.svelte'; import { userStore } from 'src/routes/UserStore.svelte';
let data = $state({ password: '' }); let data = $state({ password: '' });
@ -23,7 +24,7 @@
} }
</script> </script>
<form class="danger-bg" on:submit|preventDefault={deleteUser}> <form class="danger-bg" onsubmit={preventDefault(deleteUser)}>
<h2 class="no-top-margin">Delete user</h2> <h2 class="no-top-margin">Delete user</h2>
Deleting the user will delete all your data stored in the service including the images. Deleting the user will delete all your data stored in the service including the images.
<fieldset> <fieldset>

View File

@ -10,7 +10,7 @@
import { post, rdelete } from 'src/lib/requests.svelte'; import { post, rdelete } from 'src/lib/requests.svelte';
import { userStore } from 'src/routes/UserStore.svelte'; import { userStore } from 'src/routes/UserStore.svelte';
import AddToken from './AddToken.svelte'; import AddToken from './AddToken.svelte';
let page = $state(0); let page = $state(0);
let showNext = $state(false); let showNext = $state(false);
@ -70,7 +70,7 @@
{new Date(token.create_date).toLocaleString()} {new Date(token.create_date).toLocaleString()}
</td> </td>
<td> <td>
<button class="danger" on:click={() => removeToken(token)}> <button class="danger" onclick={() => removeToken(token)}>
<span class="bi bi-trash"></span> <span class="bi bi-trash"></span>
</button> </button>
</td> </td>
@ -81,7 +81,7 @@
<div class="flex justify-center align-center"> <div class="flex justify-center align-center">
<div class="grow-1 flex justify-end align-center"> <div class="grow-1 flex justify-end align-center">
{#if page > 0} {#if page > 0}
<button on:click={() => (page -= 1)}> Prev </button> <button onclick={() => (page -= 1)}> Prev </button>
{/if} {/if}
</div> </div>
@ -91,25 +91,15 @@
<div class="grow-1 flex justify-start align-center"> <div class="grow-1 flex justify-start align-center">
{#if showNext} {#if showNext}
<button on:click={() => (page += 1)}> Next </button> <button onclick={() => (page += 1)}> Next </button>
{/if} {/if}
</div> </div>
</div> </div>
</div> </div>
<AddToken on:reload={getList} /> <AddToken onreload={getList} />
<style lang="scss"> <style lang="scss">
.buttons {
width: 100%;
display: flex;
justify-content: space-between;
& > button {
margin: 3px 5px;
}
}
table { table {
width: 100%; width: 100%;
box-shadow: 0 2px 8px 1px #66666622; box-shadow: 0 2px 8px 1px #66666622;
@ -132,10 +122,4 @@
table tr th:first-child { table tr th:first-child {
border-left: none; border-left: none;
} }
table tr td button,
table tr td .button {
padding: 5px 10px;
box-shadow: 0 2px 5px 1px #66666655;
}
</style> </style>

View File

@ -65,6 +65,7 @@ button.expander::after {
a.button { a.button {
text-decoration: none; text-decoration: none;
height: 1.6em;
} }
.flex { .flex {
@ -190,3 +191,12 @@ form.danger-bg {
background-color: var(--danger-transparent); background-color: var(--danger-transparent);
border: 1px solid var(--danger); border: 1px solid var(--danger);
} }
pre {
font-family: 'Fira Code';
background-color: #f6f8fa;
word-break: break-word;
white-space: pre-wrap;
border-radius: 10px;
padding: 10px;
}

View File

@ -12,10 +12,10 @@ const config = {
// If your environment is not supported or you settled on a specific environment, switch out the adapter. // If your environment is not supported or you settled on a specific environment, switch out the adapter.
// See https://kit.svelte.dev/docs/adapters for more information about adapters. // See https://kit.svelte.dev/docs/adapters for more information about adapters.
adapter: adapter(), adapter: adapter(),
alias: { alias: {
src: "src", src: 'src',
routes: "src/routes", routes: 'src/routes'
} }
} }
}; };

View File

@ -2,5 +2,10 @@ import { sveltekit } from '@sveltejs/kit/vite';
import { defineConfig } from 'vite'; import { defineConfig } from 'vite';
export default defineConfig({ export default defineConfig({
plugins: [sveltekit()] plugins: [sveltekit()],
build: {
commonjsOptions: {
esmExternals: true
}
}
}); });