diff --git a/DockerfileServer b/DockerfileServer new file mode 100644 index 0000000..a6fdaa6 --- /dev/null +++ b/DockerfileServer @@ -0,0 +1,37 @@ +FROM docker.io/nvidia/cuda:12.3.2-devel-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +# Sometimes you have to get update twice because ? +RUN apt-get update +RUN apt-get update + +RUN apt-get install -y wget unzip python3-pip vim python3 python3-pip curl + +RUN wget https://go.dev/dl/go1.22.2.linux-amd64.tar.gz +RUN tar -xvf go1.22.2.linux-amd64.tar.gz -C /usr/local +ENV PATH=$PATH:/usr/local/go/bin +ENV GOPATH=/go + +RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | tar -C /usr/local -xz' +# RUN bash -c 'curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.13.1.tar.gz" | tar -C /usr/local -xz' +RUN ldconfig + +RUN ln -s /usr/bin/python3 /usr/bin/python +RUN python -m pip install nvidia-pyindex +ADD requirements.txt . +RUN python -m pip install -r requirements.txt + +ENV CUDNN_PATH=/usr/local/lib/python3.10/dist-packages/nvidia/cudnn +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages/nvidia/cudnn/lib + +WORKDIR /app + +ADD go.mod . +ADD go.sum . +ADD main.go . +ADD logic logic + +RUN go install || true + +CMD ["go", "run", "."] diff --git a/logic/db_types/classes.go b/logic/db_types/classes.go index 49eee43..b74e5a5 100644 --- a/logic/db_types/classes.go +++ b/logic/db_types/classes.go @@ -6,3 +6,19 @@ const ( DATA_POINT_MODE_TRAINING DATA_POINT_MODE = 1 DATA_POINT_MODE_TESTING = 2 ) + +type ModelClassStatus int + +const ( + CLASS_STATUS_TO_TRAIN ModelClassStatus = iota + 1 + CLASS_STATUS_TRAINING + CLASS_STATUS_TRAINED +) + +type ModelClass struct { + Id string `db:"mc.id" json:"id"` + ModelId string `db:"mc.model_id" json:"model_id"` + Name string `db:"mc.name" json:"name"` + ClassOrder int `db:"mc.class_order" json:"class_order"` + Status int `db:"mc.status" json:"status"` +} diff --git a/logic/db_types/definitions.go b/logic/db_types/definitions.go new file mode 100644 index 0000000..d511d32 --- /dev/null +++ b/logic/db_types/definitions.go @@ -0,0 +1,95 @@ +package dbtypes + +import ( + "time" + + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" +) + +type DefinitionStatus int + +const ( + DEFINITION_STATUS_CANCELD_TRAINING DefinitionStatus = -4 + DEFINITION_STATUS_FAILED_TRAINING = -3 + DEFINITION_STATUS_PRE_INIT = 1 + DEFINITION_STATUS_INIT = 2 + DEFINITION_STATUS_TRAINING = 3 + DEFINITION_STATUS_PAUSED_TRAINING = 6 + DEFINITION_STATUS_TRANIED = 4 + DEFINITION_STATUS_READY = 5 +) + +type Definition struct { + Id string `db:"md.id" json:"id"` + ModelId string `db:"md.model_id" json:"model_id"` + Accuracy float64 `db:"md.accuracy" json:"accuracy"` + TargetAccuracy int `db:"md.target_accuracy" json:"target_accuracy"` + Epoch int `db:"md.epoch" json:"epoch"` + Status int `db:"md.status" json:"status"` + CreatedOn time.Time `db:"md.created_on" json:"created"` + EpochProgress int `db:"md.epoch_progress" json:"epoch_progress"` +} + +type SortByAccuracyDefinitions []*Definition + +func (nf SortByAccuracyDefinitions) Len() int { return len(nf) } +func (nf SortByAccuracyDefinitions) Swap(i, j int) { nf[i], nf[j] = nf[j], nf[i] } +func (nf SortByAccuracyDefinitions) Less(i, j int) bool { + return nf[i].Accuracy < nf[j].Accuracy +} + +func GetDefinition(db db.Db, definition_id string) (definition Definition, err error) { + err = GetDBOnce(db, &definition, "model_definition as md where id=$1;", definition_id) + return +} + +func MakeDefenition(db db.Db, model_id string, target_accuracy int) (definition Definition, err error) { + var NewDefinition = struct { + ModelId string `db:"model_id"` + TargetAccuracy int `db:"target_accuracy"` + }{ModelId: model_id, TargetAccuracy: target_accuracy} + + id, err := InsertReturnId(db, &NewDefinition, "model_definition", "id") + if err != nil { + return + } + return GetDefinition(db, id) +} + +func (d Definition) UpdateStatus(db db.Db, status DefinitionStatus) (err error) { + _, err = db.Exec("update model_definition set status=$1 where id=$2", status, d.Id) + return +} + +func (d Definition) MakeLayer(db db.Db, layer_order int, layer_type LayerType, shape string) (layer Layer, err error) { + var NewLayer = struct { + DefinitionId string `db:"def_id"` + LayerOrder int `db:"layer_order"` + LayerType LayerType `db:"layer_type"` + Shape string `db:"shape"` + }{ + DefinitionId: d.Id, + LayerOrder: layer_order, + LayerType: layer_type, + Shape: shape, + } + + id, err := InsertReturnId(db, &NewLayer, "model_definition_layer", "id") + if err != nil { + return + } + + return GetLayer(db, id) +} + +func (d Definition) GetLayers(db db.Db, filter string, args ...any) (layer []*Layer, err error) { + args = append(args, d.Id) + return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...) +} + +func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64) (err error) { + d.Accuracy = accuracy + d.Epoch += 1 + _, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id) + return +} diff --git a/logic/db_types/layer.go b/logic/db_types/layer.go new file mode 100644 index 0000000..a82bd7a --- /dev/null +++ b/logic/db_types/layer.go @@ -0,0 +1,50 @@ +package dbtypes + +import ( + "encoding/json" + + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" +) + +type LayerType int + +const ( + LAYER_INPUT LayerType = 1 + LAYER_DENSE = 2 + LAYER_FLATTEN = 3 + LAYER_SIMPLE_BLOCK = 4 +) + +type Layer struct { + Id string `db:"mdl.id" json:"id"` + DefinitionId string `db:"mdl.def_id" json:"definition_id"` + LayerOrder string `db:"mdl.layer_order" json:"layer_order"` + LayerType LayerType `db:"mdl.layer_type" json:"layer_type"` + Shape string `db:"mdl.shape" json:"shape"` + ExpType string `db:"mdl.exp_type" json:"exp_type"` +} + +func ShapeToString(args ...int) string { + text, err := json.Marshal(args) + if err != nil { + panic("Could not generate Shape") + } + return string(text) +} + +func StringToShape(str string) (shape []int64) { + err := json.Unmarshal([]byte(str), &shape) + if err != nil { + panic("Could not parse Shape") + } + return +} + +func (l Layer) GetShape() []int64 { + return StringToShape(l.Shape) +} + +func GetLayer(db db.Db, layer_id string) (layer Layer, err error) { + err = GetDBOnce(db, &layer, "model_definition_layer as mdl where mdl.id=$1", layer_id) + return +} diff --git a/logic/db_types/types.go b/logic/db_types/types.go index 97fb993..0bd3537 100644 --- a/logic/db_types/types.go +++ b/logic/db_types/types.go @@ -2,23 +2,26 @@ package dbtypes import ( "errors" + "fmt" + "path" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" ) -const ( - FAILED_TRAINING = -4 - FAILED_PREPARING_TRAINING = -3 - FAILED_PREPARING_ZIP_FILE = -2 - FAILED_PREPARING = -1 +type ModelStatus int - PREPARING = 1 - CONFIRM_PRE_TRAINING = 2 - PREPARING_ZIP_FILE = 3 - TRAINING = 4 - READY = 5 - READY_ALTERATION = 6 - READY_ALTERATION_FAILED = -6 +const ( + FAILED_TRAINING ModelStatus = -4 + FAILED_PREPARING_TRAINING = -3 + FAILED_PREPARING_ZIP_FILE = -2 + FAILED_PREPARING = -1 + PREPARING = 1 + CONFIRM_PRE_TRAINING = 2 + PREPARING_ZIP_FILE = 3 + TRAINING = 4 + READY = 5 + READY_ALTERATION = 6 + READY_ALTERATION_FAILED = -6 READY_RETRAIN = 7 READY_RETRAIN_FAILED = -7 @@ -26,15 +29,6 @@ const ( type ModelDefinitionStatus int -type LayerType int - -const ( - LAYER_INPUT LayerType = 1 - LAYER_DENSE = 2 - LAYER_FLATTEN = 3 - LAYER_SIMPLE_BLOCK = 4 -) - const ( MODEL_DEFINITION_STATUS_CANCELD_TRAINING ModelDefinitionStatus = -4 MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3 @@ -46,14 +40,6 @@ const ( MODEL_DEFINITION_STATUS_READY = 5 ) -type ModelClassStatus int - -const ( - MODEL_CLASS_STATUS_TO_TRAIN ModelClassStatus = 1 - MODEL_CLASS_STATUS_TRAINING = 2 - MODEL_CLASS_STATUS_TRAINED = 3 -) - type ModelHeadStatus int const ( @@ -97,6 +83,61 @@ func (m BaseModel) CanEval() bool { return true } +// DO NOT Pass un filtered data on filters +func (m BaseModel) GetDefinitions(db db.Db, filters string, args ...any) ([]*Definition, error) { + n_args := []any{m.Id} + n_args = append(n_args, args...) + return GetDbMultitple[Definition](db, fmt.Sprintf("model_definition as md where md.model_id=$1 %s", filters), n_args...) +} + +func (m BaseModel) GetClasses(db db.Db, filters string, args ...any) ([]*ModelClass, error) { + n_args := []any{m.Id} + n_args = append(n_args, args...) + return GetDbMultitple[ModelClass](db, fmt.Sprintf("model_classes as mc where mc.model_id=$1 %s", filters), n_args...) +} + +func (m *BaseModel) UpdateStatus(db db.Db, status ModelStatus) (err error) { + _, err = db.Exec("update models set status=$1 where id=$2", status, m.Id) + return +} + +type DataPoint struct { + Class int `json:"class"` + Path string `json:"path"` +} + +func (m BaseModel) DataPoints(db db.Db, mode DATA_POINT_MODE) (data []DataPoint, err error) { + rows, err := db.Query( + "select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner "+ + "join model_classes as mc on mc.id = mdp.class_id "+ + "where mc.model_id = $1 and mdp.model_mode=$2;", + m.Id, mode) + if err != nil { + return + } + defer rows.Close() + + data = []DataPoint{} + + for rows.Next() { + var id string + var class_order int + var file_path string + if err = rows.Scan(&id, &class_order, &file_path); err != nil { + return + } + if file_path == "id://" { + data = append(data, DataPoint{ + Path: path.Join("./savedData", m.Id, "data", id+"."+m.Format), + Class: class_order, + }) + } else { + panic("TODO remote file path") + } + } + return +} + func StringToImageMode(colorMode string) int { switch colorMode { case "greyscale": diff --git a/logic/models/classes/main.go b/logic/models/classes/main.go index 81d7563..8f91403 100644 --- a/logic/models/classes/main.go +++ b/logic/models/classes/main.go @@ -7,15 +7,15 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" ) -type ModelClass struct { +type ModelClassJSON struct { Id string `json:"id"` ModelId string `json:"model_id" db:"model_id"` Name string `json:"name"` Status int `json:"status"` } -func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) { - return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id) +func ListClasses(c BasePack, model_id string) (cls []*ModelClassJSON, err error) { + return GetDbMultitple[ModelClassJSON](c.GetDb(), "model_classes where model_id=$1", model_id) } func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) { diff --git a/logic/models/data.go b/logic/models/data.go index 47112a0..1e6a89c 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -495,7 +495,7 @@ func handleDataUpload(handle *Handle) { return c.E500M("Could not create class", err) } - var modelClass model_classes.ModelClass + var modelClass model_classes.ModelClassJSON err = GetDBOnce(c, &modelClass, "model_classes where id=$1;", id) if err != nil { return c.E500M("Failed to get class information but class was creted", err) @@ -704,7 +704,7 @@ func handleDataUpload(handle *Handle) { return c.Error500(err) } } else { - _, err = handle.Db.Exec("delete from model_classes where model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TO_TRAIN) + _, err = handle.Db.Exec("delete from model_classes where model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN) if err != nil { return c.Error500(err) } diff --git a/logic/models/delete.go b/logic/models/delete.go index ed0bf27..7abc28f 100644 --- a/logic/models/delete.go +++ b/logic/models/delete.go @@ -51,7 +51,7 @@ func handleDelete(handle *Handle) { return c.E500M("Faield to get model", err) } - switch model.Status { + switch ModelStatus(model.Status) { case FAILED_TRAINING: fallthrough case FAILED_PREPARING_ZIP_FILE: diff --git a/logic/models/edit.go b/logic/models/edit.go index 0a6efa8..6530d93 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -35,9 +35,9 @@ func handleEdit(handle *Handle) { } type ReturnType struct { - Classes []*model_classes.ModelClass `json:"classes"` - HasData bool `json:"has_data"` - NumberOfInvalidImages int `json:"number_of_invalid_images"` + Classes []*model_classes.ModelClassJSON `json:"classes"` + HasData bool `json:"has_data"` + NumberOfInvalidImages int `json:"number_of_invalid_images"` } c.ShowMessage = false diff --git a/logic/models/train/remote_train.go b/logic/models/train/remote_train.go new file mode 100644 index 0000000..501a707 --- /dev/null +++ b/logic/models/train/remote_train.go @@ -0,0 +1,79 @@ +package models_train + +import ( + "errors" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + "github.com/goccy/go-json" +) + +func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (err error) { + l := b.GetLogger() + + model, err := GetBaseModel(b.GetDb(), *task.ModelId) + if err != nil { + task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information") + l.Error("Failed to get model information", "err", err) + return err + } + + if model.Status != TRAINING { + task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training") + return errors.New("Model not in the right status") + } + + // TODO do this when the runner says it's OK + //task.UpdateStatusLog(b, TASK_RUNNING, "Training model") + + // TODO move this to the runner part as well + var dat struct { + NumberOfModels int + Accuracy int + } + + err = json.Unmarshal([]byte(task.ExtraTaskInfo), &dat) + if err != nil { + task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model extra information") + } + + if model.ModelType == 2 { + panic("TODO") + full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) + if full_error != nil { + l.Error("Failed to generate defintions", "err", full_error) + task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model") + return errors.New("Failed to generate definitions") + } + } else { + error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) + if error != nil { + task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model") + return errors.New("Failed to generate definitions") + } + } + + runners := handler.DataMap["runners"].(map[string]interface{}) + runner := runners[runner_id].(map[string]interface{}) + runner["task"] = &task + + runners[runner_id] = runner + handler.DataMap["runners"] = runners + + return +} + +func CleanUpFailed(b BasePack, task *Task) { + db := b.GetDb() + l := b.GetLogger() + model, err := GetBaseModel(db, *task.ModelId) + if err != nil { + l.Error("Failed to get model", "err", err) + } else { + err = model.UpdateStatus(db, FAILED_TRAINING) + if err != nil { + l.Error("Failed to get status", err) + } + } +} diff --git a/logic/models/train/reset.go b/logic/models/train/reset.go index bc92eeb..655bcc1 100644 --- a/logic/models/train/reset.go +++ b/logic/models/train/reset.go @@ -17,7 +17,7 @@ func handleRest(handle *Handle) { return c.E500M("Failed to get model", err) } - if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING { + if model.Status != FAILED_PREPARING_TRAINING && model.Status != int(FAILED_TRAINING) { return c.JsonBadRequest("Model is not in status that be reset") } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 8c4fb32..1457747 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -39,16 +39,6 @@ func getDir() string { return dir } -// This function creates a new model_definition -func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) { - var NewDefinition = struct { - ModelId string `db:"model_id"` - TargetAccuracy int `db:"target_accuracy"` - }{ModelId: model_id, TargetAccuracy: target_accuracy} - - return InsertReturnId(db, &NewDefinition, "model_definition", "id") -} - func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionStatus) (err error) { _, err = c.GetDb().Exec("update model_definition set status = $1 where id = $2", status, id) return @@ -118,14 +108,14 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) var co struct { Count int `db:"count(*)"` } - err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING) + err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, CLASS_STATUS_TRAINING) if err != nil { return } count = co.Count if count == 0 { - err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN) + err = setModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, CLASS_STATUS_TO_TRAIN) if err != nil { return } @@ -137,7 +127,7 @@ func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) return generateCvsExp(c, run_path, model_id, true) } - data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING) + data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, CLASS_STATUS_TRAINING) if err != nil { return } @@ -287,7 +277,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i var co struct { Count int `db:"count(*)"` } - err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING) + err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, CLASS_STATUS_TRAINING) if err != nil { return } @@ -296,7 +286,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i count := co.Count if count == 0 { - err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN) + err = setModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, CLASS_STATUS_TO_TRAIN) if err != nil { return } else if doPanic { @@ -305,7 +295,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i return generateCvsExpandExp(c, run_path, model_id, offset, true) } - data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING) + data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, CLASS_STATUS_TRAINING) if err != nil { return } @@ -339,7 +329,7 @@ func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset i // This is to load some extra data so that the model has more things to train on // - data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10) + data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, CLASS_STATUS_TRAINED, count*10) if err != nil { return } @@ -737,7 +727,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if err != nil { l.Error("Failed to train Model! Err:") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } defer definitionsRows.Close() @@ -750,7 +740,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil { l.Error("Failed to train Model Could not read definition from db!Err:") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } definitions = append(definitions, rowv) @@ -758,7 +748,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if len(definitions) == 0 { l.Error("No Definitions defined!") - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } @@ -788,14 +778,14 @@ func trainModel(c BasePack, model *BaseModel) (err error) { _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id) if err != nil { l.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return err } _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) if err != nil { l.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return err } @@ -813,7 +803,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id) if err != nil { l.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return err } } @@ -868,7 +858,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if err != nil { l.Error("DB: failed to read definition") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } defer rows.Close() @@ -876,7 +866,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if !rows.Next() { // TODO Make the Model status have a message l.Error("All definitions failed to train!") - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } @@ -884,14 +874,14 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if err = rows.Scan(&id); err != nil { l.Error("Failed to read id:") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } if _, err = db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil { l.Error("Failed to update model definition") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } @@ -899,7 +889,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if err != nil { l.Error("Failed to select model_definition to delete") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } defer to_delete.Close() @@ -909,7 +899,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if err = to_delete.Scan(&id); err != nil { l.Error("Failed to scan the id of a model_definition to delete") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) @@ -919,7 +909,7 @@ func trainModel(c BasePack, model *BaseModel) (err error) { if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { l.Error("Failed to delete model_definition") l.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) return } @@ -1066,7 +1056,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) { err = GetDBOnce(db, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) if err == NotFoundError { // Set the class status to trained - err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING) + err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) if err != nil { l.Error("All definitions failed to train! And Failed to set class status") return err @@ -1101,7 +1091,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) { } if err = splitModel(c, model); err != nil { - err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING) + err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) if err != nil { l.Error("Failed to split the model! And Failed to set class status") return err @@ -1112,7 +1102,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) { } // Set the class status to trained - err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING) + err = setModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) if err != nil { l.Error("Failed to set class status") return err @@ -1272,7 +1262,7 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe db := c.GetDb() l := c.GetLogger() - def_id, err := MakeDefenition(db, model.Id, target_accuracy) + def, err := MakeDefenition(db, model.Id, target_accuracy) if err != nil { failed() return @@ -1281,60 +1271,68 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe order := 1 // Note the shape of the first layer defines the import size - if complexity == 2 { - // Note the shape for now is no used - width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0)))) - height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0)))) - l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height) - err = MakeLayer(db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height)) - if err != nil { - failed() - return - } - order++ - } else { - err = MakeLayer(db, def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) - if err != nil { - failed() - return - } - order++ - } - - loop := max(int((math.Log(float64(model.Width)) / math.Log(float64(10)))), 1) - for i := 0; i < loop; i++ { - err = MakeLayer(db, def_id, order, LAYER_SIMPLE_BLOCK, "") - order++ - if err != nil { - failed() - return - } - } - - err = MakeLayer(db, def_id, order, LAYER_FLATTEN, "") + //_, err = def.MakeLayer(db, order, LAYER_INPUT, ShapeToString(model.Width, model.Height, model.ImageMode)) + _, err = def.MakeLayer(db, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height)) if err != nil { failed() return } order++ - loop = max(int((math.Log(float64(number_of_classes))/math.Log(float64(10)))/2), 1) - for i := 0; i < loop; i++ { - err = MakeLayer(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) - order++ + if complexity == 0 { + _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "") if err != nil { failed() return } - } + order++ - err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT) - if err != nil { + loop := int(math.Log2(float64(number_of_classes))) + for i := 0; i < loop; i++ { + _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i))) + order++ + if err != nil { + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) + return + } + } + } else if complexity == 1 || complexity == 2 { + loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10))))) + for i := 0; i < loop; i++ { + _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") + order++ + if err != nil { + failed() + return + } + } + + _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "") + if err != nil { + failed() + return + } + order++ + + loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) + if loop == 0 { + loop = 1 + } + for i := 0; i < loop; i++ { + _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i))) + order++ + if err != nil { + failed() + return + } + } + } else { + l.Error("Unkown complexity", "complexity", complexity) failed() return } - return nil + return def.UpdateStatus(db, DEFINITION_STATUS_INIT) } func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, number_of_models int) (err error) { @@ -1393,12 +1391,14 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy return } - def_id, err := MakeDefenition(c.GetDb(), model.Id, target_accuracy) + def, err := MakeDefenition(c.GetDb(), model.Id, target_accuracy) if err != nil { failed() return } + def_id := def.Id + order := 1 width := model.Width @@ -1533,7 +1533,7 @@ func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy } func ResetClasses(c BasePack, model *BaseModel) { - _, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id) + _, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING, model.Id) if err != nil { c.GetLogger().Error("Error while reseting the classes", "error", err) } @@ -1544,7 +1544,7 @@ func trainExpandable(c *Context, model *BaseModel) { failed := func(msg string) { c.Logger.Error(msg, "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, int(FAILED_TRAINING)) ResetClasses(c, model) } @@ -1588,7 +1588,7 @@ func trainExpandable(c *Context, model *BaseModel) { } // Set the class status to trained - err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING) + err = setModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) if err != nil { failed("Failed to set class status") return @@ -1648,7 +1648,7 @@ func RunTaskTrain(b BasePack, task Task) (err error) { if err != nil { l.Error("Failed to train model", "err", err) task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model") - ModelUpdateStatus(b, model.Id, FAILED_TRAINING) + ModelUpdateStatus(b, model.Id, int(FAILED_TRAINING)) return } @@ -1731,7 +1731,7 @@ func RunTaskRetrain(b BasePack, task Task) (err error) { l.Info("Model updaded") - _, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id) + _, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", CLASS_STATUS_TRAINED, CLASS_STATUS_TRAINING, model.Id) if err != nil { l.Error("Error while updating the classes", "error", err) failed() @@ -1861,7 +1861,7 @@ func handleTrain(handle *Handle) { c, "model_classes where model_id=$1 and status=$2 order by class_order asc", model.Id, - MODEL_CLASS_STATUS_TO_TRAIN, + CLASS_STATUS_TO_TRAIN, ) if err != nil { _err := c.RollbackTx() @@ -1882,7 +1882,7 @@ func handleTrain(handle *Handle) { //Update the classes { - _, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id) + _, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", CLASS_STATUS_TRAINING, CLASS_STATUS_TO_TRAIN, model.Id) if err != nil { _err := c.RollbackTx() if _err != nil { diff --git a/logic/tasks/README.md b/logic/tasks/README.md new file mode 100644 index 0000000..2e60da7 --- /dev/null +++ b/logic/tasks/README.md @@ -0,0 +1,7 @@ +# Runner Protocol + +``` + /----\ + \/ | +Register -> Init -> Active ---> Ready -> Info +``` diff --git a/logic/tasks/index.go b/logic/tasks/index.go index 3c8d59e..8a3a5ec 100644 --- a/logic/tasks/index.go +++ b/logic/tasks/index.go @@ -8,4 +8,5 @@ func HandleTasks(handle *Handle) { handleUpload(handle) handleList(handle) handleRequests(handle) + handleRemoteRunner(handle) } diff --git a/logic/tasks/runner.go b/logic/tasks/runner.go new file mode 100644 index 0000000..9f55e3b --- /dev/null +++ b/logic/tasks/runner.go @@ -0,0 +1,386 @@ +package tasks + +import ( + "sync" + "time" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) { + runner, err := GetRunner(c, dat.Id) + if err == NotFoundError { + e = c.JsonBadRequest("Could not find runner, please register runner first") + return + } else if err != nil { + e = c.E500M("Failed to get information about the runner", err) + return + } + + if runner.Token != *c.Token { + return nil, c.SendJSONStatus(401, "Only runners can use this funcion") + } + + return +} + +type VerifyTask struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` +} + +func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) { + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + if runners[dat.Id] == nil { + return nil, c.JsonBadRequest("Runner not active") + } + + var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) + + if runner_data["task"] == nil { + return nil, c.SendJSONStatus(404, "No active task") + } + + return runner_data["task"].(*Task), nil +} + +func handleRemoteRunner(x *Handle) { + + type RegisterRunner struct { + Token string `json:"token" validate:"required"` + Type RunnerType `json:"type" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/register", User_Normal, func(c *Context, dat *RegisterRunner) *Error { + if *c.Token != dat.Token { + // TODO do admin + return c.E500M("Please make sure that the token is the same that is being registered", nil) + } + + c.Logger.Info("test", "dat", dat) + + var runner Runner + err := GetDBOnce(c, &runner, "remote_runner as ru where token=$1", dat.Token) + if err != NotFoundError && err != nil { + return c.E500M("Failed to get information remote runners", err) + } + if err != NotFoundError { + return c.JsonBadRequest("Token is already registered by a runner") + } + + // TODO get id from token passed by when doing admin + var userId = c.User.Id + + var new_runner = struct { + Type RunnerType + UserId string `db:"user_id"` + Token string + }{ + Type: dat.Type, + Token: dat.Token, + UserId: userId, + } + + id, err := InsertReturnId(c, &new_runner, "remote_runner", "id") + if err != nil { + return c.E500M("Failed to create remote runner", err) + } + + return c.SendJSON(struct { + Id string `json:"id"` + }{ + Id: id, + }) + }) + + // TODO remove runner + + PostAuthJson(x, "/tasks/runner/init", User_Normal, func(c *Context, dat *JustId) *Error { + runner, error := verifyRunner(c, dat) + if error != nil { + return error + } + + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + if runners[dat.Id] != nil { + c.Logger.Info("Logger trying to register but already registerd") + c.ShowMessage = false + return c.SendJSON("Ok") + } + + var new_runner = map[string]interface{}{} + new_runner["last_time_check"] = time.Now() + new_runner["runner_info"] = runner + + runners[dat.Id] = new_runner + + x.DataMap["runners"] = runners + + return c.SendJSON("Ok") + }) + + PostAuthJson(x, "/tasks/runner/active", User_Normal, func(c *Context, dat *JustId) *Error { + _, error := verifyRunner(c, dat) + if error != nil { + return error + } + + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + if runners[dat.Id] == nil { + return c.JsonBadRequest("Runner not active") + } + + var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) + + if runner_data["task"] == nil { + c.ShowMessage = false + return c.SendJSONStatus(404, "No active task") + } + + c.ShowMessage = false + // This should be a task obj + return c.SendJSON(runner_data["task"]) + }) + + PostAuthJson(x, "/tasks/runner/ready", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + err := task.UpdateStatus(c, TASK_RUNNING, "Task Running on Runner") + if err != nil { + return c.E500M("Failed to set task status", err) + } + + return c.SendJSON("Ok") + }) + + type TaskFail struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + Reason string `json:"reason" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/fail", User_Normal, func(c *Context, dat *TaskFail) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId}) + if error != nil { + return error + } + + err := task.UpdateStatus(c, TASK_FAILED_RUNNING, dat.Reason) + if err != nil { + return c.E500M("Failed to set task status", err) + } + + // Do extra clean up on tasks + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + CleanUpFailed(c, task) + default: + panic("Do not know how to handle this") + } + + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) + runner_data["task"] = nil + + runners[dat.Id] = runner_data + x.DataMap["runners"] = runners + + return c.SendJSON("Ok") + }) + + PostAuthJson(x, "/tasks/runner/train/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + return c.E500M("Failed to get model information", err) + } + + defs, err := model.GetDefinitions(c, "and md.status=$2", DEFINITION_STATUS_INIT) + if err != nil { + return c.E500M("Failed to get the model definitions", err) + } + + return c.SendJSON(defs) + }) + + PostAuthJson(x, "/tasks/runner/train/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + return c.E500M("Failed to get model information", err) + } + + classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN) + if err != nil { + return c.E500M("Failed to get the model classes", err) + } + + return c.SendJSON(classes) + }) + + type RunnerTrainDefStatus struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + DefId string `json:"defId" validate:"required"` + Status DefinitionStatus `json:"status" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/train/def/status", User_Normal, func(c *Context, dat *RunnerTrainDefStatus) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId}) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + def, err := GetDefinition(c, dat.DefId) + if err != nil { + return c.E500M("Failed to get definition information", err) + } + + err = def.UpdateStatus(c, dat.Status) + if err != nil { + return c.E500M("Failed to update model status", err) + } + + return c.SendJSON("Ok") + }) + + type RunnerTrainDefLayers struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + DefId string `json:"defId" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId}) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + def, err := GetDefinition(c, dat.DefId) + if err != nil { + return c.E500M("Failed to get definition information", err) + } + + layers, err := def.GetLayers(c, " order by layer_order asc") + if err != nil { + return c.E500M("Failed to get layers", err) + } + + return c.SendJSON(layers) + }) + + PostAuthJson(x, "/tasks/runner/train/datapoints", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + return c.E500M("Failed to get model information", err) + } + + training_points, err := model.DataPoints(c, DATA_POINT_MODE_TRAINING) + if err != nil { + return c.E500M("Failed to get the model classes", err) + } + testing_points, err := model.DataPoints(c, DATA_POINT_MODE_TRAINING) + if err != nil { + return c.E500M("Failed to get the model classes", err) + } + + return c.SendJSON(struct { + Testing []DataPoint `json:"testing"` + Training []DataPoint `json:"training"` + }{ + Testing: testing_points, + Training: training_points, + }) + }) +} diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go index 4e6f967..7f53e48 100644 --- a/logic/tasks/runner/runner.go +++ b/logic/tasks/runner/runner.go @@ -5,6 +5,7 @@ import ( "math" "os" "runtime/debug" + "sync" "time" "github.com/charmbracelet/log" @@ -90,6 +91,45 @@ func runner(config Config, db db.Db, task_channel chan Task, index int, back_cha } } +/** +* Handle remote runner + */ +func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Task) { + logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportCaller: true, + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: fmt.Sprintf("Runner pre %s", runner_id), + }) + defer func() { + if r := recover(); r != nil { + logger.Error("Runner failed to setup for runner", "due to", r, "stack", string(debug.Stack())) + // TODO maybe create better failed task + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to setup task for runner") + } + }() + + err := task.UpdateStatus(base, TASK_PICKED_UP, "Failed to setup task for runner") + if err != nil { + logger.Error("Failed to mark task as PICK UP") + return + } + + mutex := handler.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + if err := PrepareTraining(handler, base, task, runner_id); err != nil { + logger.Error("Failed to prepare for training", "err", err) + } + default: + logger.Error("Not sure what to do panicing", "taskType", task.TaskType) + panic("not sure what to do") + } +} + /** * Tells the orcchestator to look at the task list from time to time */ @@ -125,7 +165,7 @@ func attentionSeeker(config Config, back_channel chan int) { /** * Manages what worker should to Work */ -func RunnerOrchestrator(db db.Db, config Config) { +func RunnerOrchestrator(db db.Db, config Config, handler *Handle) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, @@ -133,6 +173,16 @@ func RunnerOrchestrator(db db.Db, config Config) { Prefix: "Runner Orchestrator Logger", }) + // Setup vars + handler.DataMap["runners"] = map[string]interface{}{} + handler.DataMap["runners_mutex"] = &sync.Mutex{} + + base := BasePackStruct{ + Db: db, + Logger: logger, + Host: config.Hostname, + } + gpu_workers := config.GpuWorker.NumberOfWorkers logger.Info("Starting runners") @@ -149,7 +199,7 @@ func RunnerOrchestrator(db db.Db, config Config) { close(task_runners[x]) } close(back_channel) - go RunnerOrchestrator(db, config) + go RunnerOrchestrator(db, config, handler) } }() @@ -198,19 +248,45 @@ func RunnerOrchestrator(db db.Db, config Config) { } if task_to_dispatch != nil { - for i := 0; i < len(task_runners_used); i += 1 { - if !task_runners_used[i] { - task_runners[i] <- *task_to_dispatch - task_runners_used[i] = true + + // Only let CPU tasks be done by the local users + if task_to_dispatch.TaskType == int(TASK_TYPE_DELETE_USER) { + for i := 0; i < len(task_runners_used); i += 1 { + if !task_runners_used[i] { + task_runners[i] <- *task_to_dispatch + task_runners_used[i] = true + task_to_dispatch = nil + break + } + } + continue + } + + mutex := handler.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + remote_runners := handler.DataMap["runners"].(map[string]interface{}) + + for k, v := range remote_runners { + runner_data := v.(map[string]interface{}) + runner_info := runner_data["runner_info"].(*Runner) + + if runner_data["task"] != nil { + continue + } + + if runner_info.UserId == task_to_dispatch.UserId { + go handleRemoteTask(handler, base, k, *task_to_dispatch) task_to_dispatch = nil break } } + + mutex.Unlock() } } } -func StartRunners(db db.Db, config Config) { - go RunnerOrchestrator(db, config) +func StartRunners(db db.Db, config Config, handler *Handle) { + go RunnerOrchestrator(db, config, handler) } diff --git a/logic/tasks/utils/runner.go b/logic/tasks/utils/runner.go new file mode 100644 index 0000000..9521644 --- /dev/null +++ b/logic/tasks/utils/runner.go @@ -0,0 +1,29 @@ +package tasks_utils + +import ( + "time" + + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" + dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" +) + +type RunnerType int64 + +const ( + RUNNER_TYPE_GPU RunnerType = iota + 1 +) + +type Runner struct { + Id string `json:"id" db:"ru.id"` + UserId string `json:"user_id" db:"ru.user_id"` + Token string `json:"token" db:"ru.token"` + Type RunnerType `json:"type" db:"ru.type"` + CreateOn time.Time `json:"createOn" db:"ru.created_on"` +} + +func GetRunner(db db.Db, id string) (ru *Runner, err error) { + var runner Runner + err = dbtypes.GetDBOnce(db, &runner, "remote_runner as ru where ru.id=$1", id) + ru = &runner + return +} diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 4b35f81..7ec981b 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -374,7 +374,7 @@ func (c Context) JsonBadRequest(dat any) *Error { c.SetReportCaller(true) c.Logger.Warn("Request failed with a bad request", "dat", dat) c.SetReportCaller(false) - return c.ErrorCode(nil, 404, dat) + return c.SendJSONStatus(http.StatusBadRequest, dat) } func (c Context) JsonErrorBadRequest(err error, dat any) *Error { diff --git a/main.go b/main.go index 2433118..2ba1973 100644 --- a/main.go +++ b/main.go @@ -36,11 +36,11 @@ func main() { log.Info("Config loaded!", "config", config) config.GenerateToken(db) - StartRunners(db, config) - //TODO check if file structure exists to save data handle := NewHandler(db, config) + StartRunners(db, config, handle) + config.Cleanup(db) // TODO Handle this in other way diff --git a/nginx.dev.conf b/nginx.dev.conf index 242a37b..96e6229 100644 --- a/nginx.dev.conf +++ b/nginx.dev.conf @@ -13,7 +13,7 @@ http { server { listen 8000; - client_max_body_size 1G; + client_max_body_size 5G; location / { proxy_http_version 1.1; diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e0db44b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# tensorflow[and-cuda] == 2.15.1 +tensorflow[and-cuda] == 2.9.1 +pandas +# Make sure to install the nvidia pyindex first +# nvidia-pyindex +nvidia-cudnn diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..c887593 --- /dev/null +++ b/run.sh @@ -0,0 +1,2 @@ +#!/bin/bash +podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-server bash diff --git a/runner/.gitignore b/runner/.gitignore new file mode 100644 index 0000000..2f7896d --- /dev/null +++ b/runner/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/runner/Cargo.lock b/runner/Cargo.lock new file mode 100644 index 0000000..4305286 --- /dev/null +++ b/runner/Cargo.lock @@ -0,0 +1,1935 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "anyhow" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + +[[package]] +name = "backtrace" +version = "0.3.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" + +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "cc" +version = "1.0.96" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "h2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "hyper" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "socket2", + "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + +[[package]] +name = "ipnet" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.154" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" + +[[package]] +name = "linux-raw-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "miniz_oxide" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.48.0", +] + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "openssl" +version = "0.10.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +dependencies = [ + "bitflags 2.5.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "parking_lot" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.5", +] + +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags 2.5.0", +] + +[[package]] +name = "reqwest" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-tls", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "system-configuration", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "runner" +version = "0.1.0" +dependencies = [ + "anyhow", + "reqwest", + "serde", + "serde_json", + "serde_repr", + "tch", + "tokio", + "toml", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags 2.5.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pemfile" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +dependencies = [ + "base64", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" + +[[package]] +name = "rustls-webpki" +version = "0.102.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "ryu" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_repr" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + +[[package]] +name = "syn" +version = "2.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tch" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61fd89a98303b22acd6d4969b4c8940f7a30ba79af32b744a2028375d156e95a" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand", + "safetensors", + "thiserror", + "torch-sys", + "zip", +] + +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + +[[package]] +name = "thiserror" +version = "1.0.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "toml" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "torch-sys" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5997681f7f3700fa475f541fcda44c8959ea42a724194316fe7297cb96ebb08" +dependencies = [ + "anyhow", + "cc", + "libc", + "serde", + "serde_json", + "ureq", + "zip", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "log", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +dependencies = [ + "base64", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "rustls-webpki", + "serde", + "serde_json", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "winnow" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14b9415ee827af173ebb3f15f9083df5a122eb93572ec28741fb153356ea2578" +dependencies = [ + "memchr", +] + +[[package]] +name = "winreg" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/runner/Cargo.toml b/runner/Cargo.toml new file mode 100644 index 0000000..459ed1e --- /dev/null +++ b/runner/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "runner" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.82" +serde = { version = "1.0.200", features = ["derive"] } +toml = "0.8.12" +reqwest = { version = "0.12", features = ["json"] } +tokio = { version = "1", features = ["full"] } +serde_json = "1.0.116" +serde_repr = "0.1" +tch = { version = "0.16.0", features = ["download-libtorch"] } diff --git a/runner/Dockerfile b/runner/Dockerfile new file mode 100644 index 0000000..5685a0c --- /dev/null +++ b/runner/Dockerfile @@ -0,0 +1,12 @@ +FROM docker.io/nvidia/cuda:11.7.1-devel-ubuntu22.04 + +RUN apt-get update +RUN apt-get install -y curl + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +ENV PATH="$PATH:/root/.cargo/bin" +RUN rustup toolchain install stable + +RUN apt-get install -y pkg-config libssl-dev + +WORKDIR /app diff --git a/runner/config.toml b/runner/config.toml new file mode 100644 index 0000000..659a748 --- /dev/null +++ b/runner/config.toml @@ -0,0 +1,3 @@ +hostname = "https://testing.andr3h3nriqu3s.com/api" +token = "d2bc41e8293937bcd9397870c98f97acc9603f742924b518e193cd1013e45d57897aa302b364001c72b458afcfb34239dfaf38a66b318e5cbc973eea" +data_path = "/home/andr3/Documents/my-repos/fyp" diff --git a/runner/data.toml b/runner/data.toml new file mode 100644 index 0000000..a52e09b --- /dev/null +++ b/runner/data.toml @@ -0,0 +1 @@ +id = "a7cec9e9-1d05-4633-8bc5-6faabe4fd5a3" diff --git a/runner/run.sh b/runner/run.sh new file mode 100755 index 0000000..4b5a346 --- /dev/null +++ b/runner/run.sh @@ -0,0 +1,2 @@ +#!/bin/bash +podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-runner bash diff --git a/runner/src/dataloader.rs b/runner/src/dataloader.rs new file mode 100644 index 0000000..0640af8 --- /dev/null +++ b/runner/src/dataloader.rs @@ -0,0 +1,106 @@ +use crate::{model::DataPoint, settings::ConfigFile}; +use std::{path::Path, sync::Arc}; +use tch::Tensor; + +pub struct DataLoader { + pub batch_size: i64, + pub len: usize, + pub inputs: Vec, + pub labels: Vec, + pub pos: usize, +} + +impl DataLoader { + pub fn new( + config: Arc, + data: Vec, + classes_len: i64, + batch_size: i64, + ) -> DataLoader { + let len: f64 = (data.len() as f64) / (batch_size as f64); + let min_len: i64 = len.floor() as i64; + let max_len: i64 = len.ceil() as i64; + + let base_path = Path::new(&config.data_path); + + let mut inputs: Vec = Vec::new(); + let mut all_labels: Vec = Vec::new(); + + for batch in 0..min_len { + let mut batch_acc: Vec = Vec::new(); + let mut labels: Vec = Vec::new(); + for image in 0..batch_size { + let i: usize = (batch * batch_size + image).try_into().unwrap(); + let item = &data[i]; + batch_acc.push( + tch::vision::image::load(base_path.join(&item.path)) + .ok() + .unwrap(), + ); + + if item.class >= 0 { + let t = tch::Tensor::from_slice(&[item.class]) + .onehot(classes_len.try_into().unwrap()); + labels.push(t); + } else { + labels.push(tch::Tensor::zeros( + (classes_len), + (tch::Kind::Float, tch::Device::Cpu), + )) + } + } + inputs.push(tch::Tensor::cat(&batch_acc[0..], 0)); + all_labels.push(tch::Tensor::cat(&labels[0..], 0)); + } + + if min_len != max_len { + let mut batch_acc: Vec = Vec::new(); + let mut labels: Vec = Vec::new(); + for image in 0..(data.len() - (batch_size * min_len) as usize) { + let i: usize = (min_len * batch_size + (image as i64)) as usize; + let item = &data[i]; + batch_acc.push( + tch::vision::image::load(base_path.join(&item.path)) + .ok() + .unwrap(), + ); + + if item.class >= 0 { + let t = tch::Tensor::from_slice(&[item.class]).onehot(classes_len); + labels.push(t); + } else { + labels.push(tch::Tensor::zeros( + classes_len, + (tch::Kind::Float, tch::Device::Cpu), + )) + } + } + inputs.push(tch::Tensor::cat(&batch_acc[0..], 0)); + all_labels.push(tch::Tensor::cat(&labels[0..], 0)); + } + + return DataLoader { + batch_size, + inputs, + labels: all_labels, + len: max_len as usize, + pos: 0, + }; + } + + pub fn restart(self: &mut DataLoader) { + self.pos = 0; + } + + pub fn next(self: &mut DataLoader) -> Option<(Tensor, Tensor)> { + if self.pos >= self.len { + return None; + } + let input = self.inputs[self.pos].empty_like(); + self.inputs[self.pos] = self.inputs[self.pos].clone(&input); + let label = self.labels[self.pos].empty_like(); + self.labels[self.pos] = self.labels[self.pos].clone(&label); + + return Some((input, label)); + } +} diff --git a/runner/src/main.rs b/runner/src/main.rs new file mode 100644 index 0000000..b3ae49b --- /dev/null +++ b/runner/src/main.rs @@ -0,0 +1,206 @@ +mod dataloader; +mod model; +mod settings; +mod tasks; +mod training; +mod types; + +use crate::settings::*; +use crate::tasks::{fail_task, Task, TaskType}; +use crate::training::handle_train; +use anyhow::{bail, Result}; +use reqwest::StatusCode; +use serde_json::json; +use std::{fs, process::exit, sync::Arc, time::Duration}; + +enum ResultAlive { + Ok, + Error, + NotInit, +} + +async fn send_keep_alive_message( + config: Arc, + runner_data: Arc, +) -> ResultAlive { + let client = reqwest::Client::new(); + + let to_send = json!({ + "id": runner_data.id, + }); + + let resp = client + .post(format!("{}/tasks/runner/beat", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await; + + if resp.is_err() { + return ResultAlive::Error; + } + + let resp = resp.ok(); + + if resp.is_none() { + return ResultAlive::Error; + } + + let resp = resp.unwrap(); + + // TODO see if the message is related to not being inited + if resp.status() != 200 { + println!("Could not connect with the status"); + return ResultAlive::Error; + } + + ResultAlive::Ok +} + +async fn keep_alive(config: Arc, runner_data: Arc) -> Result<()> { + let mut failed = 0; + loop { + match send_keep_alive_message(config.clone(), runner_data.clone()).await { + ResultAlive::Error => failed += 1, + ResultAlive::NotInit => { + println!("Runner not inited! Restarting!"); + exit(1) + } + ResultAlive::Ok => failed = 0, + } + + // TODO move to config + if failed > 20 { + println!("Failed to connect to API! More than 20 times in a row stoping"); + exit(1) + } + + tokio::time::sleep(Duration::from_secs(1)).await; + } +} + +async fn handle_task( + task: Task, + config: Arc, + runner_data: Arc, +) -> Result<()> { + let res = match task.task_type { + TaskType::Training => handle_train(&task, config.clone(), runner_data.clone()).await, + _ => { + println!("Do not know how to handle this task #{:?}", task); + bail!("Failed") + } + }; + + if res.is_err() { + println!("task failed #{:?}", res); + fail_task( + &task, + config, + runner_data, + "Do not know how to handle this kind of task", + ) + .await? + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<()> { + // Load config file + let config_data = fs::read_to_string("./config.toml")?; + let mut config: ConfigFile = toml::from_str(&config_data)?; + + let client = reqwest::Client::new(); + if config.config_path == None { + config.config_path = Some(String::from("./data.toml")) + } + + let runner_data: RunnerData = load_runner_data(&config).await?; + + let to_send = json!({ + "id": runner_data.id, + }); + + // Inform the server that the runner is available + let resp = client + .post(format!("{}/tasks/runner/init", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await?; + + if resp.status() != 200 { + println!( + "Could not connect with the api: status {} body {}", + resp.status(), + resp.text().await? + ); + return Ok(()); + } + + let res = resp.json::().await?; + if res != "Ok" { + print!("Unexpected problem: {}", res); + return Ok(()); + } + + let config = Arc::new(config); + let runner_data = Arc::new(runner_data); + + let config_alive = config.clone(); + let runner_data_alive = runner_data.clone(); + std::thread::spawn(move || keep_alive(config_alive.clone(), runner_data_alive.clone())); + + println!("Started main loop"); + loop { + //TODO move time to config + tokio::time::sleep(Duration::from_secs(1)).await; + + let to_send = json!({ "id": runner_data.id }); + + let resp = client + .post(format!("{}/tasks/runner/active", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await; + + if resp.is_err() || resp.as_ref().ok().is_none() { + println!("Failed to get info from server {:?}", resp); + continue; + } + + let resp = resp?; + + match resp.status() { + // No active task + StatusCode::NOT_FOUND => (), + StatusCode::OK => { + println!("Found task!"); + + let task: Result = resp.json().await; + if task.is_err() || task.as_ref().ok().is_none() { + println!("Failed to resolve the json {:?}", task); + continue; + } + + let task = task?; + + let res = handle_task(task, config.clone(), runner_data.clone()).await; + + if res.is_err() || res.as_ref().ok().is_none() { + println!("Failed to run the task"); + } + + _ = res; + () + } + _ => { + println!("Unexpected error #{:?}", resp); + exit(1) + } + } + } +} diff --git a/runner/src/model/mod.rs b/runner/src/model/mod.rs new file mode 100644 index 0000000..2b0536d --- /dev/null +++ b/runner/src/model/mod.rs @@ -0,0 +1,99 @@ +use anyhow::bail; +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use tch::{ + nn::{self, Module}, + Device, +}; + +#[derive(Debug)] +pub struct Model { + pub vs: nn::VarStore, + pub seq: nn::Sequential, + pub layers: Vec, +} + +#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)] +#[repr(i8)] +pub enum LayerType { + Input = 1, + Dense = 2, + Flatten = 3, + SimpleBlock = 4, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Layer { + pub id: String, + pub definition_id: String, + pub layer_order: String, + pub layer_type: LayerType, + pub shape: String, + pub exp_type: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DataPoint { + pub class: i64, + pub path: String, +} + +pub fn build_model(layers: Vec, last_linear_size: i64, add_sigmoid: bool) -> Model { + let vs = nn::VarStore::new(Device::Cpu); + + let mut seq = nn::seq(); + + let mut last_linear_size = last_linear_size; + let mut last_linear_conv: Vec = Vec::new(); + + for layer in layers.iter() { + match layer.layer_type { + LayerType::Input => { + last_linear_conv = serde_json::from_str(&layer.shape).ok().unwrap(); + println!("Layer: Input, In: {:?}", last_linear_conv); + } + LayerType::Dense => { + let shape: Vec = serde_json::from_str(&layer.shape).ok().unwrap(); + println!("Layer: Dense, In: {}, Out: {}", last_linear_size, shape[0]); + seq = seq + .add(nn::linear( + &vs.root(), + last_linear_size, + shape[0], + Default::default(), + )) + .add_fn(|xs| xs.relu()); + last_linear_size = shape[0]; + } + LayerType::Flatten => { + seq = seq.add_fn(|xs| xs.flatten(1, -1)); + last_linear_size = 1; + for i in &last_linear_conv { + last_linear_size *= i; + } + println!( + "Layer: flatten, In: {:?}, Out: {}", + last_linear_conv, last_linear_size + ) + } + LayerType::SimpleBlock => { + panic!("DO not create Simple blocks yet"); + let new_last_linear_conv = + vec![128, last_linear_conv[1] / 2, last_linear_conv[2] / 2]; + println!( + "Layer: block, In: {:?}, Put: {:?}", + last_linear_conv, new_last_linear_conv, + ); + //TODO + //m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0])) + last_linear_conv = new_last_linear_conv; + } + } + } + + if add_sigmoid { + seq = seq.add_fn(|xs| xs.sigmoid()); + } + + return Model { vs, layers, seq }; +} diff --git a/runner/src/settings.rs b/runner/src/settings.rs new file mode 100644 index 0000000..a9c3603 --- /dev/null +++ b/runner/src/settings.rs @@ -0,0 +1,57 @@ +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::{fs, path}; + +#[derive(Deserialize)] +pub struct ConfigFile { + // Hostname to connect with the api + pub hostname: String, + // Token used in the api to authenticate + pub token: String, + // Path to where to store some generated configuration values + // defaults to ./data.toml + pub config_path: Option, + // Data Path + // Path to where the data is mounted + pub data_path: String, +} + +#[derive(Deserialize, Serialize)] +pub struct RunnerData { + pub id: String, +} + +pub async fn load_runner_data(config: &ConfigFile) -> Result { + let data_path = config.config_path.as_ref().unwrap(); + let data_path = path::Path::new(&*data_path); + + if data_path.exists() { + let runner_data = fs::read_to_string(data_path)?; + Ok(toml::from_str(&runner_data)?) + } else { + let client = reqwest::Client::new(); + let to_send = json!({ + "token": config.token, + "type": 1, + }); + + let register_resp = client + .post(format!("{}/tasks/runner/register", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await?; + + if register_resp.status() != 200 { + bail!(format!("Could not create runner {:#?}", register_resp)); + } + + let runner_data: RunnerData = register_resp.json().await?; + + fs::write(data_path, toml::to_string(&runner_data)?) + .expect("Faield to save data for runner"); + + Ok(runner_data) + } +} diff --git a/runner/src/tasks.rs b/runner/src/tasks.rs new file mode 100644 index 0000000..60d61dc --- /dev/null +++ b/runner/src/tasks.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use serde::Deserialize; +use serde_json::json; +use serde_repr::Deserialize_repr; + +use crate::{ConfigFile, RunnerData}; + +#[derive(Clone, Copy, Deserialize_repr, Debug)] +#[repr(i8)] +pub enum TaskStatus { + FailedRunning = -2, + FailedCreation = -1, + Preparing = 0, + Todo = 1, + PickedUp = 2, + Running = 3, + Done = 4, +} + +#[derive(Clone, Copy, Deserialize_repr, Debug)] +#[repr(i8)] +pub enum TaskType { + Classification = 1, + Training = 2, + Retraining = 3, + DeleteUser = 4, +} + +#[derive(Deserialize, Debug)] +pub struct Task { + pub id: String, + pub user_id: String, + pub model_id: String, + pub status: TaskStatus, + pub status_message: String, + pub user_confirmed: i8, + pub compacted: i8, + #[serde(alias = "type")] + pub task_type: TaskType, + pub extra_task_info: String, + pub result: String, + pub created: String, +} + +pub async fn fail_task( + task: &Task, + config: Arc, + runner_data: Arc, + reason: &str, +) -> Result<()> { + println!("Marking Task as faield"); + + let client = reqwest::Client::new(); + + let to_send = json!({ + "id": runner_data.id, + "taskId": task.id, + "reason": reason + }); + + let resp = client + .post(format!("{}/tasks/runner/fail", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await?; + + if resp.status() != 200 { + println!("Failed to update status of task"); + bail!("Failed to update status of task"); + } + + Ok(()) +} + +impl Task { + pub async fn fail( + self: &mut Task, + config: Arc, + runner_data: Arc, + reason: &str, + ) -> Result<()> { + fail_task(self, config, runner_data, reason).await?; + self.status = TaskStatus::FailedRunning; + self.status_message = reason.to_string(); + Ok(()) + } +} diff --git a/runner/src/training.rs b/runner/src/training.rs new file mode 100644 index 0000000..848a7cf --- /dev/null +++ b/runner/src/training.rs @@ -0,0 +1,544 @@ +use crate::{ + dataloader::DataLoader, + model::{self, build_model}, + settings::{ConfigFile, RunnerData}, + tasks::{fail_task, Task}, + types::{DataPointRequest, Definition, ModelClass}, +}; +use std::sync::Arc; + +use anyhow::Result; +use serde_json::json; +use tch::{ + nn::{self, Module, OptimizerConfig}, + Tensor, +}; + +pub async fn handle_train( + task: &Task, + config: Arc, + runner_data: Arc, +) -> Result<()> { + let client = reqwest::Client::new(); + println!("Preparing to train a model"); + + let to_send = json!({ + "id": runner_data.id, + "taskId": task.id, + }); + + let mut defs: Vec = client + .post(format!("{}/tasks/runner/train/defs", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await? + .json() + .await?; + + let classes: Vec = client + .post(format!("{}/tasks/runner/train/classes", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await? + .json() + .await?; + + let data: DataPointRequest = client + .post(format!("{}/tasks/runner/train/datapoints", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await? + .json() + .await?; + + let mut data_loader = DataLoader::new(config.clone(), data.testing, classes.len() as i64, 64); + + // TODO make this a vec + let mut model: Option = None; + + loop { + let config = config.clone(); + let runner_data = runner_data.clone(); + let mut to_remove: Vec = Vec::new(); + + let mut def_iter = defs.iter_mut(); + + let mut i: usize = 0; + while let Some(def) = def_iter.next() { + def.updateStatus( + task, + config.clone(), + runner_data.clone(), + crate::types::DefinitionStatus::Training, + ) + .await?; + + let model_err = train_definition( + def, + &mut data_loader, + model, + config.clone(), + runner_data.clone(), + &task, + ) + .await; + + if model_err.is_err() { + println!("Failed to create model {:?}", model_err); + model = None; + to_remove.push(i); + continue; + } + + model = model_err?; + + i += 1; + } + + defs = defs + .into_iter() + .enumerate() + .filter(|&(i, _)| to_remove.iter().any(|b| *b == i)) + .map(|(_, e)| e) + .collect(); + + break; + } + + fail_task(task, config, runner_data, "TODO").await?; + Ok(()) + + /* + for { + // Keep track of definitions that did not train fast enough + var toRemove ToRemoveList = []int{} + + for i, def := range definitions { + + accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes) + if err != nil { + log.Error("Failed to train definition!Err:", "err", err) + def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) + toRemove = append(toRemove, i) + continue + } + models[def.Id] = ml_model + + if accuracy >= float64(def.TargetAccuracy) { + log.Info("Found a definition that reaches target_accuracy!") + _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id) + if err != nil { + log.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return err + } + + _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING) + if err != nil { + log.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return err + } + + finished = true + break + } + + if def.Epoch > MAX_EPOCH { + fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy) + def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) + toRemove = append(toRemove, i) + continue + } + + _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id) + if err != nil { + log.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return err + } + } + + if finished { + break + } + + sort.Sort(sort.Reverse(toRemove)) + + log.Info("Round done", "toRemove", toRemove) + + for _, n := range toRemove { + // Clean up unsed models + models[definitions[n].Id] = nil + definitions = remove(definitions, n) + } + + len_def := len(definitions) + + if len_def == 0 { + break + } + + if len_def == 1 { + continue + } + + sort.Sort(sort.Reverse(definitions)) + + acc := definitions[0].Accuracy - 20.0 + + log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc) + + toRemove = []int{} + for i, def := range definitions { + if def.Accuracy < acc { + toRemove = append(toRemove, i) + } + } + + log.Info("Removing due to accuracy", "toRemove", toRemove) + + sort.Sort(sort.Reverse(toRemove)) + for _, n := range toRemove { + log.Warn("Removing definition not fast enough learning", "n", n) + definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) + models[definitions[n].Id] = nil + definitions = remove(definitions, n) + } + } + + var def Definition + err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED) + if err != nil { + if err == NotFoundError { + log.Error("All definitions failed to train!") + } else { + log.Error("DB: failed to read definition", "err", err) + } + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil { + log.Error("Failed to update model definition", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id) + if err != nil { + log.Error("Failed to select model_definition to delete") + log.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + defer to_delete.Close() + + for to_delete.Next() { + var id string + if err = to_delete.Scan(&id); err != nil { + log.Error("Failed to scan the id of a model_definition to delete", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) + } + + // TODO Check if returning also works here + if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil { + log.Error("Failed to delete model_definition") + log.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + ModelUpdateStatus(c, model.Id, READY) + + return + */ +} + +async fn train_definition( + def: &Definition, + data_loader: &mut DataLoader, + model: Option, + config: Arc, + runner_data: Arc, + task: &Task, +) -> Result> { + let client = reqwest::Client::new(); + println!("About to start training definition"); + + let mut accuracy = 0; + + let model = model.unwrap_or({ + let layers: Vec = client + .post(format!("{}/tasks/runner/train/def/layers", config.hostname)) + .header("token", &config.token) + .body( + json!({ + "id": runner_data.id, + "taskId": task.id, + "defId": def.id, + }) + .to_string(), + ) + .send() + .await? + .json() + .await?; + + build_model(layers, 0, true) + }); + + println!("here1!"); + + // TODO CUDA + // get device + // Move model to cuda + + let mut opt = nn::Adam::default().build(&model.vs, 1e-5)?; + + println!("here2!"); + + for epoch in 1..20 { + data_loader.restart(); + while let Some((inputs, labels)) = data_loader.next() { + let inputs = inputs.to_kind(tch::Kind::Float); + let labels = labels.to_kind(tch::Kind::Float); + println!("ins: {:?} labels: {:?}", inputs.size(), labels.size()); + let out = model.seq.forward(&inputs); + let weight: Option = None; + let loss = out.binary_cross_entropy(&labels, weight, tch::Reduction::Mean); + opt.backward_step(&loss); + println!("out: {:?}", out); + } + } + + return Ok(Some(model)); + /* + + opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001) + if err != nil { + return + } + + for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ { + var trainIter *torch.Iter2 + trainIter, err = ds.TrainIter(32) + if err != nil { + return + } + // trainIter.ToDevice(device) + + log.Info("epoch", "epoch", epoch) + + var trainLoss float64 = 0 + var trainCorrect float64 = 0 + ok := true + for ok { + var item torch.Iter2Item + var loss *torch.Tensor + item, ok = trainIter.Next() + if !ok { + continue + } + + data := item.Data + data, err = data.ToDevice(device, gotch.Float, false, true, false) + if err != nil { + return + } + + var size []int64 + size, err = data.Size() + if err != nil { + return + } + + var zeros *torch.Tensor + zeros, err = torch.Zeros(size, gotch.Float, device) + if err != nil { + return + } + + data, err = zeros.Add(data, true) + if err != nil { + return + } + + log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) + + data, err = data.SetRequiresGrad(true, false) + if err != nil { + return + } + + log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) + + err = data.RetainGrad(false) + if err != nil { + return + } + + log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) + + pred := model.ForwardT(data, true) + pred, err = pred.SetRequiresGrad(true, true) + if err != nil { + return + } + + err = pred.RetainGrad(false) + if err != nil { + return + } + + label := item.Label + label, err = label.ToDevice(device, gotch.Float, false, true, false) + if err != nil { + return + } + label, err = label.SetRequiresGrad(true, true) + if err != nil { + return + } + err = label.RetainGrad(false) + if err != nil { + return + } + + // Calculate loss + loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false) + if err != nil { + return + } + loss, err = loss.SetRequiresGrad(true, false) + if err != nil { + return + } + err = loss.RetainGrad(false) + if err != nil { + return + } + + err = opt.ZeroGrad() + if err != nil { + return + } + + err = loss.Backward() + if err != nil { + return + } + + log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values()) + log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values()) + log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false)) + + vars := model.Vs.Variables() + + for k, v := range vars { + log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false)) + } + + model.Debug() + + err = opt.Step() + if err != nil { + return + } + + trainLoss = loss.Float64Values()[0] + + // Calculate accuracy + / *var p_pred, p_labels *torch.Tensor + p_pred, err = pred.Argmax([]int64{1}, true, false) + if err != nil { + return + } + + p_labels, err = item.Label.Argmax([]int64{1}, true, false) + if err != nil { + return + } + + floats := p_pred.Float64Values() + floats_labels := p_labels.Float64Values() + + for i := range floats { + if floats[i] == floats_labels[i] { + trainCorrect += 1 + } + } * / + + panic("fornow") + } + + //v := []float64{} + + log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize)) + + / *correct := int64(0) + //torch.NoGrad(func() { + ok = true + testIter := ds.TestIter(64) + for ok { + var item torch.Iter2Item + item, ok = testIter.Next() + if !ok { + continue + } + + output := model.Forward(item.Data) + + var pred, labels *torch.Tensor + pred, err = output.Argmax([]int64{1}, true, false) + if err != nil { + return + } + + labels, err = item.Label.Argmax([]int64{1}, true, false) + if err != nil { + return + } + + floats := pred.Float64Values() + floats_labels := labels.Float64Values() + + for i := range floats { + if floats[i] == floats_labels[i] { + correct += 1 + } + } + } + + accuracy = float64(correct) / float64(ds.TestImagesSize) + + log.Info("Eval accuracy", "accuracy", accuracy) + + err = def.UpdateAfterEpoch(db, accuracy*100) + if err != nil { + return + }* / + //}) + } + + result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id) + err = os.MkdirAll(result_path, os.ModePerm) + if err != nil { + return + } + + err = my_torch.SaveModel(model, path.Join(result_path, "model.dat")) + if err != nil { + return + } + + log.Info("Model finished training!", "accuracy", accuracy) + return + */ +} diff --git a/runner/src/types.rs b/runner/src/types.rs new file mode 100644 index 0000000..b5fd4a4 --- /dev/null +++ b/runner/src/types.rs @@ -0,0 +1,89 @@ +use crate::{model, tasks::Task, ConfigFile, RunnerData}; +use anyhow::{bail, Result}; +use serde::Deserialize; +use serde_json::json; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::sync::Arc; + +#[derive(Clone, Copy, Deserialize_repr, Serialize_repr, Debug)] +#[repr(i8)] +pub enum DefinitionStatus { + CanceldTraining = -4, + FailedTraining = -3, + PreInit = 1, + Init = 2, + Training = 3, + PausedTraining = 6, + Tranied = 4, + Ready = 5, +} + +#[derive(Deserialize, Debug)] +pub struct Definition { + pub id: String, + pub model_id: String, + pub accuracy: f64, + pub target_accuracy: i64, + pub epoch: i64, + pub status: i64, + pub created: String, + pub epoch_progress: i64, +} + +impl Definition { + pub async fn updateStatus( + self: &mut Definition, + task: &Task, + config: Arc, + runner_data: Arc, + status: DefinitionStatus, + ) -> Result<()> { + println!("Marking Task as faield"); + + let client = reqwest::Client::new(); + + let to_send = json!({ + "id": runner_data.id, + "taskId": task.id, + "defId": self.id, + "status": status, + }); + + let resp = client + .post(format!("{}/tasks/runner/train/def/status", config.hostname)) + .header("token", &config.token) + .body(to_send.to_string()) + .send() + .await?; + + if resp.status() != 200 { + println!("Failed to update status of task"); + bail!("Failed to update status of task"); + } + + Ok(()) + } +} + +#[derive(Clone, Copy, Deserialize_repr, Debug)] +#[repr(i8)] +pub enum ModelClassStatus { + ToTrain = 1, + Training = 2, + Trained = 3, +} + +#[derive(Deserialize, Debug)] +pub struct ModelClass { + pub id: String, + pub model_id: String, + pub name: String, + pub class_order: i64, + pub status: ModelClassStatus, +} + +#[derive(Deserialize, Debug)] +pub struct DataPointRequest { + pub testing: Vec, + pub training: Vec, +} diff --git a/sql/tasks.sql b/sql/tasks.sql index 8248ade..e9aa445 100644 --- a/sql/tasks.sql +++ b/sql/tasks.sql @@ -38,3 +38,14 @@ create table if not exists tasks_dependencies ( main_id uuid references tasks (id) on delete cascade not null, dependent_id uuid references tasks (id) on delete cascade not null ); + +create table if not exists remote_runner ( + id uuid primary key default gen_random_uuid(), + user_id uuid references users (id) on delete cascade not null, + token text not null, + + -- 1: GPU + type integer, + + created_on timestamp default current_timestamp +); diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 9542109..1935e98 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -82,7 +82,7 @@ def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset: def filterDataset(path): path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") - + {{ if eq .Model.Format "png" }} path = tf.strings.regex_replace(path, ".png", "") {{ else if eq .Model.Format "jpeg" }} @@ -90,7 +90,7 @@ def filterDataset(path): {{ else }} ERROR {{ end }} - + return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1 seed = random.randint(0, 100000000) @@ -135,9 +135,9 @@ def addBlock( model.add(layers.ReLU()) if top: if pooling_same: - model.add(pool_func(padding="same", strides=(1, 1))) + model.add(pool_func(pool_size=(2,2), padding="same", strides=(1, 1))) else: - model.add(pool_func()) + model.add(pool_func(pool_size=(2,2))) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.4)) @@ -172,7 +172,7 @@ model.compile( his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[ NotifyServerCallback(), - tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)], use_multiprocessing = True) + tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)]) acc = his.history["accuracy"]