Compare commits

..

2 Commits
main ... runner

Author SHA1 Message Date
b1e4211e6a more work on the rust runner 2024-05-06 12:48:02 +01:00
e22df8adc9 started working on runner 2024-05-06 01:10:58 +01:00
92 changed files with 4320 additions and 6464 deletions

View File

@ -1,6 +1,6 @@
# vi: ft=dockerfile # vi: ft=dockerfile
FROM docker.io/nginx FROM docker.io/nginx
ADD nginx.proxy.conf /nginx.conf ADD nginx.dev.conf /nginx.conf
CMD ["nginx", "-c", "/nginx.conf", "-g", "daemon off;"] CMD ["nginx", "-c", "/nginx.conf", "-g", "daemon off;"]

View File

@ -14,7 +14,7 @@ ENV PATH=$PATH:/usr/local/go/bin
ENV GOPATH=/go ENV GOPATH=/go
RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | tar -C /usr/local -xz' RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | tar -C /usr/local -xz'
RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.15.0.tar.gz" | tar -C /usr/local -xz' # RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.13.1.tar.gz" | tar -C /usr/local -xz'
RUN ldconfig RUN ldconfig
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
@ -31,10 +31,7 @@ ADD go.mod .
ADD go.sum . ADD go.sum .
ADD main.go . ADD main.go .
ADD logic logic ADD logic logic
ADD entrypoint.sh .
RUN go install || true RUN go install || true
RUN go build . CMD ["go", "run", "."]
CMD ["./entrypoint.sh"]

View File

@ -1,42 +0,0 @@
# Configure the system
Go to the config.toml file and setup your hostname
# Build the containers
Running this commands on the root of the project will setup the nessesary.
Make sure that your docker/podman installation supports domain name resolution between containers
```bash
docker build -t andre-fyp-proxy -f DockerfileProxy
docker build -t andre-fyp-server -f DockerfileServer
cd webpage
docker build -t andre-fyp-web-server .
cd ..
```
# Run the docker compose
Running docker compose sets up the database server, the web page server, the proxy server and the main server
```bash
docker compose up
```
# Setup the Database
On another terminal instance create the database and tables.
Note: the password can be changed in the docker-compose file
```bash
PGPASSWORD=verysafepassword psql -h localhost -U postgres -f sql/base.sql
PGPASSWORD=verysafepassword psql -h localhost -U postgres -d fyp -f sql/user.sql
PGPASSWORD=verysafepassword psql -h localhost -U postgres -d fyp -f sql/models.sql
PGPASSWORD=verysafepassword psql -h localhost -U postgres -d fyp -f sql/tasks.sql
```
# Restart docker compose
Now restart docker compose and the system should be available under the domain name set up on the config.toml file

View File

@ -12,12 +12,7 @@ USER = "service"
[Worker] [Worker]
PULLING_TIME = "500ms" PULLING_TIME = "500ms"
NUMBER_OF_WORKERS = 16 NUMBER_OF_WORKERS = 1
[DB] [DB]
MAX_CONNECTIONS = 600 MAX_CONNECTIONS = 600
host = "db"
port = 5432
user = "postgres"
password = "verysafepassword"
dbname = "fyp"

View File

@ -1,44 +1,11 @@
version: "3.1"
services: services:
db: db:
image: docker.io/postgres:16.3 image: docker.andr3h3nriqu3s.com/services/postgres
command: -c 'max_connections=600' command: -c 'max_connections=600'
restart: always restart: always
networks:
- fyp-network
environment: environment:
POSTGRES_PASSWORD: verysafepassword POSTGRES_PASSWORD: verysafepassword
ports: ports:
- "5432:5432" - "5432:5432"
web-page:
image: andre-fyp-web-server
hostname: webpage
networks:
- fyp-network
server:
image: andre-fyp-server
hostname: server
networks:
- fyp-network
depends_on:
- db
volumes:
- "./config.toml:/app/config.toml"
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
proxy-server:
image: andre-fyp-proxy
networks:
- fyp-network
ports:
- "8000:8000"
depends_on:
- web-page
- server
networks:
fyp-network: {}

View File

@ -1,4 +0,0 @@
#/bin/bash
while true; do
./fyp
done

4
go.mod
View File

@ -9,11 +9,10 @@ require (
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
golang.org/x/crypto v0.19.0 golang.org/x/crypto v0.19.0
github.com/BurntSushi/toml v1.3.2
github.com/goccy/go-json v0.10.2
) )
require ( require (
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/lipgloss v0.9.1 // indirect github.com/charmbracelet/lipgloss v0.9.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
@ -21,6 +20,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.19.0 // indirect github.com/go-playground/validator/v10 v10.19.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx v3.6.2+incompatible // indirect github.com/jackc/pgx v3.6.2+incompatible // indirect

View File

@ -87,9 +87,9 @@ func (d Definition) GetLayers(db db.Db, filter string, args ...any) (layer []*La
return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...) return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...)
} }
func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64, epoch int) (err error) { func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64) (err error) {
d.Accuracy = accuracy d.Accuracy = accuracy
d.Epoch += epoch d.Epoch += 1
_, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id) _, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id)
return return
} }

View File

@ -2,10 +2,8 @@ package dbtypes
import ( import (
"encoding/json" "encoding/json"
"fmt"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
"github.com/charmbracelet/log"
) )
type LayerType int type LayerType int
@ -20,28 +18,15 @@ const (
type Layer struct { type Layer struct {
Id string `db:"mdl.id" json:"id"` Id string `db:"mdl.id" json:"id"`
DefinitionId string `db:"mdl.def_id" json:"definition_id"` DefinitionId string `db:"mdl.def_id" json:"definition_id"`
LayerOrder int `db:"mdl.layer_order" json:"layer_order"` LayerOrder string `db:"mdl.layer_order" json:"layer_order"`
LayerType LayerType `db:"mdl.layer_type" json:"layer_type"` LayerType LayerType `db:"mdl.layer_type" json:"layer_type"`
Shape string `db:"mdl.shape" json:"shape"` Shape string `db:"mdl.shape" json:"shape"`
ExpType int `db:"mdl.exp_type" json:"exp_type"` ExpType string `db:"mdl.exp_type" json:"exp_type"`
}
func (x *Layer) ShapeToSize() {
v := x.GetShape()
switch x.LayerType {
case LAYER_INPUT:
x.Shape = fmt.Sprintf("%d,%d", v[1], v[2])
case LAYER_DENSE:
x.Shape = fmt.Sprintf("(%d)", v[0])
default:
x.Shape = "ERROR"
}
} }
func ShapeToString(args ...int) string { func ShapeToString(args ...int) string {
text, err := json.Marshal(args) text, err := json.Marshal(args)
if err != nil { if err != nil {
log.Error("json err!", "err", err)
panic("Could not generate Shape") panic("Could not generate Shape")
} }
return string(text) return string(text)
@ -50,16 +35,12 @@ func ShapeToString(args ...int) string {
func StringToShape(str string) (shape []int64) { func StringToShape(str string) (shape []int64) {
err := json.Unmarshal([]byte(str), &shape) err := json.Unmarshal([]byte(str), &shape)
if err != nil { if err != nil {
log.Error("json err!", "err", err)
panic("Could not parse Shape") panic("Could not parse Shape")
} }
return return
} }
func (l Layer) GetShape() []int64 { func (l Layer) GetShape() []int64 {
if l.Shape == "" {
return []int64{}
}
return StringToShape(l.Shape) return StringToShape(l.Shape)
} }

View File

@ -3,6 +3,7 @@ package dbtypes
import ( import (
"errors" "errors"
"fmt" "fmt"
"path"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
) )
@ -50,16 +51,17 @@ const (
) )
type BaseModel struct { type BaseModel struct {
Name string `json:"name"` Name string
Status int `json:"status"` Status int
Id string `json:"id"` Id string
ModelType int `db:"model_type" json:"model_type"`
ImageModeRaw string `db:"color_mode" json:"image_more_raw"` ModelType int `db:"model_type"`
ImageMode int `db:"0" json:"image_mode"` ImageModeRaw string `db:"color_mode"`
Width int `json:"width"` ImageMode int `db:"0"`
Height int `json:"height"` Width int
Format string `json:"format"` Height int
CanTrain int `db:"can_train" json:"can_train"` Format string
CanTrain int `db:"can_train"`
} }
var ModelNotFoundError = errors.New("Model not found error") var ModelNotFoundError = errors.New("Model not found error")
@ -100,7 +102,6 @@ func (m *BaseModel) UpdateStatus(db db.Db, status ModelStatus) (err error) {
} }
type DataPoint struct { type DataPoint struct {
Id string `json:"id"`
Class int `json:"class"` Class int `json:"class"`
Path string `json:"path"` Path string `json:"path"`
} }
@ -125,11 +126,14 @@ func (m BaseModel) DataPoints(db db.Db, mode DATA_POINT_MODE) (data []DataPoint,
if err = rows.Scan(&id, &class_order, &file_path); err != nil { if err = rows.Scan(&id, &class_order, &file_path); err != nil {
return return
} }
if file_path == "id://" {
data = append(data, DataPoint{ data = append(data, DataPoint{
Id: id, Path: path.Join("./savedData", m.Id, "data", id+"."+m.Format),
Path: file_path,
Class: class_order, Class: class_order,
}) })
} else {
panic("TODO remote file path")
}
} }
return return
} }

View File

@ -14,19 +14,10 @@ const (
) )
type User struct { type User struct {
Id string `db:"u.id" json:"id"` Id string `db:"u.id"`
Username string `db:"u.username" json:"username"` Username string `db:"u.username"`
Email string `db:"u.email" json:"email"` Email string `db:"u.email"`
UserType int `db:"u.user_type" json:"user_type"` UserType int `db:"u.user_type"`
}
func UserFromId(db db.Db, id string) (*User, error) {
var user User
err := GetDBOnce(db, &user, "users as u where u.id=$1", id)
if err != nil {
return nil, err
}
return &user, nil
} }
func UserFromToken(db db.Db, token string) (*User, error) { func UserFromToken(db db.Db, token string) (*User, error) {

View File

@ -16,6 +16,7 @@ import (
) )
func loadBaseImage(c *Context, id string) { func loadBaseImage(c *Context, id string) {
// TODO handle more types than png
infile, err := os.Open(path.Join("savedData", id, "baseimage.png")) infile, err := os.Open(path.Join("savedData", id, "baseimage.png"))
if err != nil { if err != nil {
c.Logger.Errorf("Failed to read image for model with id %s\n", id) c.Logger.Errorf("Failed to read image for model with id %s\n", id)
@ -53,29 +54,21 @@ func loadBaseImage(c *Context, id string) {
model_color = "greyscale" model_color = "greyscale"
case color.NRGBAModel: case color.NRGBAModel:
fallthrough fallthrough
case color.RGBAModel:
fallthrough
case color.YCbCrModel: case color.YCbCrModel:
model_color = "rgb" model_color = "rgb"
default: default:
c.Logger.Error("Do not know how to handle this color model") c.Logger.Error("Do not know how to handle this color model")
if src.ColorModel() == color.RGBA64Model { if src.ColorModel() == color.RGBA64Model {
c.Logger.Error("Color is rgb 64") c.Logger.Error("Color is rgb")
} else if src.ColorModel() == color.NRGBA64Model { } else if src.ColorModel() == color.NRGBA64Model {
c.Logger.Error("Color is nrgb 64") c.Logger.Error("Color is nrgb 64")
} else if src.ColorModel() == color.AlphaModel { } else if src.ColorModel() == color.AlphaModel {
c.Logger.Error("Color is alpha") c.Logger.Error("Color is alpha")
} else if src.ColorModel() == color.CMYKModel { } else if src.ColorModel() == color.CMYKModel {
c.Logger.Error("Color is cmyk") c.Logger.Error("Color is cmyk")
} else if src.ColorModel() == color.NRGBA64Model {
c.Logger.Error("Color is cmyk")
} else if src.ColorModel() == color.NYCbCrAModel {
c.Logger.Error("Color is cmyk a")
} else if src.ColorModel() == color.Alpha16Model {
c.Logger.Error("Color is cmyk a")
} else { } else {
c.Logger.Error("Other so assuming color", "color mode", src.ColorModel()) c.Logger.Error("Other so assuming color")
} }
ModelUpdateStatus(c, id, FAILED_PREPARING) ModelUpdateStatus(c, id, FAILED_PREPARING)

View File

@ -136,16 +136,17 @@ func processZipFile(c *Context, model *BaseModel) {
return return
} }
if paths[0] == "training" { if paths[0] != "training" {
training = InsertIfNotPresent(training, paths[1]) training = InsertIfNotPresent(training, paths[1])
} else if paths[0] == "testing" { } else if paths[0] != "testing" {
testing = InsertIfNotPresent(testing, paths[1]) testing = InsertIfNotPresent(testing, paths[1])
} }
} }
if !reflect.DeepEqual(testing, training) { if !reflect.DeepEqual(testing, training) {
c.Logger.Warn("Diff", "testing", testing, "training", training) c.Logger.Info("Diff", "testing", testing, "training", training)
c.Logger.Warn("Testing and traing datasets differ") failed("Testing and Training datesets are diferent")
return
} }
base_path := path.Join("savedData", model.Id, "data") base_path := path.Join("savedData", model.Id, "data")
@ -265,15 +266,16 @@ func processZipFileExpand(c *Context, model *BaseModel) {
return return
} }
if paths[0] == "training" { if paths[0] != "training" {
training = InsertIfNotPresent(training, paths[1]) training = InsertIfNotPresent(training, paths[1])
} else if paths[0] == "testing" { } else if paths[0] != "testing" {
testing = InsertIfNotPresent(testing, paths[1]) testing = InsertIfNotPresent(testing, paths[1])
} }
} }
if !reflect.DeepEqual(testing, training) { if !reflect.DeepEqual(testing, training) {
c.GetLogger().Warn("testing and training differ", "testing", testing, "training", training) failed("testing and training are diferent")
return
} }
base_path := path.Join("savedData", model.Id, "data") base_path := path.Join("savedData", model.Id, "data")
@ -634,8 +636,7 @@ func handleDataUpload(handle *Handle) {
// TODO work in allowing the model to add new in the pre ready moment // TODO work in allowing the model to add new in the pre ready moment
if model.Status != READY { if model.Status != READY {
c.GetLogger().Error("Model not in the ready status", "status", model.Status) return c.JsonBadRequest("Model not in the correct state to add a more classes")
return c.JsonBadRequest("Model not in the correct state to add more classes")
} }
// TODO mk this path configurable // TODO mk this path configurable

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"os" "os"
"path" "path"
"runtime/debug"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@ -38,19 +37,11 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
return image.Scale(0, 255) return image.Scale(0, 255)
} }
func runModelNormal(model *BaseModel, def_id string, inputImage *tf.Tensor, data *RunnerModelData) (order int, confidence float32, err error) { func runModelNormal(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
order = 0 order = 0
err = nil err = nil
var tf_model *tg.Model = nil tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
if data.Id != nil && *data.Id == def_id {
tf_model = data.Model
} else {
tf_model = tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
data.Model = tf_model
data.Id = &def_id
}
results := tf_model.Exec([]tf.Output{ results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0), tf_model.Op("StatefulPartitionedCall", 0),
@ -134,15 +125,10 @@ func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.
return return
} }
type RunnerModelData struct { func ClassifyTask(base BasePack, task Task) (err error) {
Id *string
Model *tg.Model
}
func ClassifyTask(base BasePack, task Task, data *RunnerModelData) (err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
base.GetLogger().Error("Task failed due to", "error", r, "stack", string(debug.Stack())) base.GetLogger().Error("Task failed due to", "error", r)
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running")
} }
}() }()
@ -200,8 +186,6 @@ func ClassifyTask(base BasePack, task Task, data *RunnerModelData) (err error) {
if model.ModelType == 2 { if model.ModelType == 2 {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id) base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
data.Model = nil
data.Id = nil
vi, confidence, err = runModelExp(base, model, def_id, inputImage) vi, confidence, err = runModelExp(base, model, def_id, inputImage)
if err != nil { if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
@ -209,7 +193,7 @@ func ClassifyTask(base BasePack, task Task, data *RunnerModelData) (err error) {
} }
} else { } else {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id) base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
vi, confidence, err = runModelNormal(model, def_id, inputImage, data) vi, confidence, err = runModelNormal(base, model, def_id, inputImage)
if err != nil { if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
return return

View File

@ -38,8 +38,6 @@ func TestImgForModel(c *Context, model *BaseModel, path string) (result bool) {
model_color = "greyscale" model_color = "greyscale"
case color.NRGBAModel: case color.NRGBAModel:
fallthrough fallthrough
case color.RGBAModel:
fallthrough
case color.YCbCrModel: case color.YCbCrModel:
model_color = "rgb" model_color = "rgb"
default: default:

View File

@ -6,7 +6,6 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
"github.com/goccy/go-json" "github.com/goccy/go-json"
) )
@ -40,6 +39,7 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
} }
if model.ModelType == 2 { if model.ModelType == 2 {
panic("TODO")
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil { if full_error != nil {
l.Error("Failed to generate defintions", "err", full_error) l.Error("Failed to generate defintions", "err", full_error)
@ -57,6 +57,7 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
runners := handler.DataMap["runners"].(map[string]interface{}) runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{}) runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task runner["task"] = &task
runners[runner_id] = runner runners[runner_id] = runner
handler.DataMap["runners"] = runners handler.DataMap["runners"] = runners
@ -75,44 +76,4 @@ func CleanUpFailed(b BasePack, task *Task) {
l.Error("Failed to get status", err) l.Error("Failed to get status", err)
} }
} }
// Set the class status to trained
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("Failed to set class status")
return
}
}
func CleanUpFailedRetrain(b BasePack, task *Task) {
db := b.GetDb()
l := b.GetLogger()
model, err := GetBaseModel(db, *task.ModelId)
if err != nil {
l.Error("Failed to get model", "err", err)
} else {
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("Failed to get status", err)
}
}
ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
var defData struct {
Id string `db:"md.id"`
TargetAcuuracy float64 `db:"md.target_accuracy"`
}
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil {
log.Error("failed to get def data", err)
return
}
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
if err_ != nil {
panic(err_)
}
} }

View File

@ -101,10 +101,6 @@ func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, arg
return return
} }
func SetModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) {
return setModelClassStatus(c, status, filter, args...)
}
func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) { func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) {
db := c.GetDb() db := c.GetDb()
@ -161,23 +157,40 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool)
return return
} }
func trainDefinition(c BasePack, model *BaseModel, def Definition, load_prev bool) (accuracy float64, err error) { func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
l := c.GetLogger() l := c.GetLogger()
db := c.GetDb()
l.Warn("About to start training definition") l.Warn("About to start training definition")
accuracy = 0 accuracy = 0
layers, err := db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
layers, err := def.GetLayers(c.GetDb(), " order by layer_order asc;")
if err != nil { if err != nil {
return return
} }
defer layers.Close()
for _, layer := range layers { type layerrow struct {
layer.ShapeToSize() LayerType int
Shape string
LayerNum int
}
got := []layerrow{}
i := 1
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape); err != nil {
return
}
row.Shape = shapeToSize(row.Shape)
row.LayerNum = 1
got = append(got, row)
i = i + 1
} }
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", def.Id) run_path := path.Join("/tmp", model.Id, "defs", definition_id)
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
@ -202,17 +215,17 @@ func trainDefinition(c BasePack, model *BaseModel, def Definition, load_prev boo
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model.Id, "defs", def.Id) result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": layers, "Layers": got,
"Size": layers[0].Shape, "Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"RunPath": run_path, "RunPath": run_path,
"ColorMode": model.ImageMode, "ColorMode": model.ImageMode,
"Model": model, "Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN, "EPOCH_PER_RUN": EPOCH_PER_RUN,
"DefId": def.Id, "DefId": definition_id,
"LoadPrev": load_prev, "LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"), "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path), "SaveModelPath": path.Join(getDir(), result_path),
@ -339,7 +352,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i
return return
} }
func trainDefinitionExpandExp(c BasePack, model *BaseModel, def Definition, load_prev bool) (accuracy float64, err error) { func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
l := c.GetLogger() l := c.GetLogger()
@ -354,7 +367,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, def Definition, load
} }
// status = 2 (INIT) 3 (TRAINING) // status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", def.Id) heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil { if err != nil {
return return
} else if len(heads) == 0 { } else if len(heads) == 0 {
@ -373,49 +386,62 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, def Definition, load
return return
} }
layers, err := def.GetLayers(c.GetDb(), " order by layer_order asc;") layers, err := c.GetDb().Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
return return
} }
defer layers.Close()
var got []*Layer type layerrow struct {
LayerType int
Shape string
ExpType int
LayerNum int
}
got := []layerrow{}
i := 1 i := 1
var last *Layer = nil var last *layerrow = nil
got_2 := false got_2 := false
var first *Layer = nil var first *layerrow = nil
for _, layer := range layers { for layers.Next() {
layer.ShapeToSize() var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
// Keep track of the first layer so we can keep the size of the image // Keep track of the first layer so we can keep the size of the image
if first == nil { if first == nil {
first = layer first = &row
} }
if layer.ExpType == 2 { row.LayerNum = i
row.Shape = shapeToSize(row.Shape)
if row.ExpType == 2 {
if !got_2 { if !got_2 {
got = append(got, last) got = append(got, *last)
got_2 = true got_2 = true
} }
got = append(got, layer) got = append(got, row)
} }
last = layer last = &row
i += 1 i += 1
} }
got = append(got, &Layer{ got = append(got, layerrow{
LayerType: LAYER_DENSE, LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1), Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2, ExpType: 2,
LayerOrder: len(got), LayerNum: i,
}) })
l.Info("Got layers", "layers", got) l.Info("Got layers", "layers", got)
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-retrain") run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
@ -446,7 +472,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, def Definition, load
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model.Id, "defs", def.Id) result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": got, "Layers": got,
@ -502,7 +528,7 @@ func trainDefinitionExpandExp(c BasePack, model *BaseModel, def Definition, load
return return
} }
func trainDefinitionExp(c BasePack, model *BaseModel, def Definition, load_prev bool) (accuracy float64, err error) { func trainDefinitionExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
l := c.GetLogger() l := c.GetLogger()
db := c.GetDb() db := c.GetDb()
@ -518,7 +544,7 @@ func trainDefinitionExp(c BasePack, model *BaseModel, def Definition, load_prev
} }
// status = 2 (INIT) 3 (TRAINING) // status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](db, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", def.Id) heads, err := GetDbMultitple[ExpHead](db, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil { if err != nil {
return return
} else if len(heads) == 0 { } else if len(heads) == 0 {
@ -536,24 +562,42 @@ func trainDefinitionExp(c BasePack, model *BaseModel, def Definition, load_prev
return return
} }
layers, err := def.GetLayers(db, " order by layer_order asc;") layers, err := db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
return return
} }
defer layers.Close()
for _, layer := range layers { type layerrow struct {
layer.ShapeToSize() LayerType int
Shape string
ExpType int
LayerNum int
} }
layers = append(layers, &Layer{ got := []layerrow{}
i := 1
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
row.LayerNum = i
row.Shape = shapeToSize(row.Shape)
got = append(got, row)
i += 1
}
got = append(got, layerrow{
LayerType: LAYER_DENSE, LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1), Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2, ExpType: 2,
LayerOrder: len(layers), LayerNum: i,
}) })
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id) run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id)
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
@ -580,11 +624,11 @@ func trainDefinitionExp(c BasePack, model *BaseModel, def Definition, load_prev
} }
// Copy result around // Copy result around
result_path := path.Join("savedData", model.Id, "defs", def.Id) result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": layers, "Layers": got,
"Size": layers[0].Shape, "Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.Id, "HeadId": exp.Id,
"RunPath": run_path, "RunPath": run_path,
@ -652,6 +696,21 @@ func remove[T interface{}](lst []T, i int) []T {
return append(lst[:i], lst[i+1:]...) return append(lst[:i], lst[i+1:]...)
} }
type TrainModelRow struct {
id string
target_accuracy int
epoch int
acuracy float64
}
type TraingModelRowDefinitions []TrainModelRow
func (nf TraingModelRowDefinitions) Len() int { return len(nf) }
func (nf TraingModelRowDefinitions) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
func (nf TraingModelRowDefinitions) Less(i, j int) bool {
return nf[i].acuracy < nf[j].acuracy
}
type ToRemoveList []int type ToRemoveList []int
func (nf ToRemoveList) Len() int { return len(nf) } func (nf ToRemoveList) Len() int { return len(nf) }
@ -664,16 +723,30 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
db := c.GetDb() db := c.GetDb()
l := c.GetLogger() l := c.GetLogger()
defs_, err := model.GetDefinitions(db, "and md.status=$2", MODEL_DEFINITION_STATUS_INIT) definitionsRows, err := db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
if err != nil { if err != nil {
l.Error("Failed to train Model!", "err", err) l.Error("Failed to train Model! Err:")
l.Error(err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return return
} }
defer definitionsRows.Close()
var defs SortByAccuracyDefinitions = defs_ var definitions TraingModelRowDefinitions = []TrainModelRow{}
if len(defs) == 0 { for definitionsRows.Next() {
var rowv TrainModelRow
rowv.acuracy = 0
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
l.Error("Failed to train Model Could not read definition from db!Err:")
l.Error(err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return
}
definitions = append(definitions, rowv)
}
if len(definitions) == 0 {
l.Error("No Definitions defined!") l.Error("No Definitions defined!")
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return return
@ -684,29 +757,32 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
for { for {
var toRemove ToRemoveList = []int{} var toRemove ToRemoveList = []int{}
for i, def := range defs { for i, def := range definitions {
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, *def, !firstRound) accuracy, err := trainDefinition(c, model, def.id, !firstRound)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:", "err", err) l.Error("Failed to train definition!Err:", "err", err)
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
continue continue
} }
def.Epoch += EPOCH_PER_RUN def.epoch += EPOCH_PER_RUN
accuracy = accuracy * 100 accuracy = accuracy * 100
def.Accuracy = float64(accuracy) def.acuracy = float64(accuracy)
if accuracy >= float64(def.TargetAccuracy) { definitions[i].epoch += EPOCH_PER_RUN
definitions[i].acuracy = accuracy
if accuracy >= float64(def.target_accuracy) {
l.Info("Found a definition that reaches target_accuracy!") l.Info("Found a definition that reaches target_accuracy!")
_, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.Epoch, def.Id) _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:\n", "err", err) l.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
return err return err
} }
_, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:\n", "err", err) l.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
@ -717,14 +793,14 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
break break
} }
if def.Epoch > MAX_EPOCH { if def.epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy) fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.target_accuracy)
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
continue continue
} }
_, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.Id) _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:\n", "err", err) l.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING))
@ -742,26 +818,28 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
l.Info("Round done", "toRemove", toRemove) l.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove { for _, n := range toRemove {
defs = remove(defs, n) definitions = remove(definitions, n)
} }
len_def := len(defs) len_def := len(definitions)
if len_def == 0 { if len_def == 0 {
break break
} else if len_def == 1 { }
if len_def == 1 {
continue continue
} }
sort.Sort(sort.Reverse(defs)) sort.Sort(sort.Reverse(definitions))
acc := defs[0].Accuracy - 20.0 acc := definitions[0].acuracy - 20.0
l.Info("Training models, Highest acc", "acc", defs[0].Accuracy, "mod_acc", acc) l.Info("Training models, Highest acc", "acc", definitions[0].acuracy, "mod_acc", acc)
toRemove = []int{} toRemove = []int{}
for i, def := range defs { for i, def := range definitions {
if def.Accuracy < acc { if def.acuracy < acc {
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
} }
} }
@ -771,8 +849,8 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
sort.Sort(sort.Reverse(toRemove)) sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove { for _, n := range toRemove {
l.Warn("Removing definition not fast enough learning", "n", n) l.Warn("Removing definition not fast enough learning", "n", n)
ModelDefinitionUpdateStatus(c, defs[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, definitions[n].id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
defs = remove(defs, n) definitions = remove(definitions, n)
} }
} }
@ -840,18 +918,33 @@ func trainModel(c BasePack, model *BaseModel) (err error) {
return return
} }
type TrainModelRowUsable struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
Epoch int
Acuracy float64 `db:"0"`
}
type TrainModelRowUsables []*TrainModelRowUsable
func (nf TrainModelRowUsables) Len() int { return len(nf) }
func (nf TrainModelRowUsables) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] }
func (nf TrainModelRowUsables) Less(i, j int) bool {
return nf[i].Acuracy < nf[j].Acuracy
}
func trainModelExp(c BasePack, model *BaseModel) (err error) { func trainModelExp(c BasePack, model *BaseModel) (err error) {
l := c.GetLogger() l := c.GetLogger()
db := c.GetDb() db := c.GetDb()
defs_, err := model.GetDefinitions(db, " and status=$2;", MODEL_DEFINITION_STATUS_INIT) var definitions TrainModelRowUsables
definitions, err = GetDbMultitple[TrainModelRowUsable](db, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
if err != nil { if err != nil {
l.Error("Failed to get definitions") l.Error("Failed to get definitions")
return return
} }
var defs SortByAccuracyDefinitions = defs_ if len(definitions) == 0 {
if len(defs) == 0 {
l.Error("No Definitions defined!") l.Error("No Definitions defined!")
return errors.New("No Definitions found") return errors.New("No Definitions found")
} }
@ -861,9 +954,9 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
for { for {
var toRemove ToRemoveList = []int{} var toRemove ToRemoveList = []int{}
for i, def := range defs { for i, def := range definitions {
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinitionExp(c, model, *def, !firstRound) accuracy, err := trainDefinitionExp(c, model, def.Id, !firstRound)
if err != nil { if err != nil {
l.Error("Failed to train definition!Err:", "err", err) l.Error("Failed to train definition!Err:", "err", err)
ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
@ -872,10 +965,10 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
} }
def.Epoch += EPOCH_PER_RUN def.Epoch += EPOCH_PER_RUN
accuracy = accuracy * 100 accuracy = accuracy * 100
def.Accuracy = float64(accuracy) def.Acuracy = float64(accuracy)
defs[i].Epoch += EPOCH_PER_RUN definitions[i].Epoch += EPOCH_PER_RUN
defs[i].Accuracy = accuracy definitions[i].Acuracy = accuracy
if accuracy >= float64(def.TargetAccuracy) { if accuracy >= float64(def.TargetAccuracy) {
l.Info("Found a definition that reaches target_accuracy!") l.Info("Found a definition that reaches target_accuracy!")
@ -925,10 +1018,10 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
l.Info("Round done", "toRemove", toRemove) l.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove { for _, n := range toRemove {
defs = remove(defs, n) definitions = remove(definitions, n)
} }
len_def := len(defs) len_def := len(definitions)
if len_def == 0 { if len_def == 0 {
break break
@ -936,14 +1029,14 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
continue continue
} }
sort.Sort(sort.Reverse(defs)) sort.Sort(sort.Reverse(definitions))
acc := defs[0].Accuracy - 20.0 acc := definitions[0].Acuracy - 20.0
l.Info("Training models, Highest acc", "acc", defs[0].Accuracy, "mod_acc", acc) l.Info("Training models, Highest acc", "acc", definitions[0].Acuracy, "mod_acc", acc)
toRemove = []int{} toRemove = []int{}
for i, def := range defs { for i, def := range definitions {
if def.Accuracy < acc { if def.Acuracy < acc {
toRemove = append(toRemove, i) toRemove = append(toRemove, i)
} }
} }
@ -953,8 +1046,8 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
sort.Sort(sort.Reverse(toRemove)) sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove { for _, n := range toRemove {
l.Warn("Removing definition not fast enough learning", "n", n) l.Warn("Removing definition not fast enough learning", "n", n)
ModelDefinitionUpdateStatus(c, defs[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, definitions[n].Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
defs = remove(defs, n) definitions = remove(definitions, n)
} }
} }
@ -969,12 +1062,6 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return err return err
} }
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("All definitions failed to train! And Failed to set model status")
return err
}
l.Error("All definitions failed to train!") l.Error("All definitions failed to train!")
return err return err
} else if err != nil { } else if err != nil {
@ -1003,7 +1090,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return err return err
} }
if err = SplitModel(c, model); err != nil { if err = splitModel(c, model); err != nil {
err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
l.Error("Failed to split the model! And Failed to set class status") l.Error("Failed to split the model! And Failed to set class status")
@ -1036,7 +1123,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return return
} }
func SplitModel(c BasePack, model *BaseModel) (err error) { func splitModel(c BasePack, model *BaseModel) (err error) {
db := c.GetDb() db := c.GetDb()
l := c.GetLogger() l := c.GetLogger()
@ -1173,6 +1260,7 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
} }
db := c.GetDb() db := c.GetDb()
l := c.GetLogger()
def, err := MakeDefenition(db, model.Id, target_accuracy) def, err := MakeDefenition(db, model.Id, target_accuracy)
if err != nil { if err != nil {
@ -1191,7 +1279,34 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
} }
order++ order++
loop := max(1, int(math.Ceil((math.Log(float64(model.Width))/math.Log(float64(10)))))+1) if complexity == 0 {
/*
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
if err != nil {
failed()
return
}
order++
*/
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
if err != nil {
failed()
return
}
order++
loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
return
}
}
} else if complexity == 1 || complexity == 2 {
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
order++ order++
@ -1220,6 +1335,11 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
return return
} }
} }
} else {
l.Error("Unkown complexity", "complexity", complexity)
failed()
return
}
return def.UpdateStatus(db, DEFINITION_STATUS_INIT) return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
} }
@ -1290,16 +1410,29 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order := 1 order := 1
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height), 1) width := model.Width
height := model.Height
// Note the shape of the first layer defines the import size
if complexity == 2 {
// Note the shape for now is no used
width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
}
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
order++
// handle the errors inside the pervious if block
if err != nil { if err != nil {
failed() failed()
return return
} }
order++
// Create the blocks // Create the blocks
loop := int(math.Ceil((math.Log(float64(model.Width)) / math.Log(float64(10))))) + 1 loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
/*if model.Width < 50 && model.Height < 50 { /*if model.Width < 50 && model.Height < 50 {
loop = 0 loop = 0
@ -1307,7 +1440,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
log.Info("Size of the simple block", "loop", loop) log.Info("Size of the simple block", "loop", loop)
loop = max(loop, min(2, model.ImageMode)) //loop = max(loop, 3)
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayerExpandable(db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1) err = MakeLayerExpandable(db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
@ -1327,7 +1460,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order++ order++
// Flatten the blocks into dense // Flatten the blocks into dense
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*2), 1) err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1)
if err != nil { if err != nil {
failed() failed()
return return
@ -1341,7 +1474,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
loop = max(loop, 3) loop = max(loop, 3)
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)*2), 2) err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)*2), 2)
order++ order++
if err != nil { if err != nil {
failed() failed()
@ -1424,31 +1557,31 @@ func trainExpandable(c *Context, model *BaseModel) {
ResetClasses(c, model) ResetClasses(c, model)
} }
defs_, err := model.GetDefinitions(c, " and status=$2", MODEL_DEFINITION_STATUS_READY) var definitions TrainModelRowUsables
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil { if err != nil {
failed("Failed to get definitions") failed("Failed to get definitions")
return return
} }
var defs SortByAccuracyDefinitions = defs_ if len(definitions) != 1 {
if len(defs) != 1 {
failed("There should only be one definition available!") failed("There should only be one definition available!")
return return
} }
firstRound := true firstRound := true
def := defs[0] def := definitions[0]
epoch := 0 epoch := 0
for { for {
acc, err := trainDefinitionExp(c, model, *def, !firstRound) acc, err := trainDefinitionExp(c, model, def.Id, !firstRound)
if err != nil { if err != nil {
failed("Failed to train definition!") failed("Failed to train definition!")
return return
} }
epoch += EPOCH_PER_RUN epoch += EPOCH_PER_RUN
if float64(acc*100) >= float64(def.Accuracy) { if float64(acc*100) >= float64(def.Acuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!") c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING) _, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
@ -1553,18 +1686,22 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining") task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining")
defs, err := model.GetDefinitions(db, "") var defData struct {
Id string `db:"md.id"`
TargetAcuuracy float64 `db:"md.target_accuracy"`
}
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil { if err != nil {
failed() failed()
return return
} }
def := *defs[0]
failed = func() { failed = func() {
ResetClasses(b, model) ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED) ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining") task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining")
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", def.Id) _, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
if err_ != nil { if err_ != nil {
panic(err_) panic(err_)
} }
@ -1575,21 +1712,21 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
var epocs = 0 var epocs = 0
// TODO make max epochs come from db // TODO make max epochs come from db
// TODO re increase the target accuracy // TODO re increase the target accuracy
for acc*100 < float64(def.TargetAccuracy)-5 && epocs < 10 { for acc*100 < defData.TargetAcuuracy-5 && epocs < 10 {
// This is something I have to check // This is something I have to check
acc, err = trainDefinitionExpandExp(b, model, def, epocs > 0) acc, err = trainDefinitionExpandExp(b, model, defData.Id, epocs > 0)
if err != nil { if err != nil {
failed() failed()
return return
} }
l.Info("Retrained model", "accuracy", acc, "target", def.TargetAccuracy) l.Info("Retrained model", "accuracy", acc, "target", defData.TargetAcuuracy)
epocs += 1 epocs += 1
} }
if acc*100 < float64(def.TargetAccuracy)-5 { if acc*100 < defData.TargetAcuuracy {
l.Error("Model never achived targetd accuracy", "acc", acc*100, "target", def.TargetAccuracy) l.Error("Model never achived targetd accuracy", "acc", acc*100, "target", defData.TargetAcuuracy)
failed() failed()
return return
} }
@ -1610,13 +1747,6 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
return return
} }
_, err = db.Exec("update exp_model_head set status=$1 where status=$2 and def_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, def.Id)
if err != nil {
l.Error("Error while updating the classes", "error", err)
failed()
return
}
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining") task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
return return

View File

@ -68,7 +68,7 @@ func handleTasksStats(handle *Handle) {
} else if task.Status < 2 { } else if task.Status < 2 {
total.Classfication_pre += 1 total.Classfication_pre += 1
hours[hour].Classfication_pre += 1 hours[hour].Classfication_pre += 1
} else if task.Status < 4 || task.Status == 5 { } else if task.Status < 4 {
total.Classfication_running += 1 total.Classfication_running += 1
hours[hour].Classfication_running += 1 hours[hour].Classfication_running += 1
} }

7
logic/tasks/README.md Normal file
View File

@ -0,0 +1,7 @@
# Runner Protocol
```
/----\
\/ |
Register -> Init -> Active ---> Ready -> Info
```

View File

@ -9,5 +9,4 @@ func HandleTasks(handle *Handle) {
handleList(handle) handleList(handle)
handleRequests(handle) handleRequests(handle)
handleRemoteRunner(handle) handleRemoteRunner(handle)
handleRunnerData(handle)
} }

View File

@ -1,8 +1,6 @@
package tasks package tasks
import ( import (
"os"
"path"
"sync" "sync"
"time" "time"
@ -10,7 +8,6 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
) )
func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) { func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) {
@ -35,12 +32,6 @@ type VerifyTask struct {
TaskId string `json:"taskId" validate:"required"` TaskId string `json:"taskId" validate:"required"`
} }
type RunnerTrainDef struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"`
}
func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) { func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) {
mutex := x.DataMap["runners_mutex"].(*sync.Mutex) mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock() mutex.Lock()
@ -60,18 +51,6 @@ func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Erro
return runner_data["task"].(*Task), nil return runner_data["task"].(*Task), nil
} }
func clearRunnerTask(x *Handle, runner_id string) {
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[runner_id].(map[string]interface{})
runner_data["task"] = nil
runners[runner_id] = runner_data
x.DataMap["runners"] = runners
}
func handleRemoteRunner(x *Handle) { func handleRemoteRunner(x *Handle) {
type RegisterRunner struct { type RegisterRunner struct {
@ -221,10 +200,6 @@ func handleRemoteRunner(x *Handle) {
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
CleanUpFailed(c, task) CleanUpFailed(c, task)
case int(TASK_TYPE_RETRAINING):
CleanUpFailedRetrain(c, task)
case int(TASK_TYPE_CLASSIFICATION):
// DO nothing
default: default:
panic("Do not know how to handle this") panic("Do not know how to handle this")
} }
@ -243,7 +218,7 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
PostAuthJson(x, "/tasks/runner/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error { PostAuthJson(x, "/tasks/runner/train/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -254,15 +229,7 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
var status DefinitionStatus if task.TaskType != int(TASK_TYPE_TRAINING) {
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
status = DEFINITION_STATUS_INIT
case int(TASK_TYPE_RETRAINING):
fallthrough
case int(TASK_TYPE_CLASSIFICATION):
status = DEFINITION_STATUS_READY
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions") return c.JsonBadRequest("Task is not the right type go get the definitions")
} }
@ -272,7 +239,7 @@ func handleRemoteRunner(x *Handle) {
return c.E500M("Failed to get model information", err) return c.E500M("Failed to get model information", err)
} }
defs, err := model.GetDefinitions(c, "and md.status=$2", status) defs, err := model.GetDefinitions(c, "and md.status=$2", DEFINITION_STATUS_INIT)
if err != nil { if err != nil {
return c.E500M("Failed to get the model definitions", err) return c.E500M("Failed to get the model definitions", err)
} }
@ -280,7 +247,7 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON(defs) return c.SendJSON(defs)
}) })
PostAuthJson(x, "/tasks/runner/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error { PostAuthJson(x, "/tasks/runner/train/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -291,35 +258,22 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId) model, err := GetBaseModel(c, *task.ModelId)
if err != nil { if err != nil {
return c.E500M("Failed to get model information", err) return c.E500M("Failed to get model information", err)
} }
switch task.TaskType { classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN)
case int(TASK_TYPE_TRAINING):
classes, err := model.GetClasses(c, "and status in ($2, $3) order by mc.class_order asc", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
return c.E500M("Failed to get the model classes", err) return c.E500M("Failed to get the model classes", err)
} }
return c.SendJSON(classes)
case int(TASK_TYPE_RETRAINING):
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
case int(TASK_TYPE_CLASSIFICATION):
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINED)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
return c.SendJSON(classes)
}) })
type RunnerTrainDefStatus struct { type RunnerTrainDefStatus struct {
@ -357,13 +311,12 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
type RunnerTrainDefHeadStatus struct { type RunnerTrainDefLayers struct {
Id string `json:"id" validate:"required"` Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"` TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"` DefId string `json:"defId" validate:"required"`
Status ModelHeadStatus `json:"status" validate:"required"`
} }
PostAuthJson(x, "/tasks/runner/train/def/head/status", User_Normal, func(c *Context, dat *RunnerTrainDefHeadStatus) *Error { PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -384,69 +337,6 @@ func handleRemoteRunner(x *Handle) {
return c.E500M("Failed to get definition information", err) return c.E500M("Failed to get definition information", err)
} }
_, err = c.Exec("update exp_model_head set status=$1 where def_id=$2;", dat.Status, def.Id)
if err != nil {
log.Error("Failed to train definition!")
return c.E500M("Failed to train definition", err)
}
return c.SendJSON("Ok")
})
type RunnerRetrainDefHeadStatus struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
HeadId string `json:"defId" validate:"required"`
Status ModelHeadStatus `json:"status" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/retrain/def/head/status", User_Normal, func(c *Context, dat *RunnerRetrainDefHeadStatus) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_RETRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
if err := UpdateStatus(c.GetDb(), "exp_model_head", dat.HeadId, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return c.E500M("Failed to update head status", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
// Do nothing
case int(TASK_TYPE_RETRAINING):
// Do nothing
default:
c.Logger.Error("Task not is not the right type to get the layers", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the layers")
}
def, err := GetDefinition(c, dat.DefId)
if err != nil {
return c.E500M("Failed to get definition information", err)
}
layers, err := def.GetLayers(c, " order by layer_order asc") layers, err := def.GetLayers(c, " order by layer_order asc")
if err != nil { if err != nil {
return c.E500M("Failed to get layers", err) return c.E500M("Failed to get layers", err)
@ -466,12 +356,7 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
switch task.TaskType { if task.TaskType != int(TASK_TYPE_TRAINING) {
case int(TASK_TYPE_TRAINING):
// DO nothing
case int(TASK_TYPE_RETRAINING):
// DO nothing
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions") return c.JsonBadRequest("Task is not the right type go get the definitions")
} }
@ -498,463 +383,4 @@ func handleRemoteRunner(x *Handle) {
Training: training_points, Training: training_points,
}) })
}) })
type RunnerTrainDefEpoch struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"`
Epoch int `json:"epoch" validate:"required"`
Accuracy float64 `json:"accuracy" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/train/epoch", User_Normal, func(c *Context, dat *RunnerTrainDefEpoch) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{
Id: dat.Id,
TaskId: dat.TaskId,
})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
def, err := GetDefinition(c, dat.DefId)
if err != nil {
return c.E500M("Failed to get definition information", err)
}
err = def.UpdateAfterEpoch(c, dat.Accuracy, dat.Epoch)
if err != nil {
return c.E500M("Failed to update model", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/mark-failed", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{
Id: dat.Id,
TaskId: dat.TaskId,
})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
_, err := c.Exec(
"update model_definition set status=$1 "+
"where model_id=$2 and status in ($3, $4)",
MODEL_DEFINITION_STATUS_CANCELD_TRAINING,
task.ModelId,
MODEL_DEFINITION_STATUS_TRAINING,
MODEL_DEFINITION_STATUS_PAUSED_TRAINING,
)
if err != nil {
return c.E500M("Failed to mark definition as failed", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/model", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
//DO NOTHING
case int(TASK_TYPE_RETRAINING):
//DO NOTHING
case int(TASK_TYPE_CLASSIFICATION):
//DO NOTHING
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
return c.E500M("Failed to get model information", err)
}
return c.SendJSON(model)
})
PostAuthJson(x, "/tasks/runner/heads", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
type ExpHead struct {
Id string `json:"id"`
Start int `db:"range_start" json:"start"`
End int `db:"range_end" json:"end"`
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
fallthrough
case int(TASK_TYPE_RETRAINING):
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and status in (2,3)", dat.DefId)
if err != nil {
return c.E500M("Failed getting active heads", err)
}
return c.SendJSON(heads)
case int(TASK_TYPE_CLASSIFICATION):
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1", dat.DefId)
if err != nil {
return c.E500M("Failed getting active heads", err)
}
return c.SendJSON(heads)
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
})
PostAuthJson(x, "/tasks/runner/train_exp/class/status/train", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
return c.E500M("Failed to get model", err)
}
err = SetModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN)
if err != nil {
return c.E500M("Failed update status", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// TODO Make the Model status have a message
c.Logger.Error("All definitions failed to train!")
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
return c.SendJSON("Ok")
} else if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
return c.E500M("Failed to get model definition", err)
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
return c.E500M("Failed to update model definition", err)
}
to_delete, err := c.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
defer to_delete.Close()
for to_delete.Next() {
var id string
if err = to_delete.Scan(&id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
}
// TODO Check if returning also works here
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
// Set the class status to trained
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1;", model.Id)
if err != nil {
c.Logger.Error("Failed to set class status")
return c.E500M("Failed to set class status", err)
}
if err = model.UpdateStatus(c, READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to update status of model", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
})
type RunnerClassDone struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
Result string `json:"result" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/class/done", User_Normal, func(c *Context, dat *RunnerClassDone) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{
Id: dat.Id,
TaskId: dat.TaskId,
})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_CLASSIFICATION) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
err := task.SetResultText(c, dat.Result)
if err != nil {
return c.E500M("Failed to update the task", err)
}
err = task.UpdateStatus(c, TASK_DONE, "Task completed")
if err != nil {
return c.E500M("Failed to update task", err)
}
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
runner_data["task"] = nil
runners[dat.Id] = runner_data
x.DataMap["runners"] = runners
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train_exp/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
// TODO add check the to the model
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// TODO Make the Model status have a message
c.Logger.Error("All definitions failed to train!")
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
} else if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
return c.E500M("Failed to get model definition", err)
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
return c.E500M("Failed to update model definition", err)
}
to_delete, err := GetDbMultitple[JustId](c, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
c.GetLogger().Error("Failed to select model_definition to delete")
return c.E500M("Failed to select model definition to delete", err)
}
for _, d := range to_delete {
os.RemoveAll(path.Join("savedData", model.Id, "defs", d.Id))
}
// TODO Check if returning also works here
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
if err = SplitModel(c, model); err != nil {
err = SetModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
c.Logger.Error("Failed to split the model! And Failed to set class status")
return c.E500M("Failed to split the model", err)
}
c.Logger.Error("Failed to split the model")
return c.E500M("Failed to split the model", err)
}
// Set the class status to trained
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
c.Logger.Error("Failed to set class status")
return c.E500M("Failed to set class status", err)
}
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
if err = model.UpdateStatus(c, READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to update status of model", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
runner_data["task"] = nil
runners[dat.Id] = runner_data
x.DataMap["runners"] = runners
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/retrain/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_RETRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to set class status", err)
}
defs, err := model.GetDefinitions(c, "")
if err != nil {
return c.E500M("Failed to get definitions", err)
}
_, err = c.Exec("update exp_model_head set status=$1 where status=$2 and def_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, defs[0].Id)
if err != nil {
return c.E500M("Failed to set head status", err)
}
err = model.UpdateStatus(c, READY)
if err != nil {
return c.E500M("Failed to set class status", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
})
} }

View File

@ -19,8 +19,6 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
) )
var QUEUE_SIZE = 10
/** /**
* Actually runs the code * Actually runs the code
*/ */
@ -49,28 +47,17 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
Host: config.Hostname, Host: config.Hostname,
} }
loaded_model := RunnerModelData{
Id: nil,
Model: nil,
}
count := 0
for task := range task_channel { for task := range task_channel {
logger.Info("Got task", "task", task) logger.Info("Got task", "task", task)
task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task") task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task")
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) { if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
logger.Info("Classification Task") logger.Info("Classification Task")
if err = ClassifyTask(base, task, &loaded_model); err != nil { if err = ClassifyTask(base, task); err != nil {
logger.Error("Classification task failed", "error", err) logger.Error("Classification task failed", "error", err)
} }
if count == QUEUE_SIZE {
back_channel <- index back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} else if task.TaskType == int(TASK_TYPE_TRAINING) { } else if task.TaskType == int(TASK_TYPE_TRAINING) {
logger.Info("Training Task") logger.Info("Training Task")
@ -78,12 +65,7 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
if count == QUEUE_SIZE {
back_channel <- index back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} else if task.TaskType == int(TASK_TYPE_RETRAINING) { } else if task.TaskType == int(TASK_TYPE_RETRAINING) {
logger.Info("Retraining Task") logger.Info("Retraining Task")
@ -91,12 +73,7 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
if count == QUEUE_SIZE {
back_channel <- index back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} else if task.TaskType == int(TASK_TYPE_DELETE_USER) { } else if task.TaskType == int(TASK_TYPE_DELETE_USER) {
logger.Warn("User deleting Task") logger.Warn("User deleting Task")
@ -104,23 +81,13 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
if count == QUEUE_SIZE {
back_channel <- index back_channel <- index
count = 0
} else {
count += 1
}
continue continue
} }
logger.Error("Do not know how to route task", "task", task) logger.Error("Do not know how to route task", "task", task)
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task") task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task")
if count == QUEUE_SIZE {
back_channel <- index back_channel <- index
count = 0
} else {
count += 1
}
} }
} }
@ -153,22 +120,10 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
defer mutex.Unlock() defer mutex.Unlock()
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_RETRAINING):
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
if err := PrepareTraining(handler, base, task, runner_id); err != nil { if err := PrepareTraining(handler, base, task, runner_id); err != nil {
logger.Error("Failed to prepare for training", "err", err) logger.Error("Failed to prepare for training", "err", err)
} }
case int(TASK_TYPE_CLASSIFICATION):
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
default: default:
logger.Error("Not sure what to do panicing", "taskType", task.TaskType) logger.Error("Not sure what to do panicing", "taskType", task.TaskType)
panic("not sure what to do") panic("not sure what to do")
@ -178,7 +133,7 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
/** /**
* Tells the orcchestator to look at the task list from time to time * Tells the orcchestator to look at the task list from time to time
*/ */
func attentionSeeker(config Config, db db.Db, back_channel chan int) { func attentionSeeker(config Config, back_channel chan int) {
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true, ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
@ -203,20 +158,6 @@ func attentionSeeker(config Config, db db.Db, back_channel chan int) {
for true { for true {
back_channel <- 0 back_channel <- 0
for {
var s struct {
Count int `json:"count(*)"`
}
err := GetDBOnce(db, &s, "tasks where stauts = 5 or status = 3")
if err != nil {
break
}
if s.Count == 0 {
break
}
time.Sleep(t)
}
time.Sleep(t) time.Sleep(t)
} }
} }
@ -232,7 +173,9 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
Prefix: "Runner Orchestrator Logger", Prefix: "Runner Orchestrator Logger",
}) })
setupHandle(handler) // Setup vars
handler.DataMap["runners"] = map[string]interface{}{}
handler.DataMap["runners_mutex"] = &sync.Mutex{}
base := BasePackStruct{ base := BasePackStruct{
Db: db, Db: db,
@ -241,22 +184,17 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
} }
gpu_workers := config.GpuWorker.NumberOfWorkers gpu_workers := config.GpuWorker.NumberOfWorkers
def_wait, err := time.ParseDuration(config.GpuWorker.Pulling)
if err != nil {
logger.Error("Failed to load", "error", err)
return
}
logger.Info("Starting runners") logger.Info("Starting runners")
task_runners := make([]chan Task, gpu_workers) task_runners := make([]chan Task, gpu_workers)
task_runners_used := make([]int, gpu_workers) task_runners_used := make([]bool, gpu_workers)
// One more to accomudate the Attention Seeker channel // One more to accomudate the Attention Seeker channel
back_channel := make(chan int, gpu_workers+1) back_channel := make(chan int, gpu_workers+1)
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
logger.Error("Recovered in Orchestrator restarting", "due to", r, "stack", string(debug.Stack())) logger.Error("Recovered in Orchestrator restarting", "due to", r)
for x := range task_runners { for x := range task_runners {
close(task_runners[x]) close(task_runners[x])
} }
@ -265,70 +203,65 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
} }
}() }()
// go attentionSeeker(config, db, back_channel) go attentionSeeker(config, back_channel)
// Start the runners // Start the runners
for i := 0; i < gpu_workers; i++ { for i := 0; i < gpu_workers; i++ {
task_runners[i] = make(chan Task, QUEUE_SIZE) task_runners[i] = make(chan Task, 10)
task_runners_used[i] = 0 task_runners_used[i] = false
AddLocalRunner(handler, LocalRunner{
RunnerNum: i + 1,
Task: nil,
})
go runner(config, db, task_runners[i], i+1, back_channel) go runner(config, db, task_runners[i], i+1, back_channel)
} }
used := 0 var task_to_dispatch *Task = nil
wait := time.Nanosecond * 100
for { for i := range back_channel {
out := true
for out {
select {
case i := <-back_channel:
if i != 0 {
if i > 0 { if i > 0 {
logger.Info("Runner freed", "runner", i) logger.Info("Runner freed", "runner", i)
task_runners_used[i-1] = 0 task_runners_used[i-1] = false
used = 0
} else if i < 0 { } else if i < 0 {
logger.Error("Runner died! Restarting!", "runner", i) logger.Error("Runner died! Restarting!", "runner", i)
i = int(math.Abs(float64(i)) - 1) i = int(math.Abs(float64(i)) - 1)
task_runners_used[i] = 0 task_runners_used[i] = false
used = 0
go runner(config, db, task_runners[i], i+1, back_channel) go runner(config, db, task_runners[i], i+1, back_channel)
} }
AddLocalTask(handler, int(math.Abs(float64(i))), nil)
} else if used == len(task_runners_used) {
continue
}
case <-time.After(wait):
if wait == time.Nanosecond*100 {
wait = def_wait
}
out = false
}
}
for { if task_to_dispatch == nil {
tasks, err := GetDbMultitple[TaskT](db, "tasks as t "+ var task TaskT
err := GetDBOnce(db, &task, "tasks as t "+
// Get depenencies // Get depenencies
"left join tasks_dependencies as td on t.id=td.main_id "+ "left join tasks_dependencies as td on t.id=td.main_id "+
// Get the task that the depencey resolves to // Get the task that the depencey resolves to
"left join tasks as t2 on t2.id=td.dependent_id "+ "left join tasks as t2 on t2.id=td.dependent_id "+
"where t.status=1 "+ "where t.status=1 "+
"group by t.id having count(td.id) filter (where t2.status in (0,1,2,3)) = 0 limit 20;") "group by t.id having count(td.id) filter (where t2.status in (0,1,2,3)) = 0;")
if err != NotFoundError && err != nil { if err != NotFoundError && err != nil {
log.Error("Failed to get tasks from db", "err", err) log.Error("Failed to get tasks from db", "err", err)
continue continue
} }
if err == NotFoundError || len(tasks) == 0 { if err == NotFoundError {
break task_to_dispatch = nil
} else {
temp := Task(task)
task_to_dispatch = &temp
}
}
if task_to_dispatch != nil {
// Only let CPU tasks be done by the local users
if task_to_dispatch.TaskType == int(TASK_TYPE_DELETE_USER) {
for i := 0; i < len(task_runners_used); i += 1 {
if !task_runners_used[i] {
task_runners[i] <- *task_to_dispatch
task_runners_used[i] = true
task_to_dispatch = nil
break
}
}
continue
} }
for _, task_to_dispatch := range tasks {
ttd := Task(*task_to_dispatch)
if task_to_dispatch != nil && task_to_dispatch.TaskType != int(TASK_TYPE_DELETE_USER) {
// TODO split tasks into cpu tasks and GPU tasks
mutex := handler.DataMap["runners_mutex"].(*sync.Mutex) mutex := handler.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock() mutex.Lock()
remote_runners := handler.DataMap["runners"].(map[string]interface{}) remote_runners := handler.DataMap["runners"].(map[string]interface{})
@ -341,44 +274,16 @@ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) {
continue continue
} }
if runner_info.UserId != task_to_dispatch.UserId { if runner_info.UserId == task_to_dispatch.UserId {
continue go handleRemoteTask(handler, base, k, *task_to_dispatch)
}
go handleRemoteTask(handler, base, k, ttd)
task_to_dispatch = nil task_to_dispatch = nil
break break
} }
}
mutex.Unlock() mutex.Unlock()
} }
used = 0
if task_to_dispatch != nil {
for i := 0; i < len(task_runners_used); i += 1 {
if task_runners_used[i] <= QUEUE_SIZE {
ttd.UpdateStatusLog(base, TASK_QUEUED, "Runner picked up task")
task_runners[i] <- ttd
task_runners_used[i] += 1
AddLocalTask(handler, i+1, &ttd)
task_to_dispatch = nil
wait = time.Nanosecond * 100
break
} else {
used += 1
}
}
}
if used == len(task_runners_used) {
break
}
}
if used == len(task_runners_used) {
break
}
}
} }
} }

View File

@ -1,51 +0,0 @@
package task_runner
import (
"sync"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
type LocalRunner struct {
RunnerNum int `json:"id"`
Task *Task `json:"task"`
}
type LocalRunners map[int]*LocalRunner
func LockRunners(handler *Handle, t string) *sync.Mutex {
req := t + "_runners_mutex"
if t == "" {
req = "runners_mutex"
}
mutex := handler.DataMap[req].(*sync.Mutex)
mutex.Lock()
return mutex
}
func setupHandle(handler *Handle) {
// Setup Remote Runner data
handler.DataMap["runners"] = map[string]interface{}{}
handler.DataMap["runners_mutex"] = &sync.Mutex{}
// Setup Local Runner data
handler.DataMap["local_runners"] = &LocalRunners{}
handler.DataMap["local_runners_mutex"] = &sync.Mutex{}
}
func AddLocalRunner(handler *Handle, runner LocalRunner) {
mutex := LockRunners(handler, "local")
defer mutex.Unlock()
runners := handler.DataMap["local_runners"].(*LocalRunners)
(*runners)[runner.RunnerNum] = &runner
}
func AddLocalTask(handler *Handle, runner_id int, task *Task) {
mutex := LockRunners(handler, "local")
defer mutex.Unlock()
runners := handler.DataMap["local_runners"].(*LocalRunners)
(*(*runners)[runner_id]).Task = task
}

View File

@ -1,25 +0,0 @@
package tasks
import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
func handleRunnerData(x *Handle) {
type NonType struct{}
PostAuthJson(x, "/tasks/runner/info", User_Admin, func(c *Context, dat *NonType) *Error {
mutex_remote := LockRunners(x, "")
defer mutex_remote.Unlock()
mutex_local := LockRunners(x, "local")
defer mutex_local.Unlock()
return c.SendJSON(struct {
RemoteRunners map[string]interface{} `json:"remoteRunners"`
LocalRunner *LocalRunners `json:"localRunners"`
}{
RemoteRunners: x.DataMap["runners"].(map[string]interface{}),
LocalRunner: x.DataMap["local_runners"].(*LocalRunners),
})
})
}

View File

@ -50,7 +50,6 @@ const (
TASK_PREPARING = 0 TASK_PREPARING = 0
TASK_TODO = 1 TASK_TODO = 1
TASK_PICKED_UP = 2 TASK_PICKED_UP = 2
TASK_QUEUED = 5
TASK_RUNNING = 3 TASK_RUNNING = 3
TASK_DONE = 4 TASK_DONE = 4
) )
@ -102,11 +101,7 @@ func (t Task) SetResult(base BasePack, result any) (err error) {
if err != nil { if err != nil {
return return
} }
return t.SetResultText(base, string(text)) _, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", text, t.Id)
}
func (t Task) SetResultText(base BasePack, text string) (err error) {
_, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", []byte(text), t.Id)
return return
} }

View File

@ -241,17 +241,6 @@ func UsersEndpints(db db.Db, handle *Handle) {
return c.SendJSON(userReturn) return c.SendJSON(userReturn)
}) })
PostAuthJson(handle, "/user/info/get", User_Admin, func(c *Context, dat *JustId) *Error {
var user *User
user, err := UserFromId(c, dat.Id)
if err == NotFoundError {
return c.SendJSONStatus(404, "User not found")
} else if err != nil {
return c.E500M("Could not get user", err)
}
return c.SendJSON(user)
})
// Handles updating users // Handles updating users
type UpdateUserData struct { type UpdateUserData struct {
Id string `json:"id"` Id string `json:"id"`

View File

@ -24,11 +24,6 @@ type ServiceUser struct {
type DbInfo struct { type DbInfo struct {
MaxConnections int `toml:"max_connections"` MaxConnections int `toml:"max_connections"`
Host string `toml:"host"`
Port int `toml:"port"`
User string `toml:"user"`
Password string `toml:"password"`
Dbname string `toml:"dbname"`
} }
type Config struct { type Config struct {
@ -102,7 +97,7 @@ func (c *Config) Cleanup(db db.Db) {
failLog(err) failLog(err)
_, err = db.Exec("update models set status=$1 where status=$2", FAILED_PREPARING, PREPARING) _, err = db.Exec("update models set status=$1 where status=$2", FAILED_PREPARING, PREPARING)
failLog(err) failLog(err)
_, err = db.Exec("update tasks set status=$1 where status=$2 or status=$3", TASK_TODO, TASK_PICKED_UP, TASK_QUEUED) _, err = db.Exec("update tasks set status=$1 where status=$2", TASK_TODO, TASK_PICKED_UP)
failLog(err) failLog(err)
tasks, err := GetDbMultitple[Task](db, "tasks where status=$1", TASK_RUNNING) tasks, err := GetDbMultitple[Task](db, "tasks where status=$1", TASK_RUNNING)
@ -119,16 +114,12 @@ func (c *Config) Cleanup(db db.Db) {
tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again") tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again")
_, err = db.Exec("update models set status=$1 where id=$2", READY_RETRAIN_FAILED, tasks[i].ModelId) _, err = db.Exec("update models set status=$1 where id=$2", READY_RETRAIN_FAILED, tasks[i].ModelId)
failLog(err) failLog(err)
_, err = db.Exec("update model_classes set status=$1 where model_id=$2 and status=$3", CLASS_STATUS_TO_TRAIN, tasks[i].ModelId, CLASS_STATUS_TRAINING)
failLog(err)
continue continue
} }
if tasks[i].TaskType == int(TASK_TYPE_TRAINING) { if tasks[i].TaskType == int(TASK_TYPE_TRAINING) {
tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again") tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again")
_, err = db.Exec("update models set status=$1 where id=$2", FAILED_TRAINING, tasks[i].ModelId) _, err = db.Exec("update models set status=$1 where id=$2", FAILED_TRAINING, tasks[i].ModelId)
failLog(err) failLog(err)
_, err = db.Exec("update model_classes set status=$1 where model_id=$2 and status=$3", CLASS_STATUS_TO_TRAIN, tasks[i].ModelId, CLASS_STATUS_TRAINING)
failLog(err)
continue continue
} }
} }

View File

@ -449,7 +449,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true, ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
TimeFormat: time.DateTime, TimeFormat: time.Kitchen,
Prefix: r.URL.Path, Prefix: r.URL.Path,
}) })

15
main.go
View File

@ -15,18 +15,25 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
) )
func main() { const (
host = "localhost"
port = 5432
user = "postgres"
password = "verysafepassword"
dbname = "aistuff"
)
config := LoadConfig() func main() {
log.Info("Config loaded!", "config", config)
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
"password=%s dbname=%s sslmode=disable", "password=%s dbname=%s sslmode=disable",
config.DbInfo.Host, config.DbInfo.Port, config.DbInfo.User, config.DbInfo.Password, config.DbInfo.Dbname) host, port, user, password, dbname)
db := db.StartUp(psqlInfo) db := db.StartUp(psqlInfo)
defer db.Close() defer db.Close()
config := LoadConfig()
log.Info("Config loaded!", "config", config)
config.GenerateToken(db) config.GenerateToken(db)
//TODO check if file structure exists to save data //TODO check if file structure exists to save data

View File

@ -1,6 +1,6 @@
events { events {
worker_connections 2024; worker_connections 1024;
} }
http { http {
@ -17,7 +17,7 @@ http {
location / { location / {
proxy_http_version 1.1; proxy_http_version 1.1;
proxy_pass http://webpage:5001; proxy_pass http://127.0.0.1:5001;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header Upgrade $http_upgrade; proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade; proxy_set_header Connection $connection_upgrade;
@ -25,7 +25,7 @@ http {
location /api { location /api {
proxy_pass http://server:5002; proxy_pass http://127.0.0.1:5002;
} }
} }
} }

View File

@ -1,4 +1,5 @@
tensorflow[and-cuda] == 2.15.1 # tensorflow[and-cuda] == 2.15.1
tensorflow[and-cuda] == 2.9.1
pandas pandas
# Make sure to install the nvidia pyindex first # Make sure to install the nvidia pyindex first
# nvidia-pyindex # nvidia-pyindex

2
run.sh
View File

@ -1,2 +1,2 @@
#!/bin/bash #!/bin/bash
podman run --network host --gpus all --replace --name fyp-server --ulimit=nofile=100000:100000 -it -v $(pwd):/app -e "TERM=xterm-256color" --restart=always andre-fyp-server podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-server bash

1
runner/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
target/

1936
runner/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

17
runner/Cargo.toml Normal file
View File

@ -0,0 +1,17 @@
[package]
name = "runner"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.82"
serde = { version = "1.0.200", features = ["derive"] }
toml = "0.8.12"
reqwest = { version = "0.12", features = ["json"] }
tokio = { version = "1", features = ["full"] }
serde_json = "1.0.116"
serde_repr = "0.1"
tch = { version = "0.16.0", features = ["download-libtorch"] }
rand = "0.8.5"

12
runner/Dockerfile Normal file
View File

@ -0,0 +1,12 @@
FROM docker.io/nvidia/cuda:11.7.1-devel-ubuntu22.04
RUN apt-get update
RUN apt-get install -y curl
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
ENV PATH="$PATH:/root/.cargo/bin"
RUN rustup toolchain install stable
RUN apt-get install -y pkg-config libssl-dev
WORKDIR /app

3
runner/config.toml Normal file
View File

@ -0,0 +1,3 @@
hostname = "https://testing.andr3h3nriqu3s.com/api"
token = "d2bc41e8293937bcd9397870c98f97acc9603f742924b518e193cd1013e45d57897aa302b364001c72b458afcfb34239dfaf38a66b318e5cbc973eea"
data_path = "/home/andr3/Documents/my-repos/fyp"

1
runner/data.toml Normal file
View File

@ -0,0 +1 @@
id = "a7cec9e9-1d05-4633-8bc5-6faabe4fd5a3"

2
runner/run.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-runner bash

115
runner/src/dataloader.rs Normal file
View File

@ -0,0 +1,115 @@
use crate::{model::DataPoint, settings::ConfigFile};
use std::{path::Path, sync::Arc};
use tch::Tensor;
pub struct DataLoader {
pub batch_size: i64,
pub len: usize,
pub inputs: Vec<Tensor>,
pub labels: Vec<Tensor>,
pub pos: usize,
}
fn import_image(
item: &DataPoint,
base_path: &Path,
classes_len: i64,
inputs: &mut Vec<Tensor>,
labels: &mut Vec<Tensor>,
) {
inputs.push(
tch::vision::image::load(base_path.join(&item.path))
.ok()
.unwrap()
.unsqueeze(0),
);
if item.class >= 0 {
let t = tch::Tensor::from_slice(&[item.class]).onehot(classes_len as i64);
labels.push(t);
} else {
labels.push(tch::Tensor::zeros(
[1, classes_len as i64],
(tch::Kind::Float, tch::Device::Cpu),
))
}
}
impl DataLoader {
pub fn new(
config: Arc<ConfigFile>,
data: Vec<DataPoint>,
classes_len: i64,
batch_size: i64,
) -> DataLoader {
let len: f64 = (data.len() as f64) / (batch_size as f64);
let min_len: i64 = len.floor() as i64;
let max_len: i64 = len.ceil() as i64;
println!(
"Creating dataloader data len: {} len: {} min_len: {} max_len:{}",
data.len(),
len,
min_len,
max_len
);
let base_path = Path::new(&config.data_path);
let mut inputs: Vec<Tensor> = Vec::new();
let mut all_labels: Vec<Tensor> = Vec::new();
for batch in 0..min_len {
let mut batch_acc: Vec<Tensor> = Vec::new();
let mut labels: Vec<Tensor> = Vec::new();
for image in 0..batch_size {
let i: usize = (batch * batch_size + image).try_into().unwrap();
let item = &data[i];
import_image(item, base_path, classes_len, &mut batch_acc, &mut labels)
}
inputs.push(tch::Tensor::cat(&batch_acc[0..], 0));
all_labels.push(tch::Tensor::cat(&labels[0..], 0));
}
// Import the last batch that has irregular sizing
if min_len != max_len {
let mut batch_acc: Vec<Tensor> = Vec::new();
let mut labels: Vec<Tensor> = Vec::new();
for image in 0..(data.len() - (batch_size * min_len) as usize) {
let i: usize = (min_len * batch_size + (image as i64)) as usize;
let item = &data[i];
import_image(item, base_path, classes_len, &mut batch_acc, &mut labels);
}
inputs.push(tch::Tensor::cat(&batch_acc[0..], 0));
all_labels.push(tch::Tensor::cat(&labels[0..], 0));
}
println!("ins shape: {:?}", inputs[0].size());
return DataLoader {
batch_size,
inputs,
labels: all_labels,
len: max_len as usize,
pos: 0,
};
}
pub fn restart(self: &mut DataLoader) {
self.pos = 0;
}
pub fn next(self: &mut DataLoader) -> Option<(Tensor, Tensor)> {
if self.pos >= self.len {
return None;
}
let input = self.inputs[self.pos].empty_like();
self.inputs[self.pos] = self.inputs[self.pos].clone(&input);
let label = self.labels[self.pos].empty_like();
self.labels[self.pos] = self.labels[self.pos].clone(&label);
self.pos += 1;
return Some((input, label));
}
}

206
runner/src/main.rs Normal file
View File

@ -0,0 +1,206 @@
mod dataloader;
mod model;
mod settings;
mod tasks;
mod training;
mod types;
use crate::settings::*;
use crate::tasks::{fail_task, Task, TaskType};
use crate::training::handle_train;
use anyhow::{bail, Result};
use reqwest::StatusCode;
use serde_json::json;
use std::{fs, process::exit, sync::Arc, time::Duration};
enum ResultAlive {
Ok,
Error,
NotInit,
}
async fn send_keep_alive_message(
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> ResultAlive {
let client = reqwest::Client::new();
let to_send = json!({
"id": runner_data.id,
});
let resp = client
.post(format!("{}/tasks/runner/beat", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await;
if resp.is_err() {
return ResultAlive::Error;
}
let resp = resp.ok();
if resp.is_none() {
return ResultAlive::Error;
}
let resp = resp.unwrap();
// TODO see if the message is related to not being inited
if resp.status() != 200 {
println!("Could not connect with the status");
return ResultAlive::Error;
}
ResultAlive::Ok
}
async fn keep_alive(config: Arc<ConfigFile>, runner_data: Arc<RunnerData>) -> Result<()> {
let mut failed = 0;
loop {
match send_keep_alive_message(config.clone(), runner_data.clone()).await {
ResultAlive::Error => failed += 1,
ResultAlive::NotInit => {
println!("Runner not inited! Restarting!");
exit(1)
}
ResultAlive::Ok => failed = 0,
}
// TODO move to config
if failed > 20 {
println!("Failed to connect to API! More than 20 times in a row stoping");
exit(1)
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
async fn handle_task(
task: Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> Result<()> {
let res = match task.task_type {
TaskType::Training => handle_train(&task, config.clone(), runner_data.clone()).await,
_ => {
println!("Do not know how to handle this task #{:?}", task);
bail!("Failed")
}
};
if res.is_err() {
println!("task failed #{:?}", res);
fail_task(
&task,
config,
runner_data,
"Do not know how to handle this kind of task",
)
.await?
}
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
// Load config file
let config_data = fs::read_to_string("./config.toml")?;
let mut config: ConfigFile = toml::from_str(&config_data)?;
let client = reqwest::Client::new();
if config.config_path == None {
config.config_path = Some(String::from("./data.toml"))
}
let runner_data: RunnerData = load_runner_data(&config).await?;
let to_send = json!({
"id": runner_data.id,
});
// Inform the server that the runner is available
let resp = client
.post(format!("{}/tasks/runner/init", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if resp.status() != 200 {
println!(
"Could not connect with the api: status {} body {}",
resp.status(),
resp.text().await?
);
return Ok(());
}
let res = resp.json::<String>().await?;
if res != "Ok" {
print!("Unexpected problem: {}", res);
return Ok(());
}
let config = Arc::new(config);
let runner_data = Arc::new(runner_data);
let config_alive = config.clone();
let runner_data_alive = runner_data.clone();
std::thread::spawn(move || keep_alive(config_alive.clone(), runner_data_alive.clone()));
println!("Started main loop");
loop {
//TODO move time to config
tokio::time::sleep(Duration::from_secs(1)).await;
let to_send = json!({ "id": runner_data.id });
let resp = client
.post(format!("{}/tasks/runner/active", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await;
if resp.is_err() || resp.as_ref().ok().is_none() {
println!("Failed to get info from server {:?}", resp);
continue;
}
let resp = resp?;
match resp.status() {
// No active task
StatusCode::NOT_FOUND => (),
StatusCode::OK => {
println!("Found task!");
let task: Result<Task, reqwest::Error> = resp.json().await;
if task.is_err() || task.as_ref().ok().is_none() {
println!("Failed to resolve the json {:?}", task);
continue;
}
let task = task?;
let res = handle_task(task, config.clone(), runner_data.clone()).await;
if res.is_err() || res.as_ref().ok().is_none() {
println!("Failed to run the task");
}
_ = res;
()
}
_ => {
println!("Unexpected error #{:?}", resp);
exit(1)
}
}
}
}

117
runner/src/model/mod.rs Normal file
View File

@ -0,0 +1,117 @@
use anyhow::bail;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use tch::{
nn::{self, Module},
Device,
};
#[derive(Debug)]
pub struct Model {
pub vs: nn::VarStore,
pub seq: nn::Sequential,
pub layers: Vec<Layer>,
}
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)]
#[repr(i8)]
pub enum LayerType {
Input = 1,
Dense = 2,
Flatten = 3,
SimpleBlock = 4,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Layer {
pub id: String,
pub definition_id: String,
pub layer_order: String,
pub layer_type: LayerType,
pub shape: String,
pub exp_type: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DataPoint {
pub class: i64,
pub path: String,
}
pub fn build_model(layers: Vec<Layer>, last_linear_size: i64, add_sigmoid: bool) -> Model {
let vs = nn::VarStore::new(Device::Cuda(0));
let mut seq = nn::seq();
let mut last_linear_size = last_linear_size;
let mut last_linear_conv: Vec<i64> = Vec::new();
for layer in layers.iter() {
match layer.layer_type {
LayerType::Input => {
last_linear_conv = serde_json::from_str(&layer.shape).ok().unwrap();
println!("Layer: Input, In: {:?}", last_linear_conv);
}
LayerType::Dense => {
let shape: Vec<i64> = serde_json::from_str(&layer.shape).ok().unwrap();
println!("Layer: Dense, In: {}, Out: {}", last_linear_size, shape[0]);
seq = seq
.add(nn::linear(
&vs.root(),
last_linear_size,
shape[0],
Default::default(),
))
.add_fn(|xs| xs.relu());
last_linear_size = shape[0];
}
LayerType::Flatten => {
seq = seq.add_fn(|xs| xs.flatten(1, -1));
last_linear_size = 1;
for i in &last_linear_conv {
last_linear_size *= i;
}
println!(
"Layer: flatten, In: {:?}, Out: {}",
last_linear_conv, last_linear_size
)
}
LayerType::SimpleBlock => {
let new_last_linear_conv =
vec![128, last_linear_conv[1] / 2, last_linear_conv[2] / 2];
println!(
"Layer: block, In: {:?}, Put: {:?}",
last_linear_conv, new_last_linear_conv,
);
let out_size = vec![new_last_linear_conv[1], new_last_linear_conv[2]];
seq = seq
.add(nn::conv2d(
&vs.root(),
last_linear_conv[0],
128,
3,
nn::ConvConfig::default(),
))
.add_fn(|xs| xs.relu())
.add(nn::conv2d(
&vs.root(),
128,
128,
3,
nn::ConvConfig::default(),
))
.add_fn(|xs| xs.relu())
.add_fn(move |xs| xs.adaptive_avg_pool2d([out_size[1], out_size[1]]))
.add_fn(|xs| xs.leaky_relu());
//m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0]))
last_linear_conv = new_last_linear_conv;
}
}
}
if add_sigmoid {
seq = seq.add_fn(|xs| xs.sigmoid());
}
return Model { vs, layers, seq };
}

57
runner/src/settings.rs Normal file
View File

@ -0,0 +1,57 @@
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{fs, path};
#[derive(Deserialize)]
pub struct ConfigFile {
// Hostname to connect with the api
pub hostname: String,
// Token used in the api to authenticate
pub token: String,
// Path to where to store some generated configuration values
// defaults to ./data.toml
pub config_path: Option<String>,
// Data Path
// Path to where the data is mounted
pub data_path: String,
}
#[derive(Deserialize, Serialize)]
pub struct RunnerData {
pub id: String,
}
pub async fn load_runner_data(config: &ConfigFile) -> Result<RunnerData> {
let data_path = config.config_path.as_ref().unwrap();
let data_path = path::Path::new(&*data_path);
if data_path.exists() {
let runner_data = fs::read_to_string(data_path)?;
Ok(toml::from_str(&runner_data)?)
} else {
let client = reqwest::Client::new();
let to_send = json!({
"token": config.token,
"type": 1,
});
let register_resp = client
.post(format!("{}/tasks/runner/register", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if register_resp.status() != 200 {
bail!(format!("Could not create runner {:#?}", register_resp));
}
let runner_data: RunnerData = register_resp.json().await?;
fs::write(data_path, toml::to_string(&runner_data)?)
.expect("Faield to save data for runner");
Ok(runner_data)
}
}

90
runner/src/tasks.rs Normal file
View File

@ -0,0 +1,90 @@
use std::sync::Arc;
use anyhow::{bail, Result};
use serde::Deserialize;
use serde_json::json;
use serde_repr::Deserialize_repr;
use crate::{ConfigFile, RunnerData};
#[derive(Clone, Copy, Deserialize_repr, Debug)]
#[repr(i8)]
pub enum TaskStatus {
FailedRunning = -2,
FailedCreation = -1,
Preparing = 0,
Todo = 1,
PickedUp = 2,
Running = 3,
Done = 4,
}
#[derive(Clone, Copy, Deserialize_repr, Debug)]
#[repr(i8)]
pub enum TaskType {
Classification = 1,
Training = 2,
Retraining = 3,
DeleteUser = 4,
}
#[derive(Deserialize, Debug)]
pub struct Task {
pub id: String,
pub user_id: String,
pub model_id: String,
pub status: TaskStatus,
pub status_message: String,
pub user_confirmed: i8,
pub compacted: i8,
#[serde(alias = "type")]
pub task_type: TaskType,
pub extra_task_info: String,
pub result: String,
pub created: String,
}
pub async fn fail_task(
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
reason: &str,
) -> Result<()> {
println!("Marking Task as failed");
let client = reqwest::Client::new();
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
"reason": reason
});
let resp = client
.post(format!("{}/tasks/runner/fail", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if resp.status() != 200 {
println!("Failed to update status of task");
bail!("Failed to update status of task");
}
Ok(())
}
impl Task {
pub async fn fail(
self: &mut Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
reason: &str,
) -> Result<()> {
fail_task(self, config, runner_data, reason).await?;
self.status = TaskStatus::FailedRunning;
self.status_message = reason.to_string();
Ok(())
}
}

599
runner/src/training.rs Normal file
View File

@ -0,0 +1,599 @@
use crate::{
dataloader::DataLoader,
model::{self, build_model},
settings::{ConfigFile, RunnerData},
tasks::{fail_task, Task},
types::{DataPointRequest, Definition, ModelClass},
};
use std::{
io::{self, Write},
sync::Arc,
};
use anyhow::Result;
use rand::{seq::SliceRandom, thread_rng};
use serde_json::json;
use tch::{
nn::{self, Module, OptimizerConfig},
Cuda, Tensor,
};
pub async fn handle_train(
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> Result<()> {
let client = reqwest::Client::new();
println!("Preparing to train a model");
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
});
let mut defs: Vec<Definition> = client
.post(format!("{}/tasks/runner/train/defs", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
if defs.len() == 0 {
println!("No defs found");
fail_task(task, config, runner_data, "No definitions found").await?;
return Ok(());
}
let classes: Vec<ModelClass> = client
.post(format!("{}/tasks/runner/train/classes", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
let data: DataPointRequest = client
.post(format!("{}/tasks/runner/train/datapoints", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
let mut testing = data.testing;
testing.shuffle(&mut thread_rng());
let mut data_loader = DataLoader::new(config.clone(), testing, classes.len() as i64, 64);
// TODO make this a vec
let mut model: Option<model::Model> = None;
loop {
let config = config.clone();
let runner_data = runner_data.clone();
let mut to_remove: Vec<usize> = Vec::new();
let mut def_iter = defs.iter_mut();
let mut i: usize = 0;
while let Some(def) = def_iter.next() {
def.updateStatus(
task,
config.clone(),
runner_data.clone(),
crate::types::DefinitionStatus::Training,
)
.await?;
let model_err = train_definition(
def,
&mut data_loader,
model,
config.clone(),
runner_data.clone(),
&task,
)
.await;
if model_err.is_err() {
println!("Failed to create model {:?}", model_err);
model = None;
to_remove.push(i);
continue;
}
model = model_err?;
i += 1;
}
defs = defs
.into_iter()
.enumerate()
.filter(|&(i, _)| to_remove.iter().any(|b| *b == i))
.map(|(_, e)| e)
.collect();
break;
}
fail_task(task, config, runner_data, "TODO").await?;
Ok(())
/*
for {
// Keep track of definitions that did not train fast enough
var toRemove ToRemoveList = []int{}
for i, def := range definitions {
accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes)
if err != nil {
log.Error("Failed to train definition!Err:", "err", err)
def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
models[def.Id] = ml_model
if accuracy >= float64(def.TargetAccuracy) {
log.Info("Found a definition that reaches target_accuracy!")
_, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
_, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
finished = true
break
}
if def.Epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
_, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
}
if finished {
break
}
sort.Sort(sort.Reverse(toRemove))
log.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove {
// Clean up unsed models
models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
len_def := len(definitions)
if len_def == 0 {
break
}
if len_def == 1 {
continue
}
sort.Sort(sort.Reverse(definitions))
acc := definitions[0].Accuracy - 20.0
log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc)
toRemove = []int{}
for i, def := range definitions {
if def.Accuracy < acc {
toRemove = append(toRemove, i)
}
}
log.Info("Removing due to accuracy", "toRemove", toRemove)
sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove {
log.Warn("Removing definition not fast enough learning", "n", n)
definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
}
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED)
if err != nil {
if err == NotFoundError {
log.Error("All definitions failed to train!")
} else {
log.Error("DB: failed to read definition", "err", err)
}
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
log.Error("Failed to update model definition", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id)
if err != nil {
log.Error("Failed to select model_definition to delete")
log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
defer to_delete.Close()
for to_delete.Next() {
var id string
if err = to_delete.Scan(&id); err != nil {
log.Error("Failed to scan the id of a model_definition to delete", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
}
// TODO Check if returning also works here
if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil {
log.Error("Failed to delete model_definition")
log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
ModelUpdateStatus(c, model.Id, READY)
return
*/
}
async fn train_definition(
def: &Definition,
data_loader: &mut DataLoader,
model: Option<model::Model>,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
task: &Task,
) -> Result<Option<model::Model>> {
let client = reqwest::Client::new();
println!("About to start training definition");
let mut accuracy = 0;
let model = model.unwrap_or({
let layers: Vec<model::Layer> = client
.post(format!("{}/tasks/runner/train/def/layers", config.hostname))
.header("token", &config.token)
.body(
json!({
"id": runner_data.id,
"taskId": task.id,
"defId": def.id,
})
.to_string(),
)
.send()
.await?
.json()
.await?;
build_model(layers, 0, true)
});
// TODO CUDA
// get device
// Move model to cuda
let mut opt = nn::Adam::default().build(&model.vs, 1e-3)?;
let mut last_acc = 0.0;
for epoch in 1..40 {
data_loader.restart();
let mut mean_loss: f64 = 0.0;
let mut mean_acc: f64 = 0.0;
while let Some((inputs, labels)) = data_loader.next() {
let inputs = inputs
.to_kind(tch::Kind::Float)
.to_device(tch::Device::Cuda(0));
let labels = labels
.to_kind(tch::Kind::Float)
.to_device(tch::Device::Cuda(0));
let out = model.seq.forward(&inputs);
let weight: Option<Tensor> = None;
let loss = out.binary_cross_entropy(&labels, weight, tch::Reduction::Mean);
opt.backward_step(&loss);
mean_loss += loss
.to_device(tch::Device::Cpu)
.unsqueeze(0)
.double_value(&[0]);
let out = out.to_device(tch::Device::Cpu);
let test = out.empty_like();
_ = out.clone(&test);
let out = test.argmax(1, true);
let mut labels = labels.to_device(tch::Device::Cpu);
labels = labels.unsqueeze(-1);
let size = out.size()[0];
let mut acc = 0;
for i in 0..size {
let res = out.double_value(&[i]);
let exp = labels.double_value(&[i, res as i64]);
if exp == 1.0 {
acc += 1;
}
}
mean_acc += acc as f64 / size as f64;
last_acc = acc as f64 / size as f64;
}
print!(
"\repoch: {} loss: {} acc: {} l acc: {} ",
epoch,
mean_loss / data_loader.len as f64,
mean_acc / data_loader.len as f64,
last_acc
);
io::stdout().flush().expect("Unable to flush stdout");
}
println!("\nlast acc: {}", last_acc);
return Ok(Some(model));
/*
opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ {
var trainIter *torch.Iter2
trainIter, err = ds.TrainIter(32)
if err != nil {
return
}
// trainIter.ToDevice(device)
log.Info("epoch", "epoch", epoch)
var trainLoss float64 = 0
var trainCorrect float64 = 0
ok := true
for ok {
var item torch.Iter2Item
var loss *torch.Tensor
item, ok = trainIter.Next()
if !ok {
continue
}
data := item.Data
data, err = data.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
var size []int64
size, err = data.Size()
if err != nil {
return
}
var zeros *torch.Tensor
zeros, err = torch.Zeros(size, gotch.Float, device)
if err != nil {
return
}
data, err = zeros.Add(data, true)
if err != nil {
return
}
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
data, err = data.SetRequiresGrad(true, false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
err = data.RetainGrad(false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
pred := model.ForwardT(data, true)
pred, err = pred.SetRequiresGrad(true, true)
if err != nil {
return
}
err = pred.RetainGrad(false)
if err != nil {
return
}
label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
label, err = label.SetRequiresGrad(true, true)
if err != nil {
return
}
err = label.RetainGrad(false)
if err != nil {
return
}
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
if err != nil {
return
}
loss, err = loss.SetRequiresGrad(true, false)
if err != nil {
return
}
err = loss.RetainGrad(false)
if err != nil {
return
}
err = opt.ZeroGrad()
if err != nil {
return
}
err = loss.Backward()
if err != nil {
return
}
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
vars := model.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false))
}
model.Debug()
err = opt.Step()
if err != nil {
return
}
trainLoss = loss.Float64Values()[0]
// Calculate accuracy
/ *var p_pred, p_labels *torch.Tensor
p_pred, err = pred.Argmax([]int64{1}, true, false)
if err != nil {
return
}
p_labels, err = item.Label.Argmax([]int64{1}, true, false)
if err != nil {
return
}
floats := p_pred.Float64Values()
floats_labels := p_labels.Float64Values()
for i := range floats {
if floats[i] == floats_labels[i] {
trainCorrect += 1
}
} * /
panic("fornow")
}
//v := []float64{}
log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize))
/ *correct := int64(0)
//torch.NoGrad(func() {
ok = true
testIter := ds.TestIter(64)
for ok {
var item torch.Iter2Item
item, ok = testIter.Next()
if !ok {
continue
}
output := model.Forward(item.Data)
var pred, labels *torch.Tensor
pred, err = output.Argmax([]int64{1}, true, false)
if err != nil {
return
}
labels, err = item.Label.Argmax([]int64{1}, true, false)
if err != nil {
return
}
floats := pred.Float64Values()
floats_labels := labels.Float64Values()
for i := range floats {
if floats[i] == floats_labels[i] {
correct += 1
}
}
}
accuracy = float64(correct) / float64(ds.TestImagesSize)
log.Info("Eval accuracy", "accuracy", accuracy)
err = def.UpdateAfterEpoch(db, accuracy*100)
if err != nil {
return
}* /
//})
}
result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id)
err = os.MkdirAll(result_path, os.ModePerm)
if err != nil {
return
}
err = my_torch.SaveModel(model, path.Join(result_path, "model.dat"))
if err != nil {
return
}
log.Info("Model finished training!", "accuracy", accuracy)
return
*/
}

89
runner/src/types.rs Normal file
View File

@ -0,0 +1,89 @@
use crate::{model, tasks::Task, ConfigFile, RunnerData};
use anyhow::{bail, Result};
use serde::Deserialize;
use serde_json::json;
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::sync::Arc;
#[derive(Clone, Copy, Deserialize_repr, Serialize_repr, Debug)]
#[repr(i8)]
pub enum DefinitionStatus {
CanceldTraining = -4,
FailedTraining = -3,
PreInit = 1,
Init = 2,
Training = 3,
PausedTraining = 6,
Tranied = 4,
Ready = 5,
}
#[derive(Deserialize, Debug)]
pub struct Definition {
pub id: String,
pub model_id: String,
pub accuracy: f64,
pub target_accuracy: i64,
pub epoch: i64,
pub status: i64,
pub created: String,
pub epoch_progress: i64,
}
impl Definition {
pub async fn updateStatus(
self: &mut Definition,
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
status: DefinitionStatus,
) -> Result<()> {
println!("Marking Task as faield");
let client = reqwest::Client::new();
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
"defId": self.id,
"status": status,
});
let resp = client
.post(format!("{}/tasks/runner/train/def/status", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?;
if resp.status() != 200 {
println!("Failed to update status of task");
bail!("Failed to update status of task");
}
Ok(())
}
}
#[derive(Clone, Copy, Deserialize_repr, Debug)]
#[repr(i8)]
pub enum ModelClassStatus {
ToTrain = 1,
Training = 2,
Trained = 3,
}
#[derive(Deserialize, Debug)]
pub struct ModelClass {
pub id: String,
pub model_id: String,
pub name: String,
pub class_order: i64,
pub status: ModelClassStatus,
}
#[derive(Deserialize, Debug)]
pub struct DataPointRequest {
pub testing: Vec<model::DataPoint>,
pub training: Vec<model::DataPoint>,
}

View File

@ -1 +1 @@
CREATE DATABASE fyp; CREATE DATABASE aistuff;

View File

@ -35,7 +35,7 @@ create table if not exists model_classes (
-- 1: to_train -- 1: to_train
-- 2: training -- 2: training
-- 3: trained -- 3: trained
status integer default 1 status integer default 1,
); );
-- drop table if exists model_data_point; -- drop table if exists model_data_point;
@ -59,6 +59,7 @@ create table if not exists model_definition (
accuracy real default 0, accuracy real default 0,
target_accuracy integer not null, target_accuracy integer not null,
epoch integer default 0, epoch integer default 0,
-- TODO add max epoch
-- 1: Pre Init -- 1: Pre Init
-- 2: Init -- 2: Init
-- 3: Training -- 3: Training
@ -77,7 +78,7 @@ create table if not exists model_definition_layer (
-- 1: input -- 1: input
-- 2: dense -- 2: dense
-- 3: flatten -- 3: flatten
-- 4: block -- TODO add conv
layer_type integer not null, layer_type integer not null,
-- ei 28,28,1 -- ei 28,28,1
-- a 28x28 grayscale image -- a 28x28 grayscale image
@ -101,6 +102,7 @@ create table if not exists exp_model_head (
accuracy real default 0, accuracy real default 0,
-- TODO add max epoch
-- 1: Pre Init -- 1: Pre Init
-- 2: Init -- 2: Init
-- 3: Training -- 3: Training

View File

@ -143,15 +143,6 @@ def addBlock(
model.add(layers.Dropout(0.4)) model.add(layers.Dropout(0.4))
return model return model
def resblock(x, kernelsize = 3, filters = 128):
fx = layers.Conv2D(filters, kernelsize, activation='relu', padding='same')(x)
fx = layers.BatchNormalization()(fx)
fx = layers.Conv2D(filters, kernelsize, padding='same')(fx)
out = layers.Add()([x,fx])
out = layers.ReLU()(out)
out = layers.BatchNormalization()(out)
return out
{{ if .LoadPrev }} {{ if .LoadPrev }}
model = tf.keras.saving.load_model('{{.LastModelRunPath}}') model = tf.keras.saving.load_model('{{.LastModelRunPath}}')

View File

@ -1 +0,0 @@
.gitignore

View File

@ -27,11 +27,5 @@ module.exports = {
parser: '@typescript-eslint/parser' parser: '@typescript-eslint/parser'
} }
} }
], ]
rules: {
'svelte/no-at-html-tags': 'off',
// TODO remove this
'@typescript-eslint/no-explicit-any': 'off'
}
}; };

View File

@ -1,9 +0,0 @@
FROM docker.io/node:22
ADD . .
RUN npm install
RUN npm run build
CMD ["npm", "run", "preview"]

Binary file not shown.

4125
webpage/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@
"dev:raw": "vite dev", "dev:raw": "vite dev",
"dev": "vite dev --port 5001 --host", "dev": "vite dev --port 5001 --host",
"build": "vite build", "build": "vite build",
"preview": "vite preview --port 5001 --host", "preview": "vite preview",
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json", "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch", "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
"lint": "prettier --check . && eslint .", "lint": "prettier --check . && eslint .",
@ -15,8 +15,7 @@
"devDependencies": { "devDependencies": {
"@sveltejs/adapter-auto": "^3.2.0", "@sveltejs/adapter-auto": "^3.2.0",
"@sveltejs/kit": "^2.5.6", "@sveltejs/kit": "^2.5.6",
"@sveltejs/vite-plugin-svelte": "^3.0.0", "@sveltejs/vite-plugin-svelte": "3.0.0",
"@types/d3": "^7.4.3",
"@types/eslint": "^8.56.9", "@types/eslint": "^8.56.9",
"@typescript-eslint/eslint-plugin": "^7.7.0", "@typescript-eslint/eslint-plugin": "^7.7.0",
"@typescript-eslint/parser": "^7.7.0", "@typescript-eslint/parser": "^7.7.0",
@ -26,7 +25,7 @@
"prettier": "^3.2.5", "prettier": "^3.2.5",
"prettier-plugin-svelte": "^3.2.3", "prettier-plugin-svelte": "^3.2.3",
"sass": "^1.75.0", "sass": "^1.75.0",
"svelte": "^5.0.0-next.104", "svelte": "5.0.0-next.104",
"svelte-check": "^3.6.9", "svelte-check": "^3.6.9",
"tslib": "^2.6.2", "tslib": "^2.6.2",
"typescript": "^5.4.5", "typescript": "^5.4.5",
@ -34,8 +33,6 @@
}, },
"type": "module", "type": "module",
"dependencies": { "dependencies": {
"chart.js": "^4.4.2", "chart.js": "^4.4.2"
"d3": "^7.9.0",
"highlight.js": "^11.9.0"
} }
} }

View File

@ -15,18 +15,8 @@
{/if} {/if}
<li class="expand"></li> <li class="expand"></li>
{#if userStore.user} {#if userStore.user}
{#if userStore.user.user_type == 2}
<li>
<a href="/admin/runners">
<span class="bi bi-cpu-fill"></span>
Runner
</a>
</li>
{/if}
<li> <li>
<a href="/user/info"> <span class="bi bi-person-fill"></span> {userStore.user.username} </a> <a href="/user/info"> <span class="bi bi-person-fill"></span> {userStore.user.username} </a>
</li>
<li>
<a href="/logout"> <span class="bi bi-box-arrow-right"></span> Logout </a> <a href="/logout"> <span class="bi bi-box-arrow-right"></span> Logout </a>
</li> </li>
{:else} {:else}

View File

@ -9,8 +9,6 @@
<link rel="preconnect" href="https://fonts.googleapis.com" /> <link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.googleapis.com" /> <link rel="preconnect" href="https://fonts.googleapis.com" />
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin /> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
<link href="https://fonts.googleapis.com/css2?family=Andada+Pro:ital,wght@0,400..840;1,400..840&family=Bebas+Neue&family=Fira+Code:wght@300..700&display=swap" rel="stylesheet">
<link <link
href="https://fonts.googleapis.com/css2?family=Andada+Pro:ital,wght@0,400..840;1,400..840&family=Bebas+Neue&display=swap" href="https://fonts.googleapis.com/css2?family=Andada+Pro:ital,wght@0,400..840;1,400..840&family=Bebas+Neue&display=swap"
rel="stylesheet" rel="stylesheet"

View File

@ -36,7 +36,7 @@
class="icon" class="icon"
class:adapt={replace_slot && file && !notExpand} class:adapt={replace_slot && file && !notExpand}
type="button" type="button"
onclick={() => fileInput.click()} on:click={() => fileInput.click()}
> >
{#if replace_slot && file} {#if replace_slot && file}
<slot name="replaced" {file}> <slot name="replaced" {file}>
@ -54,6 +54,6 @@
required required
{accept} {accept}
bind:this={fileInput} bind:this={fileInput}
onchange={onChange} on:change={onChange}
/> />
</div> </div>

View File

@ -0,0 +1,57 @@
<script context="module" lang="ts">
export type DisplayFn = (
msg: string,
options?: {
type?: 'error' | 'success';
timeToShow?: number;
}
) => void;
</script>
<script lang="ts">
let message = $state<string | undefined>(undefined);
let type = $state<'error' | 'success'>('error');
let timeout: number | undefined = undefined;
export function clear() {
if (timeout) clearTimeout(timeout);
message = undefined;
}
export function display(
msg: string,
options?: {
type?: 'error' | 'success';
timeToShow?: number;
}
) {
if (timeout) clearTimeout(timeout);
if (!msg) {
message = undefined;
return;
}
let { type: l_type, timeToShow } = options ?? { type: 'error', timeToShow: undefined };
if (l_type) {
type = l_type;
}
message = msg;
if (timeToShow) {
timeout = setTimeout(() => {
message = undefined;
timeout = undefined;
}, timeToShow);
}
}
</script>
{#if message}
<div class="form-msg {type}">
{message}
</div>
{/if}

View File

@ -1,5 +1,5 @@
<script lang="ts"> <script lang="ts">
let { title }: { title: string } = $props(); let { title } = $props<{ title: string }>();
let isHovered = $state(false); let isHovered = $state(false);
let x = $state(0); let x = $state(0);
@ -30,10 +30,10 @@
<div <div
bind:this={div} bind:this={div}
onmouseover={mouseOver} on:mouseover={mouseOver}
onmouseleave={mouseLeave} on:mouseleave={mouseLeave}
onmousemove={mouseMove} on:mousemove={mouseMove}
onfocus={focus} on:focus={focus}
role="tooltip" role="tooltip"
class="tooltipContainer" class="tooltipContainer"
> >

View File

@ -1,6 +0,0 @@
export function preventDefault(fn: any) {
return function (event: Event) {
event.preventDefault();
fn.call(this, event);
};
}

View File

@ -1,7 +1,7 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { rdelete } from '$lib/requests.svelte'; import { rdelete } from '$lib/requests.svelte';
export type User = { type User = {
token: string; token: string;
id: string; id: string;
user_type: number; user_type: number;

View File

@ -1,350 +0,0 @@
<script lang="ts">
import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { post, showMessage } from 'src/lib/requests.svelte';
import { userStore } from 'src/routes/UserStore.svelte';
import { onMount } from 'svelte';
import * as d3 from 'd3';
import type { Base } from './types';
import CardInfo from './CardInfo.svelte';
let width = $state(0);
let height = $state(0);
function drag(simulation: d3.Simulation<d3.HierarchyNode<Base>, undefined>) {
function dragstarted(event: any, d: any) {
if (!event.active) simulation.alphaTarget(0.3).restart();
d.fx = d.x;
d.fy = d.y;
selected = d.data;
}
function dragged(event: any, d: any) {
d.fx = event.x;
d.fy = event.y;
}
function dragended(event: any, d: any) {
if (!event.active) simulation.alphaTarget(0);
d.fx = null;
d.fy = null;
}
return d3.drag().on('start', dragstarted).on('drag', dragged).on('end', dragended);
}
let graph: HTMLDivElement;
let selected: Base | undefined = $state();
async function getData() {
const dataObj: Base = {
name: 'API',
type: 'api',
children: []
};
if (!dataObj.children) throw new Error();
const localRunners: Base[] = [];
const remotePairs: Record<string, Base[]> = {};
try {
let data = await post('tasks/runner/info', {});
if (Object.keys(data.localRunners).length > 0) {
for (const objId of Object.keys(data.localRunners)) {
localRunners.push({ name: objId, type: 'local_runner', task: data.localRunners[objId] });
}
dataObj.children.push({
name: 'local runners',
type: 'runner_group',
children: localRunners
});
}
if (Object.keys(data.remoteRunners).length > 0) {
for (const objId of Object.keys(data.remoteRunners)) {
let obj = data.remoteRunners[objId];
if (remotePairs[obj.runner_info.user_id as string]) {
remotePairs[obj.runner_info.user_id as string].push({
name: objId,
type: 'runner',
task: obj.task,
parent: data.remoteRunners[objId].runner_info.user_id
});
} else {
remotePairs[data.remoteRunners[objId].runner_info.user_id] = [
{
name: objId,
type: 'runner',
task: obj.task,
parent: data.remoteRunners[objId].runner_info.user_id
}
];
}
}
}
dataObj.children.push({
name: 'remote runners',
type: 'runner_group',
task: undefined,
children: Object.keys(remotePairs).map(
(name) =>
({
name,
type: 'user_group',
task: undefined,
children: remotePairs[name]
}) as Base
)
});
} catch (ex) {
showMessage(ex, notificationStore, 'Failed to get Runner information');
return;
}
const root = d3.hierarchy(dataObj);
const links = root.links();
const nodes = root.descendants();
console.log(root, links, nodes);
const simulation = d3
.forceSimulation(nodes)
.force(
'link',
d3
.forceLink(links)
.id((d: any) => d.id)
.distance((d: any) => {
let data = d.source.data as Base;
switch (data.type) {
case 'api':
return 150;
case 'runner_group':
return 90;
case 'user_group':
return 80;
case 'runner':
case 'local_runner':
return 20;
default:
throw new Error();
}
})
.strength(1)
)
.force('charge', d3.forceManyBody().strength(-1000))
.force('x', d3.forceX())
.force('y', d3.forceY());
const svg = d3
.create('svg')
.attr('width', width)
.attr('height', height - 62)
.attr('viewBox', [-width / 2, -height / 2, width, height])
.attr('style', 'max-width: 100%; height: auto;');
// Append links.
const link = svg
.append('g')
.attr('stroke', '#999')
.attr('stroke-opacity', 0.6)
.selectAll('line')
.data(links)
.join('line');
const database_svg = `
<svg xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-database" viewBox="0 0 32 32">
<path transform="scale(2)" d="M4.318 2.687C5.234 2.271 6.536 2 8 2s2.766.27 3.682.687C12.644 3.125 13 3.627 13 4c0 .374-.356.875-1.318 1.313C10.766 5.729 9.464 6 8 6s-2.766-.27-3.682-.687C3.356 4.875 3 4.373 3 4c0-.374.356-.875 1.318-1.313M13 5.698V7c0 .374-.356.875-1.318 1.313C10.766 8.729 9.464 9 8 9s-2.766-.27-3.682-.687C3.356 7.875 3 7.373 3 7V5.698c.271.202.58.378.904.525C4.978 6.711 6.427 7 8 7s3.022-.289 4.096-.777A5 5 0 0 0 13 5.698M14 4c0-1.007-.875-1.755-1.904-2.223C11.022 1.289 9.573 1 8 1s-3.022.289-4.096.777C2.875 2.245 2 2.993 2 4v9c0 1.007.875 1.755 1.904 2.223C4.978 15.71 6.427 16 8 16s3.022-.289 4.096-.777C13.125 14.755 14 14.007 14 13zm-1 4.698V10c0 .374-.356.875-1.318 1.313C10.766 11.729 9.464 12 8 12s-2.766-.27-3.682-.687C3.356 10.875 3 10.373 3 10V8.698c.271.202.58.378.904.525C4.978 9.71 6.427 10 8 10s3.022-.289 4.096-.777A5 5 0 0 0 13 8.698m0 3V13c0 .374-.356.875-1.318 1.313C10.766 14.729 9.464 15 8 15s-2.766-.27-3.682-.687C3.356 13.875 3 13.373 3 13v-1.302c.271.202.58.378.904.525C4.978 12.71 6.427 13 8 13s3.022-.289 4.096-.777c.324-.147.633-.323.904-.525"/>
</svg>
`;
const cpu_svg = `
<svg stroke="white" fill="white" xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-cpu-fill" viewBox="0 0 32 32">
<path transform="scale(2)" d="M6.5 6a.5.5 0 0 0-.5.5v3a.5.5 0 0 0 .5.5h3a.5.5 0 0 0 .5-.5v-3a.5.5 0 0 0-.5-.5z"/>
<path transform="scale(2)" d="M5.5.5a.5.5 0 0 0-1 0V2A2.5 2.5 0 0 0 2 4.5H.5a.5.5 0 0 0 0 1H2v1H.5a.5.5 0 0 0 0 1H2v1H.5a.5.5 0 0 0 0 1H2v1H.5a.5.5 0 0 0 0 1H2A2.5 2.5 0 0 0 4.5 14v1.5a.5.5 0 0 0 1 0V14h1v1.5a.5.5 0 0 0 1 0V14h1v1.5a.5.5 0 0 0 1 0V14h1v1.5a.5.5 0 0 0 1 0V14a2.5 2.5 0 0 0 2.5-2.5h1.5a.5.5 0 0 0 0-1H14v-1h1.5a.5.5 0 0 0 0-1H14v-1h1.5a.5.5 0 0 0 0-1H14v-1h1.5a.5.5 0 0 0 0-1H14A2.5 2.5 0 0 0 11.5 2V.5a.5.5 0 0 0-1 0V2h-1V.5a.5.5 0 0 0-1 0V2h-1V.5a.5.5 0 0 0-1 0V2h-1zm1 4.5h3A1.5 1.5 0 0 1 11 6.5v3A1.5 1.5 0 0 1 9.5 11h-3A1.5 1.5 0 0 1 5 9.5v-3A1.5 1.5 0 0 1 6.5 5"/>
</svg>
`;
const user_svg = `
<svg fill="white" stroke="white" xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-person-fill" viewBox="0 0 32 32">
<path transform="scale(2)" d="M3 14s-1 0-1-1 1-4 6-4 6 3 6 4-1 1-1 1zm5-6a3 3 0 1 0 0-6 3 3 0 0 0 0 6"/>
</svg>
`;
const inbox_fill = `
<svg fill="white" stroke="white" xmlns="http://www.w3.org/2000/svg" stroke-width="0.2" width="32" height="32" fill="currentColor" class="bi bi-inbox-fill" viewBox="0 0 32 32">
<path transform="scale(2)" d="M4.98 4a.5.5 0 0 0-.39.188L1.54 8H6a.5.5 0 0 1 .5.5 1.5 1.5 0 1 0 3 0A.5.5 0 0 1 10 8h4.46l-3.05-3.812A.5.5 0 0 0 11.02 4zm-1.17-.437A1.5 1.5 0 0 1 4.98 3h6.04a1.5 1.5 0 0 1 1.17.563l3.7 4.625a.5.5 0 0 1 .106.374l-.39 3.124A1.5 1.5 0 0 1 14.117 13H1.883a1.5 1.5 0 0 1-1.489-1.314l-.39-3.124a.5.5 0 0 1 .106-.374z"/>
</svg>
`;
const node = svg
.append('g')
.attr('fill', '#fff')
.attr('stroke', '#000')
.attr('stroke-width', 1.5)
.selectAll('g')
.data(nodes)
.join('g')
.attr('style', 'cursor: pointer;')
.call(drag(simulation) as any)
.on('click', (e) => {
console.log('test');
function findData(obj: HTMLElement) {
if ((obj as any).__data__) {
return (obj as any).__data__;
}
if (!obj.parentElement) {
throw new Error();
}
return findData(obj.parentElement);
}
let obj = findData(e.srcElement);
console.log(obj);
selected = obj.data;
});
node
.append('circle')
.attr('fill', (d: any) => {
let data = d.data as Base;
switch (data.type) {
case 'api':
return '#caf0f8';
case 'runner_group':
return '#00b4d8';
case 'user_group':
return '#0000ff';
case 'runner':
case 'local_runner':
return '#03045e';
default:
throw new Error();
}
})
.attr('stroke', (d: any) => {
let data = d.data as Base;
switch (data.type) {
case 'api':
case 'user_group':
case 'runner_group':
return '#fff';
case 'runner':
case 'local_runner':
// TODO make this relient on the stauts
return '#000';
default:
throw new Error();
}
})
.attr('r', (d: any) => {
let data = d.data as Base;
switch (data.type) {
case 'api':
return 30;
case 'runner_group':
return 20;
case 'user_group':
return 25;
case 'runner':
case 'local_runner':
return 30;
default:
throw new Error();
}
})
.append('title')
.text((d: any) => d.data.name);
node
.filter((d) => {
return ['api', 'local_runner', 'runner', 'user_group', 'runner_group'].includes(
d.data.type
);
})
.append('g')
.html((d) => {
switch (d.data.type) {
case 'api':
return database_svg;
case 'user_group':
return user_svg;
case 'runner_group':
return inbox_fill;
case 'local_runner':
case 'runner':
return cpu_svg;
default:
throw new Error();
}
});
simulation.on('tick', () => {
link
.attr('x1', (d: any) => d.source.x)
.attr('y1', (d: any) => d.source.y)
.attr('x2', (d: any) => d.target.x)
.attr('y2', (d: any) => d.target.y);
node
.select('circle')
.attr('cx', (d: any) => d.x)
.attr('cy', (d: any) => d.y);
node
.select('svg')
.attr('x', (d: any) => d.x - 16)
.attr('y', (d: any) => d.y - 16);
});
//invalidation.then(() => simulation.stop());
graph.appendChild(svg.node() as any);
}
$effect(() => {
console.log(selected);
});
onMount(() => {
// Check if logged in and admin
if (!userStore.user || userStore.user.user_type != 2) {
goto('/');
return;
}
getData();
});
</script>
<svelte:window bind:innerWidth={width} bind:innerHeight={height} />
<svelte:head>
<title>Runners</title>
</svelte:head>
<div class="graph-container">
<div class="graph" bind:this={graph}></div>
{#if selected}
<div class="selected">
<CardInfo item={selected} />
</div>
{/if}
</div>
<style lang="css">
.graph-container {
position: relative;
.selected {
position: absolute;
right: 40px;
top: 40px;
width: 20%;
height: auto;
padding: 20px;
background: white;
border-radius: 20px;
box-shadow: 1px 1px 8px 2px #22222244;
}
}
</style>

View File

@ -1,90 +0,0 @@
<script lang="ts">
import { post, showMessage } from 'src/lib/requests.svelte';
import type { Base } from './types';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import type { User } from 'src/routes/UserStore.svelte';
import Spinner from 'src/lib/Spinner.svelte';
import Tooltip from 'src/lib/Tooltip.svelte';
let { item }: { item: Base } = $props();
let user_data: User | undefined = $state();
async function getUserData(id: string) {
try {
user_data = await post('user/info/get', { id });
console.log(user_data);
} catch (ex) {
showMessage(ex, notificationStore, 'Could not get user information');
}
}
$effect(() => {
user_data = undefined;
if (item.type == 'user_group') {
getUserData(item.name);
} else if (item.type == 'runner') {
getUserData(item.parent ?? '');
}
});
</script>
{#if item.type == 'api'}
<h3>API</h3>
{:else if item.type == 'runner_group'}
<h3>Runner Group</h3>
This reprents a the group of {item.name}.
{:else if item.type == 'user_group'}
<h3>User</h3>
{#if user_data}
All Runners connected to this node bellong to <span class="accent">{user_data.username}</span>
{:else}
<div style="text-align: center;">
<Spinner />
</div>
{/if}
{:else if item.type == 'local_runner'}
<h3>Local Runner</h3>
This is a local runner
<div>
{#if item.task}
This runner is runing a <Tooltip title={item.task.id}>task</Tooltip>
{:else}
Not running any task
{/if}
</div>
{:else if item.type == 'runner'}
<h3>Runner</h3>
{#if user_data}
<p>
This is a remote runner. This runner is owned by<span class="accent"
>{user_data?.username}</span
>
</p>
<div>
{#if item.task}
This runner is runing a <Tooltip title={item.task.id}>task</Tooltip>
{:else}
Not running any task
{/if}
</div>
{:else}
<div style="text-align: center;">
<Spinner />
</div>
{/if}
{:else}
{item.type}
{/if}
<style lang="scss">
h3 {
text-align: center;
margin: 0;
}
.accent {
background: #22222222;
padding: 1px;
border-radius: 5px;
}
</style>

View File

@ -1,10 +0,0 @@
import type { Task } from 'src/routes/models/edit/tasks/types';
export type BaseType = 'api' | 'runner_group' | 'user_group' | 'runner' | 'local_runner';
export type Base = {
name: string;
type: BaseType;
children?: Base[];
task?: Task;
parent?: string;
};

View File

@ -3,7 +3,6 @@
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { userStore } from '../UserStore.svelte'; import { userStore } from '../UserStore.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { preventDefault } from 'src/lib/utils';
let submitted = $state(false); let submitted = $state(false);
@ -40,7 +39,7 @@
<div class="login-page"> <div class="login-page">
<div> <div>
<h1>Login</h1> <h1>Login</h1>
<form onsubmit={preventDefault(onSubmit)} class:submitted> <form on:submit|preventDefault={onSubmit} class:submitted>
<fieldset> <fieldset>
<label for="email">Email</label> <label for="email">Email</label>
<input type="email" required name="email" bind:value={loginData.email} /> <input type="email" required name="email" bind:value={loginData.email} />

View File

@ -1,22 +1,26 @@
<script lang="ts"> <script lang="ts">
import MessageSimple from 'src/lib/MessageSimple.svelte';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import { get, showMessage } from '$lib/requests.svelte'; import { get } from '$lib/requests.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import Spinner from 'src/lib/Spinner.svelte';
let list = $state< let list = $state<
| { {
name: string; name: string;
id: string; id: string;
}[] }[]
| undefined >([]);
>(undefined);
let message: MessageSimple;
onMount(async () => { onMount(async () => {
try { try {
list = await get('models'); list = await get('models');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not request list of models'); if (e instanceof Response) {
message.display(await e.json());
} else {
message.display('Could not request list of models');
}
} }
}); });
</script> </script>
@ -26,7 +30,7 @@
</svelte:head> </svelte:head>
<main> <main>
{#if list} <MessageSimple bind:this={message} />
{#if list.length > 0} {#if list.length > 0}
<div class="list-header"> <div class="list-header">
<h2>My Models</h2> <h2>My Models</h2>
@ -61,11 +65,6 @@
<a class="button padded" href="/models/add"> Create a new model </a> <a class="button padded" href="/models/add"> Create a new model </a>
</div> </div>
{/if} {/if}
{:else}
<div style="text-align: center;">
<Spinner />
</div>
{/if}
</main> </main>
<style lang="scss"> <style lang="scss">
@ -85,4 +84,11 @@
.list-header .expand { .list-header .expand {
flex-grow: 1; flex-grow: 1;
} }
.list-header .button,
.list-header button {
padding: 10px 10px;
height: calc(100% - 20px);
margin-top: 5px;
}
</style> </style>

View File

@ -1,14 +1,15 @@
<script lang="ts"> <script lang="ts">
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import { postFormData, showMessage } from 'src/lib/requests.svelte'; import MessageSimple from 'src/lib/MessageSimple.svelte';
import { postFormData } from 'src/lib/requests.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { preventDefault } from 'src/lib/utils';
let submitted = $state(false); let submitted = $state(false);
let message: MessageSimple;
let buttonClicked: Promise<void> = $state(Promise.resolve()); let buttonClicked: Promise<void> = $state(Promise.resolve());
let data = $state<{ let data = $state<{
@ -20,6 +21,7 @@
}); });
async function onSubmit() { async function onSubmit() {
message.display('');
buttonClicked = new Promise<void>(() => {}); buttonClicked = new Promise<void>(() => {});
if (!data.file || !data.name) return; if (!data.file || !data.name) return;
@ -32,7 +34,11 @@
let id = await postFormData('models/add', formData); let id = await postFormData('models/add', formData);
goto(`/models/edit?id=${id}`); goto(`/models/edit?id=${id}`);
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Was not able to create model'); if (e instanceof Response) {
message.display(await e.json());
} else {
message.display('Was not able to create model');
}
} }
buttonClicked = Promise.resolve(); buttonClicked = Promise.resolve();
@ -45,7 +51,7 @@
<main> <main>
<h1>Create new Model</h1> <h1>Create new Model</h1>
<form class:submitted onsubmit={preventDefault(onSubmit)}> <form class:submitted on:submit|preventDefault={onSubmit}>
<fieldset> <fieldset>
<label for="name">Name</label> <label for="name">Name</label>
<input id="name" name="name" required bind:value={data.name} /> <input id="name" name="name" required bind:value={data.name} />
@ -69,6 +75,7 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={message} />
{#await buttonClicked} {#await buttonClicked}
<div class="text-center">File Uploading</div> <div class="text-center">File Uploading</div>
{:then} {:then}

View File

@ -30,19 +30,18 @@
import BaseModelInfo from './BaseModelInfo.svelte'; import BaseModelInfo from './BaseModelInfo.svelte';
import DeleteModel from './DeleteModel.svelte'; import DeleteModel from './DeleteModel.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { get, rdelete, showMessage } from 'src/lib/requests.svelte'; import { get, rdelete } from 'src/lib/requests.svelte';
import { preventDefault } from 'src/lib/utils'; import MessageSimple from '$lib/MessageSimple.svelte';
import ModelData from './ModelData.svelte'; import ModelData from './ModelData.svelte';
import DeleteZip from './DeleteZip.svelte'; import DeleteZip from './DeleteZip.svelte';
import RunModel from './RunModel.svelte'; import RunModel from './RunModel.svelte';
import Tabs from 'src/lib/Tabs.svelte'; import Tabs from 'src/lib/Tabs.svelte';
import TasksDataPage from './TasksDataPage.svelte'; import TasksDataPage from './TasksDataPage.svelte';
import ModelDataPage from './ModelDataPage.svelte'; import ModelDataPage from './ModelDataPage.svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import Spinner from 'src/lib/Spinner.svelte';
let model: Promise<Model> = $state(new Promise(() => {})); let model: Promise<Model> = $state(new Promise(() => {}));
let _model: Model | undefined = $state(undefined); let _model: Model | undefined = $state(undefined);
@ -93,7 +92,10 @@
getModel(); getModel();
}); });
let resetMessages: MessageSimple;
async function resetModel() { async function resetModel() {
resetMessages.display('');
let _model = await model; let _model = await model;
try { try {
@ -103,7 +105,11 @@
getModel(); getModel();
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not reset model!'); if (e instanceof Response) {
resetMessages.display(await e.json());
} else {
resetMessages.display('Could not reset model!');
}
} }
} }
@ -141,8 +147,7 @@
<div slot="buttons" let:setActive let:isActive> <div slot="buttons" let:setActive let:isActive>
<button <button
class="tab" class="tab"
type="button" on:click|preventDefault={setActive('model')}
onclick={setActive('model')}
class:selected={isActive('model')} class:selected={isActive('model')}
> >
Model Model
@ -150,8 +155,7 @@
{#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)} {#if _model && [2, 3, 4, 5, 6, 7, -6, -7].includes(_model.status)}
<button <button
class="tab" class="tab"
type="button" on:click|preventDefault={setActive('model-data')}
onclick={setActive('model-data')}
class:selected={isActive('model-data')} class:selected={isActive('model-data')}
> >
Model Data Model Data
@ -160,8 +164,7 @@
{#if _model && [5, 6, 7, -6, -7].includes(_model.status)} {#if _model && [5, 6, 7, -6, -7].includes(_model.status)}
<button <button
class="tab" class="tab"
type="button" on:click|preventDefault={setActive('tasks')}
onclick={setActive('tasks')}
class:selected={isActive('tasks')} class:selected={isActive('tasks')}
> >
Tasks Tasks
@ -169,7 +172,7 @@
{/if} {/if}
</div> </div>
{#if _model} {#if _model}
<ModelDataPage model={_model} onreload={getModel} active={isActive('model-data')} /> <ModelDataPage model={_model} on:reload={getModel} active={isActive('model-data')} />
<TasksDataPage model={_model} active={isActive('tasks')} /> <TasksDataPage model={_model} active={isActive('tasks')} />
{/if} {/if}
<div class="content" class:selected={isActive('model')}> <div class="content" class:selected={isActive('model')}>
@ -189,6 +192,7 @@
<h1 class="text-center"> <h1 class="text-center">
{m.name} {m.name}
</h1> </h1>
<!-- TODO improve message -->
<h2 class="text-center">Failed to prepare model</h2> <h2 class="text-center">Failed to prepare model</h2>
<DeleteModel model={m} /> <DeleteModel model={m} />
@ -196,23 +200,25 @@
<!-- PRE TRAINING STATUS --> <!-- PRE TRAINING STATUS -->
{:else if m.status == 2} {:else if m.status == 2}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<ModelData model={m} onreload={getModel} /> <ModelData model={m} on:reload={getModel} />
<!-- {{ template "train-model-card" . }} --> <!-- {{ template "train-model-card" . }} -->
<DeleteModel model={m} /> <DeleteModel model={m} />
{:else if m.status == -2} {:else if m.status == -2}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<DeleteZip model={m} onreload={getModel} /> <DeleteZip model={m} on:reload={getModel} />
<DeleteModel model={m} /> <DeleteModel model={m} />
{:else if m.status == 3} {:else if m.status == 3}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<div class="card"> <div class="card">
Processing zip file... <Spinner /> <!-- TODO improve this -->
Processing zip file...
</div> </div>
{:else if m.status == -3 || m.status == -4} {:else if m.status == -3 || m.status == -4}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<form onsubmit={preventDefault(resetModel)}> <form on:submit={resetModel}>
Failed Prepare for training.<br /> Failed Prepare for training.<br />
<div class="spacer"></div> <div class="spacer"></div>
<MessageSimple bind:this={resetMessages} />
<button class="danger"> Try Again </button> <button class="danger"> Try Again </button>
</form> </form>
<DeleteModel model={m} /> <DeleteModel model={m} />
@ -331,7 +337,7 @@
<div class="card">Model expading... Processing ZIP file</div> <div class="card">Model expading... Processing ZIP file</div>
{/if} {/if}
{#if m.status == -6} {#if m.status == -6}
<DeleteZip model={m} onreload={getModel} expand /> <DeleteZip model={m} on:reload={getModel} expand />
{/if} {/if}
{#if m.status == -7} {#if m.status == -7}
<form> <form>
@ -340,7 +346,7 @@
</form> </form>
{/if} {/if}
{#if m.model_type == 2} {#if m.model_type == 2}
<ModelData simple model={m} onreload={getModel} /> <ModelData simple model={m} on:reload={getModel} />
{/if} {/if}
<DeleteModel model={m} /> <DeleteModel model={m} />
{:else} {:else}
@ -377,4 +383,10 @@
table tr th:first-child { table tr th:first-child {
border-left: none; border-left: none;
} }
table tr td button,
table tr td .button {
padding: 5px 10px;
box-shadow: 0 2px 5px 1px #66666655;
}
</style> </style>

View File

@ -1,6 +1,6 @@
<script lang="ts"> <script lang="ts">
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
let { model }: { model: Model } = $props(); let { model } = $props<{ model: Model }>();
</script> </script>
<div class="card model-card"> <div class="card model-card">

View File

@ -1,31 +1,39 @@
<script lang="ts"> <script lang="ts">
import MessageSimple from 'src/lib/MessageSimple.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import { rdelete, showMessage } from '$lib/requests.svelte'; import { rdelete } from '$lib/requests.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
let { model }: { model: Model } = $props(); let { model } = $props<{ model: Model }>();
let name: string = $state(''); let name: string = $state('');
let submmited: boolean = $state(false); let submmited: boolean = $state(false);
let messageSimple: MessageSimple;
async function deleteModel() { async function deleteModel() {
submmited = true; submmited = true;
messageSimple.display('');
try { try {
await rdelete('models/delete', { id: model.id, name }); await rdelete('models/delete', { id: model.id, name });
goto('/models'); goto('/models');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not delete the model'); if (e instanceof Response) {
messageSimple.display(await e.json());
} else {
messageSimple.display('Could not delete the model');
}
} }
} }
</script> </script>
<form onsubmit={deleteModel} class:submmited class="danger-bg"> <form on:submit|preventDefault={deleteModel} class:submmited class="danger-bg">
<fieldset> <fieldset>
<label for="name"> <label for="name">
To delete this model please type "{model.name}": To delete this model please type "{model.name}":
</label> </label>
<input name="name" id="name" required bind:value={name} /> <input name="name" id="name" required bind:value={name} />
</fieldset> </fieldset>
<MessageSimple bind:this={messageSimple} />
<button class="danger"> Delete </button> <button class="danger"> Delete </button>
</form> </form>

View File

@ -1,30 +1,32 @@
<script lang="ts"> <script lang="ts">
import { rdelete, showMessage } from 'src/lib/requests.svelte'; import { rdelete } from 'src/lib/requests.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte'; import MessageSimple from 'src/lib/MessageSimple.svelte';
import { preventDefault } from 'src/lib/utils'; import { createEventDispatcher } from 'svelte';
let { let message: MessageSimple;
model,
expand, let { model, expand } = $props<{ model: Model; expand?: boolean }>();
onreload = () => {}
}: { const dispatch = createEventDispatcher<{ reload: void }>();
model: Model;
expand?: boolean;
onreload?: () => void;
} = $props();
async function deleteZip() { async function deleteZip() {
message.clear();
try { try {
await rdelete('models/data/delete-zip-file', { id: model.id }); await rdelete('models/data/delete-zip-file', { id: model.id });
onreload(); dispatch('reload');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not delete the zip file'); if (e instanceof Response) {
message.display(await e.json());
} else {
message.display('Could not delete the zip file');
}
} }
} }
</script> </script>
<form onsubmit={preventDefault(deleteZip)}> <form on:submit|preventDefault={deleteZip}>
{#if expand} {#if expand}
Failed to proccess the zip file.<br /> Failed to proccess the zip file.<br />
Delete file and upload a correct version do add more classes.<br /> Delete file and upload a correct version do add more classes.<br />
@ -35,5 +37,6 @@
<br /> <br />
{/if} {/if}
<div class="spacer"></div> <div class="spacer"></div>
<MessageSimple bind:this={message} />
<button class="danger"> Delete Zip File </button> <button class="danger"> Delete Zip File </button>
</form> </form>

View File

@ -1,29 +1,38 @@
<script lang="ts" context="module">
export type Class = {
name: string;
id: string;
status: number;
};
</script>
<script lang="ts"> <script lang="ts">
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import Tabs from 'src/lib/Tabs.svelte'; import Tabs from 'src/lib/Tabs.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import type { Class } from './types'; import { postFormData, get } from 'src/lib/requests.svelte';
import { postFormData, get, showMessage } from 'src/lib/requests.svelte'; import MessageSimple from 'src/lib/MessageSimple.svelte';
import { createEventDispatcher } from 'svelte';
import ModelTable from './ModelTable.svelte'; import ModelTable from './ModelTable.svelte';
import TrainModel from './TrainModel.svelte'; import TrainModel from './TrainModel.svelte';
import ZipStructure from './ZipStructure.svelte'; import ZipStructure from './ZipStructure.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils';
let { let { model, simple } = $props<{ model: Model; simple?: boolean }>();
model,
simple,
onreload = () => {}
}: { model: Model; simple?: boolean; onreload?: () => void } = $props();
let classes: Class[] = $state([]); let classes: Class[] = $state([]);
let has_data: boolean = $state(false); let has_data: boolean = $state(false);
let file: File | undefined = $state(); let file: File | undefined = $state();
const dispatch = createEventDispatcher<{
reload: void;
}>();
let uploading: Promise<void> = $state(Promise.resolve()); let uploading: Promise<void> = $state(Promise.resolve());
let numberOfInvalidImages = $state(0); let numberOfInvalidImages = $state(0);
let uploadImage: MessageSimple;
async function uploadZip() { async function uploadZip() {
if (!file) return; if (!file) return;
@ -35,9 +44,13 @@
try { try {
await postFormData('models/data/upload', form); await postFormData('models/data/upload', form);
onreload(); dispatch('reload');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not upload data'); if (e instanceof Response) {
uploadImage.display(await e.json());
} else {
uploadImage.display('');
}
} }
uploading = Promise.resolve(); uploading = Promise.resolve();
@ -54,8 +67,8 @@
classes = data.classes; classes = data.classes;
numberOfInvalidImages = data.number_of_invalid_images; numberOfInvalidImages = data.number_of_invalid_images;
has_data = data.has_data; has_data = data.has_data;
} catch (e) { } catch {
showMessage(e, notificationStore, 'Could not get information on classes'); return;
} }
} }
</script> </script>
@ -67,22 +80,22 @@
<p>You need to upload data so the model can train.</p> <p>You need to upload data so the model can train.</p>
<Tabs active="upload" let:isActive> <Tabs active="upload" let:isActive>
<div slot="buttons" let:setActive let:isActive> <div slot="buttons" let:setActive let:isActive>
<button class="tab" class:selected={isActive('upload')} onclick={setActive('upload')}> <button class="tab" class:selected={isActive('upload')} on:click={setActive('upload')}>
Upload Upload
</button> </button>
<!--button <button
class="tab" class="tab"
class:selected={isActive('create-class')} class:selected={isActive('create-class')}
onclick={setActive('create-class')} on:click={setActive('create-class')}
> >
Create Class Create Class
</button--> </button>
<!--button class="tab" class:selected={isActive('api')} onclick={setActive('api')}> <button class="tab" class:selected={isActive('api')} on:click={setActive('api')}>
Api Api
</button--> </button>
</div> </div>
<div class="content" class:selected={isActive('upload')}> <div class="content" class:selected={isActive('upload')}>
<form onsubmit={preventDefault(uploadZip)}> <form on:submit|preventDefault={uploadZip}>
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Data file</label> <label for="file">Data file</label>
<div class="form-msg"> <div class="form-msg">
@ -102,6 +115,7 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={uploadImage} />
{#if file} {#if file}
{#await uploading} {#await uploading}
<button disabled> Uploading </button> <button disabled> Uploading </button>
@ -111,10 +125,10 @@
{/if} {/if}
</form> </form>
</div> </div>
<!--div class="content" class:selected={isActive('create-class')}> <div class="content" class:selected={isActive('create-class')}>
<ModelTable {classes} {model} {onreload} /> <ModelTable {classes} {model} on:reload={() => dispatch('reload')} />
</div--> </div>
<!--div class="content" class:selected={isActive('api')}>TODO</div--> <div class="content" class:selected={isActive('api')}>TODO</div>
</Tabs> </Tabs>
<div class="tabs"></div> <div class="tabs"></div>
{:else} {:else}
@ -122,14 +136,36 @@
{#if numberOfInvalidImages > 0} {#if numberOfInvalidImages > 0}
<p class="danger"> <p class="danger">
There are images {numberOfInvalidImages} that were loaded that do not have the correct format. There are images {numberOfInvalidImages} that were loaded that do not have the correct format.
These images will be deleted when the model trains. These images will be delete when the model trains.
</p> </p>
{/if} {/if}
<ModelTable {classes} {model} {onreload} /> <Tabs active="create-class" let:isActive>
<div slot="buttons" let:setActive let:isActive>
<button
class="tab"
class:selected={isActive('create-class')}
on:click={setActive('create-class')}
>
Create Class
</button>
<button class="tab" class:selected={isActive('api')} on:click={setActive('api')}>
Api
</button>
</div>
<div class="content" class:selected={isActive('create-class')}>
<ModelTable {classes} {model} on:reload={() => dispatch('reload')} />
</div>
<div class="content" class:selected={isActive('api')}>TODO</div>
</Tabs>
{/if} {/if}
</div> </div>
{/if} {/if}
{#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)} {#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)}
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} {onreload} /> <TrainModel
number_of_invalid_images={numberOfInvalidImages}
{model}
{has_data}
on:reload={() => dispatch('reload')}
/>
{/if} {/if}

View File

@ -1,4 +1,5 @@
<script lang="ts"> <script lang="ts">
import { createEventDispatcher } from 'svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import ModelData from './ModelData.svelte'; import ModelData from './ModelData.svelte';
import { post, showMessage } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
@ -6,11 +7,9 @@
import type { ModelStats } from './types'; import type { ModelStats } from './types';
import DeleteZip from './DeleteZip.svelte'; import DeleteZip from './DeleteZip.svelte';
let { let { model, active }: { model: Model; active?: boolean } = $props();
model,
active, const dispatch = createEventDispatcher<{ reload: void }>();
onreload = () => {}
}: { model: Model; active?: boolean; onreload?: () => void } = $props();
$effect(() => { $effect(() => {
if (active) getData(); if (active) getData();
@ -35,14 +34,14 @@
{/if} {/if}
{#if [-6, -2].includes(model.status)} {#if [-6, -2].includes(model.status)}
<DeleteZip {model} {onreload} expand /> <DeleteZip {model} on:reload={() => dispatch('reload')} expand />
{/if} {/if}
<ModelData <ModelData
{model} {model}
onreload={() => { on:reload={() => {
getData(); getData();
onreload(); dispatch('reload');
}} }}
/> />
</div> </div>

View File

@ -97,7 +97,7 @@
}); });
</script> </script>
<div><canvas bind:this={ctx}></canvas></div> <div><canvas bind:this={ctx} /></div>
<style lang="scss"> <style lang="scss">
canvas { canvas {

View File

@ -1,24 +1,33 @@
<script lang="ts" context="module">
export type Image = {
file_path: string;
mode: number;
status: number;
id: string;
};
</script>
<script lang="ts"> <script lang="ts">
import Tabs from 'src/lib/Tabs.svelte'; import Tabs from 'src/lib/Tabs.svelte';
import type { Class, Image } from './types'; import type { Class } from './ModelData.svelte';
import { post, postFormData, rdelete, showMessage } from 'src/lib/requests.svelte'; import { post, postFormData, rdelete, showMessage } from 'src/lib/requests.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte';
import { createEventDispatcher } from 'svelte';
import ZipStructure from './ZipStructure.svelte'; import ZipStructure from './ZipStructure.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils.js'; const dispatch = createEventDispatcher<{ reload: void }>();
import CreateNewClass from './api/CreateNewClass.svelte';
let selected_class: Class | undefined = $state(); let selected_class: Class | undefined = $state();
let { classes, model, onreload }: { classes: Class[]; model: Model; onreload?: () => void } = let { classes, model }: { classes: Class[]; model: Model } = $props();
$props();
let createClass: { className: string } = $state({ let createClass: { className: string } = $state({
className: '' className: ''
}); });
let page = $state(-1); let page = $state(0);
let showNext = $state(false); let showNext = $state(false);
let image_list = $state<Image[]>([]); let image_list = $state<Image[]>([]);
@ -32,10 +41,9 @@
}); });
async function getList() { async function getList() {
if (!selected_class) return;
try { try {
let res = await post('models/data/list', { let res = await post('models/data/list', {
id: selected_class.id, id: selected_class?.id ?? '',
page: page page: page
}); });
showNext = res.showNext; showNext = res.showNext;
@ -45,20 +53,19 @@
} }
} }
$effect(() => {
getList();
});
$effect(() => { $effect(() => {
if (selected_class) { if (selected_class) {
page = 0; page = 0;
getList();
} }
}); });
let file: File | undefined = $state(); let file: File | undefined = $state();
let uploadImage: MessageSimple;
let uploading = $state(Promise.resolve()); let uploading = $state(Promise.resolve());
async function uploadZip() { async function uploadZip() {
uploadImage.clear();
if (!file) return; if (!file) return;
uploading = new Promise(() => {}); uploading = new Promise(() => {});
@ -69,14 +76,19 @@
try { try {
await postFormData('models/data/class/upload', form); await postFormData('models/data/class/upload', form);
if (onreload) onreload(); dispatch('reload');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Failed to upload'); if (e instanceof Response) {
uploadImage.display(await e.json());
} else {
uploadImage.display('');
}
} }
uploading = Promise.resolve(); uploading = Promise.resolve();
} }
let createNewClassMessages: MessageSimple;
async function createNewClass() { async function createNewClass() {
try { try {
const r = await post('models/data/class/new', { const r = await post('models/data/class/new', {
@ -88,7 +100,7 @@
classes = classes; classes = classes;
getList(); getList();
} catch (e) { } catch (e) {
showMessage(e, notificationStore); showMessage(e, createNewClassMessages);
} }
} }
@ -97,11 +109,12 @@
rdelete('models/data/point', { id }); rdelete('models/data/point', { id });
getList(); getList();
} catch (e) { } catch (e) {
showMessage(e, notificationStore); console.error('TODO notify user', e);
} }
} }
let addFile: File | undefined = $state(); let addFile: File | undefined = $state();
let addImageMessages: MessageSimple;
let adding = $state(Promise.resolve()); let adding = $state(Promise.resolve());
let uploadImageDialog: HTMLDialogElement; let uploadImageDialog: HTMLDialogElement;
async function addImage() { async function addImage() {
@ -123,7 +136,7 @@
addFile = undefined; addFile = undefined;
getList(); getList();
} catch (e) { } catch (e) {
showMessage(e, notificationStore); showMessage(e, addImageMessages);
} }
} }
</script> </script>
@ -138,30 +151,30 @@
{#each classes as item} {#each classes as item}
<button <button
style="width: auto; white-space: nowrap;" style="width: auto; white-space: nowrap;"
onclick={() => setActiveClass(item, setActive)} on:click={() => setActiveClass(item, setActive)}
class="tab" class="tab"
class:selected={isActive(item.name)} class:selected={isActive(item.name)}
> >
{item.name} {item.name}
{#if model.model_type == 2} {#if model.model_type == 2}
{#if item.status == 1} {#if item.status == 1}
<span class="bi bi-book" style="color: orange;"></span> <span class="bi bi-book" style="color: orange;" />
{:else if item.status == 2} {:else if item.status == 2}
<span class="bi bi-book" style="color: green;"></span> <span class="bi bi-book" style="color: green;" />
{:else if item.status == 3} {:else if item.status == 3}
<span class="bi bi-check" style="color: green;"></span> <span class="bi bi-check" style="color: green;" />
{/if} {/if}
{/if} {/if}
</button> </button>
{/each} {/each}
</div> </div>
<button <button
onclick={() => { on:click={() => {
setActive('-----New Class-----')(); setActive('-----New Class-----')();
selected_class = undefined; selected_class = undefined;
}} }}
> >
<span class="bi bi-plus"></span> <span class="bi bi-plus" />
</button> </button>
</div> </div>
{#if selected_class == undefined && isActive('-----New Class-----')} {#if selected_class == undefined && isActive('-----New Class-----')}
@ -171,31 +184,21 @@
<div slot="buttons" let:setActive let:isActive> <div slot="buttons" let:setActive let:isActive>
<button <button
class="tab" class="tab"
type="button" on:click|preventDefault={setActive('zip')}
onclick={setActive('zip')}
class:selected={isActive('zip')} class:selected={isActive('zip')}
> >
Zip Zip
</button> </button>
<button <button
class="tab" class="tab"
type="button" on:click|preventDefault={setActive('empty')}
onclick={setActive('empty')}
class:selected={isActive('empty')} class:selected={isActive('empty')}
> >
Empty Class Empty Class
</button> </button>
<button
class="tab"
type="button"
onclick={setActive('api')}
class:selected={isActive('api')}
>
API
</button>
</div> </div>
<div class="content" class:selected={isActive('zip')}> <div class="content" class:selected={isActive('zip')}>
<form onsubmit={preventDefault(uploadZip)}> <form on:submit|preventDefault={uploadZip}>
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Data file</label> <label for="file">Data file</label>
<div class="form-msg"> <div class="form-msg">
@ -215,6 +218,7 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={uploadImage} />
{#if file} {#if file}
{#await uploading} {#await uploading}
<button disabled> Uploading </button> <button disabled> Uploading </button>
@ -225,7 +229,7 @@
</form> </form>
</div> </div>
<div class="content" class:selected={isActive('empty')}> <div class="content" class:selected={isActive('empty')}>
<form onsubmit={preventDefault(createNewClass)}> <form on:submit|preventDefault={createNewClass}>
<div class="form-msg"> <div class="form-msg">
This Creates an empty class that allows images to be added after This Creates an empty class that allows images to be added after
</div> </div>
@ -233,12 +237,10 @@
<label for="className">Class Name</label> <label for="className">Class Name</label>
<input required name="className" bind:value={createClass.className} /> <input required name="className" bind:value={createClass.className} />
</fieldset> </fieldset>
<MessageSimple bind:this={createNewClassMessages} />
<button> Create New Class </button> <button> Create New Class </button>
</form> </form>
</div> </div>
<div class="content" class:selected={isActive('api')}>
<CreateNewClass {model} />
</div>
</Tabs> </Tabs>
</div> </div>
{/if} {/if}
@ -256,7 +258,7 @@
{:else} {:else}
Class to train Class to train
{/if} {/if}
<button onclick={() => uploadImageDialog.showModal()}> Upload Image </button> <button on:click={() => uploadImageDialog.showModal()}> Upload Image </button>
</h2> </h2>
<table> <table>
<thead> <thead>
@ -312,7 +314,7 @@
{/if} {/if}
</td> </td>
<td style="width: 3ch"> <td style="width: 3ch">
<button class="danger" onclick={() => deleteDataPoint(image.id)}> <button class="danger" on:click={() => deleteDataPoint(image.id)}>
<span class="bi bi-trash"></span> <span class="bi bi-trash"></span>
</button> </button>
</td> </td>
@ -323,7 +325,7 @@
<div class="flex justify-center align-center"> <div class="flex justify-center align-center">
<div class="grow-1 flex justify-end align-center"> <div class="grow-1 flex justify-end align-center">
{#if page > 0} {#if page > 0}
<button onclick={() => (page -= 1)}> Prev </button> <button on:click={() => (page -= 1)}> Prev </button>
{/if} {/if}
</div> </div>
@ -333,7 +335,7 @@
<div class="grow-1 flex justify-start align-center"> <div class="grow-1 flex justify-start align-center">
{#if showNext} {#if showNext}
<button onclick={() => (page += 1)}> Next </button> <button on:click={() => (page += 1)}> Next </button>
{/if} {/if}
</div> </div>
</div> </div>
@ -343,7 +345,7 @@
{/if} {/if}
<dialog class="newImageDialog" bind:this={uploadImageDialog}> <dialog class="newImageDialog" bind:this={uploadImageDialog}>
<form onsubmit={preventDefault(addImage)}> <form on:submit|preventDefault={addImage}>
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Data file</label> <label for="file">Data file</label>
<div class="form-msg"> <div class="form-msg">
@ -363,6 +365,7 @@
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={addImageMessages} />
{#if addFile} {#if addFile}
{#await adding} {#await adding}
<button disabled> Uploading </button> <button disabled> Uploading </button>
@ -412,4 +415,10 @@
table tr th:first-child { table tr th:first-child {
border-left: none; border-left: none;
} }
table tr td button,
table tr td .button {
padding: 5px 10px;
box-shadow: 0 2px 5px 1px #66666655;
}
</style> </style>

View File

@ -1,29 +1,26 @@
<script lang="ts"> <script lang="ts">
import { post, postFormData, showMessage } from 'src/lib/requests.svelte'; import { post, postFormData } from 'src/lib/requests.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import FileUpload from 'src/lib/FileUpload.svelte'; import FileUpload from 'src/lib/FileUpload.svelte';
import { onDestroy } from 'svelte'; import MessageSimple from 'src/lib/MessageSimple.svelte';
import { createEventDispatcher, onDestroy } from 'svelte';
import Spinner from 'src/lib/Spinner.svelte'; import Spinner from 'src/lib/Spinner.svelte';
import type { Task } from './tasks/types'; import type { Task } from './TasksTable.svelte';
import Tabs from 'src/lib/Tabs.svelte';
import hljs from 'highlight.js';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils';
let { let { model } = $props<{ model: Model }>();
model,
onupload = () => {},
ontaskReload = () => {}
}: { model: Model; onupload?: () => void; ontaskReload?: () => void } = $props();
let file: File | undefined = $state(); let file: File | undefined = $state();
const dispatch = createEventDispatcher<{ upload: void; taskReload: void }>();
let _result: Promise<Task> = $state(new Promise(() => {})); let _result: Promise<Task> = $state(new Promise(() => {}));
let run = $state(false); let run = $state(false);
let last_task: string | undefined = $state(); let last_task: string | undefined = $state();
let last_task_timeout: number | null = null; let last_task_timeout: number | null = null;
let messages: MessageSimple;
async function reloadLastTimeout() { async function reloadLastTimeout() {
if (!last_task) { if (!last_task) {
return; return;
@ -34,7 +31,7 @@
const r = await post('tasks/task', { id: last_task }); const r = await post('tasks/task', { id: last_task });
if ([0, 1, 2, 3].includes(r.status)) { if ([0, 1, 2, 3].includes(r.status)) {
setTimeout(reloadLastTimeout, 500); setTimeout(reloadLastTimeout, 500);
setTimeout(ontaskReload, 500); setTimeout(() => dispatch('taskReload'), 500);
} else { } else {
_result = Promise.resolve(r); _result = Promise.resolve(r);
} }
@ -45,6 +42,7 @@
async function submit() { async function submit() {
if (!file) return; if (!file) return;
messages.clear();
let form = new FormData(); let form = new FormData();
form.append('json_data', JSON.stringify({ id: model.id })); form.append('json_data', JSON.stringify({ id: model.id }));
@ -58,10 +56,14 @@
file = undefined; file = undefined;
last_task_timeout = setTimeout(() => reloadLastTimeout()); last_task_timeout = setTimeout(() => reloadLastTimeout());
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not run the model'); if (e instanceof Response) {
messages.display(await e.json());
} else {
messages.display('Could not run the model');
}
} }
onupload(); dispatch('upload');
} }
onDestroy(() => { onDestroy(() => {
@ -71,72 +73,7 @@
}); });
</script> </script>
<Tabs active="upload" let:isActive> <form on:submit|preventDefault={submit}>
<div class="buttons" slot="buttons" let:setActive let:isActive>
<button class="tab" class:selected={isActive('upload')} onclick={setActive('upload')}>
Upload
</button>
<button class="tab" class:selected={isActive('api')} onclick={setActive('api')}> Api </button>
</div>
<div class="content" class:selected={isActive('api')}>
<div class="codeinfo">
To perform an image classfication please follow the example bellow:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`let form = new FormData();
form.append('json_data', JSON.stringify({ id: '${model.id}' }));
form.append('file', file, 'file');
const headers = new Headers();
headers.append('response-type', 'application/json');
headers.append('token', token);
const r = await fetch('${window.location.protocol}//${window.location.hostname}/api/tasks/start/image', {
method: 'POST',
headers: headers,
body: form
});`,
{ language: 'javascript' }
).value}</pre>
On Success the request will return a json with this format:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`{ id "00000000-0000-0000-0000-000000000000" }`,
{ language: 'json' }
).value}</pre>
This id can be used to query the API for the result of the task:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`const headers = new Headers();
headers.append('content-type', 'application/json');
headers.append('token', token);
const r = await fetch('${window.location.protocol}//${window.location.hostname}/api/tasks/task', {
method: 'POST',
headers: headers,
body: JSON.stringify({ id: '00000000-0000-0000-0000-000000000000' })
});`,
{ language: 'javascript' }
).value}</pre>
Once the task shows the status as 4 then the data can be obatined in the result field: The successful
return value has this type:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`{
"id": string,
"user_id": string,
"model_id": string,
"status": number,
"status_message": string,
"user_confirmed": number,
"compacted": number,
"type": number,
"extra_task_info": string,
"result": string,
"created": string
}`,
{ language: 'javascript' }
).value}</pre>
</div>
</div>
<div class="content" class:selected={isActive('upload')}>
<form onsubmit={preventDefault(submit)} style="box-shadow: none;">
<fieldset class="file-upload"> <fieldset class="file-upload">
<label for="file">Image</label> <label for="file">Image</label>
<div class="form-msg">Run image through them model and get the result</div> <div class="form-msg">Run image through them model and get the result</div>
@ -149,6 +86,7 @@ const r = await fetch('${window.location.protocol}//${window.location.hostname}/
</div> </div>
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={messages} />
<button> Run </button> <button> Run </button>
{#if run} {#if run}
{#await _result} {#await _result}
@ -171,11 +109,3 @@ const r = await fetch('${window.location.protocol}//${window.location.hostname}/
{/await} {/await}
{/if} {/if}
</form> </form>
</div>
</Tabs>
<style lang="scss">
.codeinfo {
padding: 20px;
}
</style>

View File

@ -1,4 +1,5 @@
<script lang="ts"> <script lang="ts">
import { post } from 'src/lib/requests.svelte';
import type { Model } from 'src/routes/models/edit/+page.svelte'; import type { Model } from 'src/routes/models/edit/+page.svelte';
import RunModel from './RunModel.svelte'; import RunModel from './RunModel.svelte';
import TasksTable from './tasks/TasksTable.svelte'; import TasksTable from './tasks/TasksTable.svelte';
@ -11,7 +12,7 @@
{#if active} {#if active}
<div class="content selected"> <div class="content selected">
<RunModel {model} onupload={() => table.getList()} ontaskReload={() => table.getList()} /> <RunModel {model} on:upload={() => table.getList()} on:taskReload={() => table.getList()} />
<TasksTable {model} bind:this={table} /> <TasksTable {model} bind:this={table} />
<Stats {model} /> <Stats {model} />
</div> </div>

View File

@ -1,20 +1,14 @@
<script lang="ts"> <script lang="ts">
import { notificationStore } from 'src/lib/NotificationsStore.svelte'; import MessageSimple from 'src/lib/MessageSimple.svelte';
import type { Model } from './+page.svelte'; import type { Model } from './+page.svelte';
import { post, showMessage } from 'src/lib/requests.svelte'; import { post } from 'src/lib/requests.svelte';
import { preventDefault } from 'src/lib/utils'; import { createEventDispatcher } from 'svelte';
let { let { number_of_invalid_images, has_data, model } = $props<{
number_of_invalid_images,
has_data,
model,
onreload = () => {}
}: {
number_of_invalid_images: number; number_of_invalid_images: number;
has_data: boolean; has_data: boolean;
model: Model; model: Model;
onreload?: () => void; }>();
} = $props();
let data = $state({ let data = $state({
model_type: 'simple', model_type: 'simple',
@ -24,39 +18,54 @@
let submitted = $state(false); let submitted = $state(false);
let dispatch = createEventDispatcher<{ reload: void }>();
let messages: MessageSimple;
async function submit() { async function submit() {
messages.clear();
submitted = true; submitted = true;
try { try {
await post('models/train', { await post('models/train', {
id: model.id, id: model.id,
...data ...data
}); });
onreload(); dispatch('reload');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not start the training of the model'); if (e instanceof Response) {
messages.display(await e.json());
} else {
messages.display('Could not start the training of the model');
}
} }
} }
async function submitRetrain() { async function submitRetrain() {
messages.clear();
submitted = true; submitted = true;
try { try {
await post('model/train/retrain', { id: model.id }); await post('model/train/retrain', { id: model.id });
onreload(); dispatch('reload');
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not start the training of the model'); if (e instanceof Response) {
messages.display(await e.json());
} else {
messages.display('Could not start the training of the model');
}
} }
} }
</script> </script>
{#if model.status == 2} {#if model.status == 2}
<form class:submitted onsubmit={preventDefault(submit)}> <form class:submitted on:submit|preventDefault={submit}>
{#if has_data} {#if has_data}
{#if number_of_invalid_images > 0} {#if number_of_invalid_images > 0}
<p class="danger"> <p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be deleted when the model trains. These images will be delete when the model trains.
</p> </p>
{/if} {/if}
<MessageSimple bind:this={messages} />
<!-- TODO expading mode --> <!-- TODO expading mode -->
<fieldset> <fieldset>
<legend> Model Type </legend> <legend> Model Type </legend>
@ -101,16 +110,17 @@
<h2>To train the model please provide data to the model first</h2> <h2>To train the model please provide data to the model first</h2>
{/if} {/if}
</form> </form>
{:else if ![4, 6, 7].includes(model.status)} {:else}
<form class:submitted onsubmit={submitRetrain}> <form class:submitted on:submit|preventDefault={submitRetrain}>
{#if has_data} {#if has_data}
<h2>This model has new classes and can be expanded</h2> <h2>This model has new classes and can be expanded</h2>
{#if number_of_invalid_images > 0} {#if number_of_invalid_images > 0}
<p class="danger"> <p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be deleted when the model trains. These images will be delete when the model trains.
</p> </p>
{/if} {/if}
<MessageSimple bind:this={messages} />
<button> Retrain </button> <button> Retrain </button>
{:else} {:else}
<h2>To train the model please provide data to the model first</h2> <h2>To train the model please provide data to the model first</h2>

View File

@ -1,27 +0,0 @@
<script lang="ts">
import hljs from 'highlight.js';
import type { Model } from '../+page.svelte';
let { model }: { model: Model } = $props();
</script>
To create a new class via the API you can:
<pre style="font-family: Fira Code;">{@html hljs.highlight(
`let form = new FormData();
form.append('json_data', JSON.stringify({
id: '${model.id}',
name: 'New class name'
}));
form.append('file', file, 'file');
const headers = new Headers();
headers.append('response-type', 'application/json');
headers.append('token', token);
const r = await fetch('${window.location.protocol}//${window.location.hostname}/models/data/class/new', {
method: 'POST',
headers: headers,
body: form
});`,
{ language: 'javascript' }
).value}</pre>

View File

@ -1,5 +1,5 @@
<script lang="ts"> <script lang="ts">
import { onDestroy } from 'svelte'; import { onDestroy, onMount } from 'svelte';
import type { Model } from '../+page.svelte'; import type { Model } from '../+page.svelte';
import { post, showMessage } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
import type { DataPoint, TasksStatsDay } from 'src/types/stats/task'; import type { DataPoint, TasksStatsDay } from 'src/types/stats/task';
@ -18,6 +18,7 @@
PointElement, PointElement,
LineElement LineElement
} from 'chart.js'; } from 'chart.js';
import ModelData from '../ModelData.svelte';
Chart.register( Chart.register(
Title, Title,
@ -56,9 +57,7 @@
} }
let pie: HTMLCanvasElement; let pie: HTMLCanvasElement;
let pie2: HTMLCanvasElement;
let pieChart: Chart<'pie'> | undefined; let pieChart: Chart<'pie'> | undefined;
let pie2Chart: Chart<'pie'> | undefined;
function createPie(s: TasksStatsDay) { function createPie(s: TasksStatsDay) {
if (pieChart) { if (pieChart) {
pieChart.destroy(); pieChart.destroy();
@ -73,31 +72,23 @@
'Classfication Failure', 'Classfication Failure',
'Classfication Preparing', 'Classfication Preparing',
'Classfication Running', 'Classfication Running',
'Classfication Unknown' 'Classfication Unknown',
'Non Classfication Error',
'Non Classfication Success'
], ],
datasets: [ datasets: [
{ {
label: 'Total', label: 'Total',
data: [t.c_error, t.c_success, t.c_failure, t.c_pre_running, t.c_running, t.c_unknown] data: [
} t.c_error,
t.c_success,
t.c_failure,
t.c_pre_running,
t.c_running,
t.c_unknown,
t.nc_error,
t.nc_success
] ]
},
options: {
animation: false
}
});
if (pie2Chart) {
pieChart.destroy();
}
pie2Chart = new Chart(pie2, {
type: 'pie',
data: {
labels: ['Non Classfication Error', 'Non Classfication Success'],
datasets: [
{
label: 'Total',
data: [t.nc_error, t.nc_success]
} }
] ]
}, },
@ -124,6 +115,7 @@
nc_error: 'Non Classfication Error', nc_error: 'Non Classfication Error',
nc_success: 'Non Classfication Success' nc_success: 'Non Classfication Success'
}; };
let t = s.total;
let labels = new Array(24).fill(0).map((_, i) => i); let labels = new Array(24).fill(0).map((_, i) => i);
lineChart = new Chart(line, { lineChart = new Chart(line, {
type: 'line', type: 'line',
@ -155,14 +147,9 @@
<h1>Statistics (Day)</h1> <h1>Statistics (Day)</h1>
<h2>Total</h2> <h2>Total</h2>
<div class="pies">
<div> <div>
<canvas bind:this={pie}></canvas> <canvas bind:this={pie}></canvas>
</div> </div>
<div>
<canvas bind:this={pie2}></canvas>
</div>
</div>
<h2>Hourly</h2> <h2>Hourly</h2>
<div> <div>
@ -173,12 +160,4 @@
canvas { canvas {
width: 100%; width: 100%;
} }
.pies {
display: flex;
align-content: stretch;
div {
width: 50%;
}
}
</style> </style>

View File

@ -1,4 +1,16 @@
<script lang="ts" context="module"> <script lang="ts" context="module">
export type Task = {
id: string;
user_id: string;
model_id: string;
status: number;
status_message: string;
user_confirmed: number;
compacted: number;
type: number;
created: string;
result: string;
};
export const TaskType = { export const TaskType = {
TASK_FAILED_RUNNING: -2, TASK_FAILED_RUNNING: -2,
TASK_FAILED_CREATION: -1, TASK_FAILED_CREATION: -1,
@ -40,10 +52,9 @@
<script lang="ts"> <script lang="ts">
import { post, showMessage } from 'src/lib/requests.svelte'; import { post, showMessage } from 'src/lib/requests.svelte';
import type { Model } from '../+page.svelte'; import type { Model } from '../+page.svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte';
import Tooltip from 'src/lib/Tooltip.svelte'; import Tooltip from 'src/lib/Tooltip.svelte';
import type { Task } from './types';
let { model }: { model: Model } = $props(); let { model }: { model: Model } = $props();
let page = $state(0); let page = $state(0);
@ -69,6 +80,7 @@
} }
}); });
let userPreceptionMessages: MessageSimple;
// This returns a function that performs the call and does not do the call it self // This returns a function that performs the call and does not do the call it self
function userPreception(task: string, agree: number) { function userPreception(task: string, agree: number) {
return async function () { return async function () {
@ -87,6 +99,7 @@
<div> <div>
<h2>Tasks</h2> <h2>Tasks</h2>
<MessageSimple bind:this={userPreceptionMessages} />
<table> <table>
<thead> <thead>
<tr> <tr>
@ -143,14 +156,14 @@
<div> <div>
{#if task.user_confirmed != 1} {#if task.user_confirmed != 1}
<Tooltip title="Agree with the result of the task"> <Tooltip title="Agree with the result of the task">
<button type="button" onclick={userPreception(task.id, 1)}> <button type="button" on:click={userPreception(task.id, 1)}>
<span class="bi bi-check"></span> <span class="bi bi-check"></span>
</button> </button>
</Tooltip> </Tooltip>
{/if} {/if}
{#if task.user_confirmed != -1} {#if task.user_confirmed != -1}
<Tooltip title="Disagree with the result"> <Tooltip title="Disagree with the result">
<button class="danger" type="button" onclick={userPreception(task.id, -1)}> <button class="danger" type="button" on:click={userPreception(task.id, -1)}>
<span class="bi bi-x-lg"></span> <span class="bi bi-x-lg"></span>
</button> </button>
</Tooltip> </Tooltip>
@ -195,7 +208,7 @@
<div class="flex justify-center align-center"> <div class="flex justify-center align-center">
<div class="grow-1 flex justify-end align-center"> <div class="grow-1 flex justify-end align-center">
{#if page > 0} {#if page > 0}
<button onclick={() => (page -= 1)}> Prev </button> <button on:click={() => (page -= 1)}> Prev </button>
{/if} {/if}
</div> </div>
@ -205,7 +218,7 @@
<div class="grow-1 flex justify-start align-center"> <div class="grow-1 flex justify-start align-center">
{#if showNext} {#if showNext}
<button onclick={() => (page += 1)}> Next </button> <button on:click={() => (page += 1)}> Next </button>
{/if} {/if}
</div> </div>
</div> </div>
@ -216,6 +229,10 @@
width: 100%; width: 100%;
display: flex; display: flex;
justify-content: space-between; justify-content: space-between;
& > button {
margin: 3px 5px;
}
} }
table { table {

View File

@ -1,12 +0,0 @@
export type Task = {
id: string;
user_id: string;
model_id: string;
status: number;
status_message: string;
user_confirmed: number;
compacted: number;
type: number;
created: string;
result: string;
};

View File

@ -3,16 +3,3 @@ export type ModelStats = Array<{
training: number; training: number;
testing: number; testing: number;
}>; }>;
export type Class = {
name: string;
id: string;
status: number;
};
export type Image = {
file_path: string;
mode: number;
status: number;
id: string;
};

View File

@ -3,7 +3,6 @@
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { userStore } from '../UserStore.svelte'; import { userStore } from '../UserStore.svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { preventDefault } from 'src/lib/utils';
let submitted = $state(false); let submitted = $state(false);
@ -40,7 +39,7 @@
<div class="login-page"> <div class="login-page">
<div> <div>
<h1>Register</h1> <h1>Register</h1>
<form onsubmit={preventDefault(onSubmit)} class:submitted> <form on:submit|preventDefault={onSubmit} class:submitted>
<fieldset> <fieldset>
<label for="username">Username</label> <label for="username">Username</label>
<input required name="username" bind:value={loginData.username} /> <input required name="username" bind:value={loginData.username} />

View File

@ -4,11 +4,10 @@
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import { post, showMessage } from 'src/lib/requests.svelte'; import { post } from 'src/lib/requests.svelte';
import MessageSimple, { type DisplayFn } from 'src/lib/MessageSimple.svelte';
import TokenTable from './TokenTable.svelte'; import TokenTable from './TokenTable.svelte';
import DeleteUser from './DeleteUser.svelte'; import DeleteUser from './DeleteUser.svelte';
import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { preventDefault } from 'src/lib/utils';
onMount(() => { onMount(() => {
if (!userStore.isLogin()) { if (!userStore.isLogin()) {
@ -27,8 +26,12 @@
let submiitedEmail = $state(false); let submiitedEmail = $state(false);
let submiitedPassword = $state(false); let submiitedPassword = $state(false);
let msgEmail: MessageSimple;
let msgPassword: MessageSimple;
async function onSubmitEmail() { async function onSubmitEmail() {
submiitedEmail = true; submiitedEmail = true;
msgEmail.display('');
if (!userStore.user) return; if (!userStore.user) return;
@ -41,31 +44,31 @@
...userStore.user, ...userStore.user,
...req ...req
}; };
notificationStore.add({ msgEmail.display('User updated successufly!', { type: 'success', timeToShow: 10000 });
message: 'User updated successufly!',
type: 'success',
timeToLive: 10000
});
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not update email'); if (e instanceof Response) {
msgEmail.display(await e.json());
} else {
msgEmail.display('Could not update email');
}
} }
} }
async function onSubmitPassword() { async function onSubmitPassword() {
submiitedPassword = true; submiitedPassword = true;
msgPassword.display('');
if (!userStore.user) return; if (!userStore.user) return;
try { try {
await post('user/info/password', passwordData); await post('user/info/password', passwordData);
passwordData = { old_password: '', password: '', password2: '' }; passwordData = { old_password: '', password: '', password2: '' };
msgPassword.display('Password updated successufly!', { type: 'success', timeToShow: 10000 });
notificationStore.add({
message: 'Password updated successufly!',
type: 'success',
timeToLive: 10000
});
} catch (e) { } catch (e) {
showMessage(e, notificationStore, 'Could not update password'); if (e instanceof Response) {
msgPassword.display(await e.json());
} else {
msgPassword.display('Could not update password');
}
} }
} }
</script> </script>
@ -77,14 +80,15 @@
<div class="login-page"> <div class="login-page">
<div> <div>
<h1>User Infomation</h1> <h1>User Infomation</h1>
<form onsubmit={onSubmitEmail} class:submiitedEmail> <form on:submit|preventDefault={onSubmitEmail} class:submiitedEmail>
<fieldset> <fieldset>
<label for="email">Email</label> <label for="email">Email</label>
<input type="email" required name="email" bind:value={email} /> <input type="email" required name="email" bind:value={email} />
</fieldset> </fieldset>
<MessageSimple bind:this={msgEmail} />
<button> Update </button> <button> Update </button>
</form> </form>
<form onsubmit={preventDefault(onSubmitPassword)} class:submiitedPassword> <form on:submit|preventDefault={onSubmitPassword} class:submiitedPassword>
<fieldset> <fieldset>
<label for="old_password">Old Password</label> <label for="old_password">Old Password</label>
<input <input
@ -102,6 +106,7 @@
<label for="password2">Repeat New Password</label> <label for="password2">Repeat New Password</label>
<input required bind:value={passwordData.password2} name="password2" type="password" /> <input required bind:value={passwordData.password2} name="password2" type="password" />
</fieldset> </fieldset>
<MessageSimple bind:this={msgPassword} />
<div> <div>
<button> Update </button> <button> Update </button>
</div> </div>

View File

@ -1,21 +1,25 @@
<script lang="ts"> <script lang="ts">
import {createEventDispatcher} from 'svelte';
import MessageSimple from 'src/lib/MessageSimple.svelte';
import Tooltip from 'src/lib/Tooltip.svelte'; import Tooltip from 'src/lib/Tooltip.svelte';
import 'src/styles/forms.css'; import 'src/styles/forms.css';
import {post} from 'src/lib/requests.svelte'; import {post} from 'src/lib/requests.svelte';
import Spinner from 'src/lib/Spinner.svelte'; import Spinner from 'src/lib/Spinner.svelte';
import { preventDefault } from 'src/lib/utils';
let { onreload = () => {} }: { onreload?: () => void } = $props();
const dispatch = createEventDispatcher<{reload: void}>();
let addNewToken = $state(false); let addNewToken = $state(false);
let messages: MessageSimple;
let expiry_date: HTMLInputElement = $state(undefined as any); let expiry_date: HTMLInputElement = $state(undefined as any);
type NewToken = { type NewToken = {
name: string; name: string,
expiry: number; expiry: number,
token: string; token: string,
}; }
let token: Promise<NewToken> | undefined = $state(); let token: Promise<NewToken> | undefined = $state();
@ -26,7 +30,7 @@
} = $state({ } = $state({
name: '', name: '',
expiry: '', expiry: '',
password: '' password: '',
}); });
async function createToken(e: SubmitEvent & { currentTarget: HTMLFormElement }) { async function createToken(e: SubmitEvent & { currentTarget: HTMLFormElement }) {
@ -44,16 +48,16 @@
} }
try { try {
const r = await post('user/token/add', { const r = await post("user/token/add", {
name: newToken.name, name: newToken.name,
expiry: expiry, expiry: expiry,
password: newToken.password password: newToken.password,
}); });
token = Promise.resolve(r); token = Promise.resolve(r)
setTimeout(onreload, 500); setTimeout(() => dispatch('reload'), 500)
} catch (e) { } catch (e) {
token = undefined; token = undefined;
console.error('Notify user', e); console.error("Notify user", e)
} }
} }
@ -72,7 +76,7 @@
<div> <div>
<h2>Add New Token</h2> <h2>Add New Token</h2>
{#if !token} {#if !token}
<form onsubmit={preventDefault(createToken)}> <form on:submit|preventDefault={createToken}>
<fieldset> <fieldset>
<label for="name">Name</label> <label for="name">Name</label>
<input required bind:value={newToken.name} name="name" /> <input required bind:value={newToken.name} name="name" />
@ -82,7 +86,7 @@
<div class="flex"> <div class="flex">
<input bind:this={expiry_date} bind:value={newToken.expiry} name="expiry_date" /> <input bind:this={expiry_date} bind:value={newToken.expiry} name="expiry_date" />
<Tooltip title="Time in seconds. Leave empty to last forever"> <Tooltip title="Time in seconds. Leave empty to last forever">
<span class="center-question bi bi-question-circle-fill"></span> <span class="center-question bi bi-question-circle-fill" />
</Tooltip> </Tooltip>
</div> </div>
</fieldset> </fieldset>
@ -90,6 +94,7 @@
<label for="password">Password</label> <label for="password">Password</label>
<input required bind:value={newToken.password} name="password" /> <input required bind:value={newToken.password} name="password" />
</fieldset> </fieldset>
<MessageSimple bind:this={messages} />
<div> <div>
<button> Update </button> <button> Update </button>
</div> </div>
@ -99,20 +104,20 @@
<Spinner /> Generating <Spinner /> Generating
{:then t} {:then t}
<h3> Token generated </h3> <h3> Token generated </h3>
<form onsubmit={preventDefault(() => {})}> <form on:submit|preventDefault={() => {}}>
<fieldset> <fieldset>
<label for="token">Token</label> <label for="token">Token</label>
<div class="flex"> <div class="flex">
<input value={t.token} oninput={(e) => e.preventDefault()} name="token" /> <input value={t.token} on:input={(e) => e.preventDefault() } name="token" />
<div style="width: 5em;"> <div style="width: 5em;">
<button onclick={() => navigator.clipboard.writeText(t.token)}> <button on:click={() => navigator.clipboard.writeText(t.token)} >
<span class="bi bi-clipboard"></span> <span class="bi bi-clipboard" />
</button> </button>
</div> </div>
</div> </div>
</fieldset> </fieldset>
<div> <div>
<button onclick={() => (token = undefined)}> Generate new token </button> <button on:click={() => token = undefined}> Generate new token </button>
</div> </div>
</form> </form>
{:catch e} {:catch e}
@ -122,6 +127,6 @@
</div> </div>
{:else} {:else}
<div> <div>
<button class="expander" onclick={() => (addNewToken = true)}> Add New Token </button> <button class="expander" on:click={() => (addNewToken = true)}> Add New Token </button>
</div> </div>
{/if} {/if}

View File

@ -2,7 +2,6 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { notificationStore } from 'src/lib/NotificationsStore.svelte'; import { notificationStore } from 'src/lib/NotificationsStore.svelte';
import { rdelete, showMessage } from 'src/lib/requests.svelte'; import { rdelete, showMessage } from 'src/lib/requests.svelte';
import { preventDefault } from 'src/lib/utils';
import { userStore } from 'src/routes/UserStore.svelte'; import { userStore } from 'src/routes/UserStore.svelte';
let data = $state({ password: '' }); let data = $state({ password: '' });
@ -24,7 +23,7 @@
} }
</script> </script>
<form class="danger-bg" onsubmit={preventDefault(deleteUser)}> <form class="danger-bg" on:submit|preventDefault={deleteUser}>
<h2 class="no-top-margin">Delete user</h2> <h2 class="no-top-margin">Delete user</h2>
Deleting the user will delete all your data stored in the service including the images. Deleting the user will delete all your data stored in the service including the images.
<fieldset> <fieldset>

View File

@ -70,7 +70,7 @@
{new Date(token.create_date).toLocaleString()} {new Date(token.create_date).toLocaleString()}
</td> </td>
<td> <td>
<button class="danger" onclick={() => removeToken(token)}> <button class="danger" on:click={() => removeToken(token)}>
<span class="bi bi-trash"></span> <span class="bi bi-trash"></span>
</button> </button>
</td> </td>
@ -81,7 +81,7 @@
<div class="flex justify-center align-center"> <div class="flex justify-center align-center">
<div class="grow-1 flex justify-end align-center"> <div class="grow-1 flex justify-end align-center">
{#if page > 0} {#if page > 0}
<button onclick={() => (page -= 1)}> Prev </button> <button on:click={() => (page -= 1)}> Prev </button>
{/if} {/if}
</div> </div>
@ -91,15 +91,25 @@
<div class="grow-1 flex justify-start align-center"> <div class="grow-1 flex justify-start align-center">
{#if showNext} {#if showNext}
<button onclick={() => (page += 1)}> Next </button> <button on:click={() => (page += 1)}> Next </button>
{/if} {/if}
</div> </div>
</div> </div>
</div> </div>
<AddToken onreload={getList} /> <AddToken on:reload={getList} />
<style lang="scss"> <style lang="scss">
.buttons {
width: 100%;
display: flex;
justify-content: space-between;
& > button {
margin: 3px 5px;
}
}
table { table {
width: 100%; width: 100%;
box-shadow: 0 2px 8px 1px #66666622; box-shadow: 0 2px 8px 1px #66666622;
@ -122,4 +132,10 @@
table tr th:first-child { table tr th:first-child {
border-left: none; border-left: none;
} }
table tr td button,
table tr td .button {
padding: 5px 10px;
box-shadow: 0 2px 5px 1px #66666655;
}
</style> </style>

View File

@ -65,7 +65,6 @@ button.expander::after {
a.button { a.button {
text-decoration: none; text-decoration: none;
height: 1.6em;
} }
.flex { .flex {
@ -191,12 +190,3 @@ form.danger-bg {
background-color: var(--danger-transparent); background-color: var(--danger-transparent);
border: 1px solid var(--danger); border: 1px solid var(--danger);
} }
pre {
font-family: 'Fira Code';
background-color: #f6f8fa;
word-break: break-word;
white-space: pre-wrap;
border-radius: 10px;
padding: 10px;
}

View File

@ -13,8 +13,8 @@ const config = {
// See https://kit.svelte.dev/docs/adapters for more information about adapters. // See https://kit.svelte.dev/docs/adapters for more information about adapters.
adapter: adapter(), adapter: adapter(),
alias: { alias: {
src: 'src', src: "src",
routes: 'src/routes' routes: "src/routes",
} }
} }
}; };

View File

@ -2,10 +2,5 @@ import { sveltekit } from '@sveltejs/kit/vite';
import { defineConfig } from 'vite'; import { defineConfig } from 'vite';
export default defineConfig({ export default defineConfig({
plugins: [sveltekit()], plugins: [sveltekit()]
build: {
commonjsOptions: {
esmExternals: true
}
}
}); });