diff --git a/go.mod b/go.mod index 9ac9940..b697e3c 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,11 @@ require ( github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 golang.org/x/crypto v0.19.0 + github.com/BurntSushi/toml v1.3.2 + github.com/goccy/go-json v0.10.2 ) require ( - github.com/BurntSushi/toml v1.3.2 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/lipgloss v0.9.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect @@ -20,7 +21,6 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.19.0 // indirect - github.com/goccy/go-json v0.10.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx v3.6.2+incompatible // indirect diff --git a/logic/db_types/definitions.go b/logic/db_types/definitions.go index d511d32..28392e0 100644 --- a/logic/db_types/definitions.go +++ b/logic/db_types/definitions.go @@ -87,9 +87,9 @@ func (d Definition) GetLayers(db db.Db, filter string, args ...any) (layer []*La return GetDbMultitple[Layer](db, "model_definition_layer as mdl where mdl.def_id=$1 "+filter, args...) } -func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64) (err error) { +func (d *Definition) UpdateAfterEpoch(db db.Db, accuracy float64, epoch int) (err error) { d.Accuracy = accuracy - d.Epoch += 1 + d.Epoch += epoch _, err = db.Exec("update model_definition set epoch=$1, accuracy=$2 where id=$3", d.Epoch, d.Accuracy, d.Id) return } diff --git a/logic/tasks/runner.go b/logic/tasks/runner.go index 9f55e3b..bfce53f 100644 --- a/logic/tasks/runner.go +++ b/logic/tasks/runner.go @@ -1,6 +1,8 @@ package tasks import ( + "os" + "path" "sync" "time" @@ -383,4 +385,149 @@ func handleRemoteRunner(x *Handle) { Training: training_points, }) }) + + type RunnerTrainDefEpoch struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + DefId string `json:"defId" validate:"required"` + Epoch int `json:"epoch" validate:"required"` + Accuracy float64 `json:"accuracy" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/train/epoch", User_Normal, func(c *Context, dat *RunnerTrainDefEpoch) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{ + Id: dat.Id, + TaskId: dat.TaskId, + }) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + def, err := GetDefinition(c, dat.DefId) + if err != nil { + return c.E500M("Failed to get definition information", err) + } + + err = def.UpdateAfterEpoch(c, dat.Accuracy, dat.Epoch) + if err != nil { + return c.E500M("Failed to update model", err) + } + + return c.SendJSON("Ok") + }) + + PostAuthJson(x, "/task/runner/train/mark-failed", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{ + Id: dat.Id, + TaskId: dat.TaskId, + }) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + _, err := c.Exec( + "update model_definition set status=$1 "+ + "where model_id=$2 and status in ($3, $4)", + MODEL_DEFINITION_STATUS_CANCELD_TRAINING, + task.ModelId, + MODEL_DEFINITION_STATUS_TRAINING, + MODEL_DEFINITION_STATUS_PAUSED_TRAINING, + ) + if err != nil { + return c.E500M("Failed to mark definition as failed", err) + } + + return c.SendJSON("Ok") + }) + + PostAuthJson(x, "/task/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + c.Logger.Error("Failed to get model", "err", err) + return c.E500M("Failed to get mode", err) + } + + var def Definition + err = GetDBOnce(c, &def, "from model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED) + if err == NotFoundError { + // TODO Make the Model status have a message + c.Logger.Error("All definitions failed to train!") + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!") + return c.SendJSON("Ok") + } else if err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition") + return c.E500M("Failed to get model definition", err) + } + + if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition") + return c.E500M("Failed to update model definition", err) + } + + to_delete, err := c.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id) + if err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") + return c.E500M("Failed to delete unsed definitions", err) + } + defer to_delete.Close() + + for to_delete.Next() { + var id string + if err = to_delete.Scan(&id); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") + return c.E500M("Failed to delete unsed definitions", err) + } + os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) + } + + // TODO Check if returning also works here + if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") + return c.E500M("Failed to delete unsed definitions", err) + } + + model.UpdateStatus(c, READY) + + return c.SendJSON("Ok") + }) } diff --git a/runner/.gitignore b/runner/.gitignore deleted file mode 100644 index 2f7896d..0000000 --- a/runner/.gitignore +++ /dev/null @@ -1 +0,0 @@ -target/ diff --git a/runner/Cargo.lock b/runner/Cargo.lock deleted file mode 100644 index c2b5d00..0000000 --- a/runner/Cargo.lock +++ /dev/null @@ -1,1936 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "addr2line" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "anyhow" -version = "1.0.82" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" - -[[package]] -name = "autocfg" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" - -[[package]] -name = "backtrace" -version = "0.3.71" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "bumpalo" -version = "3.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "bytes" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" - -[[package]] -name = "bzip2" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" -dependencies = [ - "bzip2-sys", - "libc", -] - -[[package]] -name = "bzip2-sys" -version = "0.1.11+1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - -[[package]] -name = "cc" -version = "1.0.96" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "065a29261d53ba54260972629f9ca6bffa69bac13cd1fed61420f7fa68b9f8bd" -dependencies = [ - "jobserver", - "libc", - "once_cell", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - -[[package]] -name = "constant_time_eq" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" - -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" - -[[package]] -name = "cpufeatures" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" -dependencies = [ - "libc", -] - -[[package]] -name = "crc32fast" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" - -[[package]] -name = "crunchy" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "typenum", -] - -[[package]] -name = "deranged" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" -dependencies = [ - "powerfmt", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "crypto-common", - "subtle", -] - -[[package]] -name = "encoding_rs" -version = "0.8.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "errno" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "fastrand" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" - -[[package]] -name = "flate2" -version = "1.0.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - -[[package]] -name = "form_urlencoded" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" -dependencies = [ - "percent-encoding", -] - -[[package]] -name = "futures-channel" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" -dependencies = [ - "futures-core", -] - -[[package]] -name = "futures-core" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" - -[[package]] -name = "futures-sink" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" - -[[package]] -name = "futures-task" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" - -[[package]] -name = "futures-util" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" -dependencies = [ - "futures-core", - "futures-task", - "pin-project-lite", - "pin-utils", -] - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "getrandom" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "gimli" -version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" - -[[package]] -name = "h2" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "half" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" -dependencies = [ - "cfg-if", - "crunchy", -] - -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "http" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - -[[package]] -name = "http-body" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" -dependencies = [ - "bytes", - "http", -] - -[[package]] -name = "http-body-util" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" -dependencies = [ - "bytes", - "futures-core", - "http", - "http-body", - "pin-project-lite", -] - -[[package]] -name = "httparse" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" - -[[package]] -name = "hyper" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "h2", - "http", - "http-body", - "httparse", - "itoa", - "pin-project-lite", - "smallvec", - "tokio", - "want", -] - -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", -] - -[[package]] -name = "hyper-util" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "http", - "http-body", - "hyper", - "pin-project-lite", - "socket2", - "tokio", - "tower", - "tower-service", - "tracing", -] - -[[package]] -name = "idna" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - -[[package]] -name = "indexmap" -version = "2.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" -dependencies = [ - "equivalent", - "hashbrown", -] - -[[package]] -name = "inout" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" -dependencies = [ - "generic-array", -] - -[[package]] -name = "ipnet" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" - -[[package]] -name = "itoa" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" - -[[package]] -name = "jobserver" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" -dependencies = [ - "libc", -] - -[[package]] -name = "js-sys" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "libc" -version = "0.2.154" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" - -[[package]] -name = "linux-raw-sys" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" - -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" - -[[package]] -name = "matrixmultiply" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" -dependencies = [ - "autocfg", - "rawpointer", -] - -[[package]] -name = "memchr" -version = "2.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" - -[[package]] -name = "mime" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" - -[[package]] -name = "miniz_oxide" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" -dependencies = [ - "adler", -] - -[[package]] -name = "mio" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.48.0", -] - -[[package]] -name = "native-tls" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - -[[package]] -name = "ndarray" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" -dependencies = [ - "matrixmultiply", - "num-complex", - "num-integer", - "num-traits", - "rawpointer", -] - -[[package]] -name = "num-complex" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "object" -version = "0.32.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" -dependencies = [ - "memchr", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "openssl" -version = "0.10.64" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" -dependencies = [ - "bitflags 2.5.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "openssl-probe" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" - -[[package]] -name = "openssl-sys" -version = "0.9.102" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - -[[package]] -name = "parking_lot" -version = "0.12.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets 0.52.5", -] - -[[package]] -name = "password-hash" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" -dependencies = [ - "base64ct", - "rand_core", - "subtle", -] - -[[package]] -name = "pbkdf2" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" -dependencies = [ - "digest", - "hmac", - "password-hash", - "sha2", -] - -[[package]] -name = "percent-encoding" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" - -[[package]] -name = "pin-project" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "pkg-config" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" - -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "proc-macro2" -version = "1.0.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rawpointer" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" - -[[package]] -name = "redox_syscall" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" -dependencies = [ - "bitflags 2.5.0", -] - -[[package]] -name = "reqwest" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10" -dependencies = [ - "base64", - "bytes", - "encoding_rs", - "futures-core", - "futures-util", - "h2", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-tls", - "hyper-util", - "ipnet", - "js-sys", - "log", - "mime", - "native-tls", - "once_cell", - "percent-encoding", - "pin-project-lite", - "rustls-pemfile", - "serde", - "serde_json", - "serde_urlencoded", - "sync_wrapper", - "system-configuration", - "tokio", - "tokio-native-tls", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "winreg", -] - -[[package]] -name = "ring" -version = "0.17.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" -dependencies = [ - "cc", - "cfg-if", - "getrandom", - "libc", - "spin", - "untrusted", - "windows-sys 0.52.0", -] - -[[package]] -name = "runner" -version = "0.1.0" -dependencies = [ - "anyhow", - "rand", - "reqwest", - "serde", - "serde_json", - "serde_repr", - "tch", - "tokio", - "toml", -] - -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - -[[package]] -name = "rustix" -version = "0.38.34" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" -dependencies = [ - "bitflags 2.5.0", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.52.0", -] - -[[package]] -name = "rustls" -version = "0.22.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" -dependencies = [ - "log", - "ring", - "rustls-pki-types", - "rustls-webpki", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls-pemfile" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" -dependencies = [ - "base64", - "rustls-pki-types", -] - -[[package]] -name = "rustls-pki-types" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" - -[[package]] -name = "rustls-webpki" -version = "0.102.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - -[[package]] -name = "ryu" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" - -[[package]] -name = "safetensors" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "schannel" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" -dependencies = [ - "windows-sys 0.52.0", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "security-framework" -version = "2.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "serde" -version = "1.0.200" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.200" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.116" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "serde_repr" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_spanned" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" -dependencies = [ - "serde", -] - -[[package]] -name = "serde_urlencoded" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" -dependencies = [ - "form_urlencoded", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sha2" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "signal-hook-registry" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" -dependencies = [ - "libc", -] - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" - -[[package]] -name = "socket2" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - -[[package]] -name = "subtle" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" - -[[package]] -name = "syn" -version = "2.0.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "sync_wrapper" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" - -[[package]] -name = "system-configuration" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "system-configuration-sys", -] - -[[package]] -name = "system-configuration-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" -dependencies = [ - "core-foundation-sys", - "libc", -] - -[[package]] -name = "tch" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fd89a98303b22acd6d4969b4c8940f7a30ba79af32b744a2028375d156e95a" -dependencies = [ - "half", - "lazy_static", - "libc", - "ndarray", - "rand", - "safetensors", - "thiserror", - "torch-sys", - "zip", -] - -[[package]] -name = "tempfile" -version = "3.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" -dependencies = [ - "cfg-if", - "fastrand", - "rustix", - "windows-sys 0.52.0", -] - -[[package]] -name = "thiserror" -version = "1.0.59" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.59" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "time" -version = "0.3.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" -dependencies = [ - "deranged", - "num-conv", - "powerfmt", - "serde", - "time-core", -] - -[[package]] -name = "time-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" - -[[package]] -name = "tinyvec" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "tokio" -version = "1.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" -dependencies = [ - "backtrace", - "bytes", - "libc", - "mio", - "num_cpus", - "parking_lot", - "pin-project-lite", - "signal-hook-registry", - "socket2", - "tokio-macros", - "windows-sys 0.48.0", -] - -[[package]] -name = "tokio-macros" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - -[[package]] -name = "tokio-util" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", - "tracing", -] - -[[package]] -name = "toml" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef" -dependencies = [ - "indexmap", - "serde", - "serde_spanned", - "toml_datetime", - "winnow", -] - -[[package]] -name = "torch-sys" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5997681f7f3700fa475f541fcda44c8959ea42a724194316fe7297cb96ebb08" -dependencies = [ - "anyhow", - "cc", - "libc", - "serde", - "serde_json", - "ureq", - "zip", -] - -[[package]] -name = "tower" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-layer" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" - -[[package]] -name = "tower-service" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" - -[[package]] -name = "tracing" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" -dependencies = [ - "log", - "pin-project-lite", - "tracing-core", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", -] - -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "unicode-bidi" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unicode-normalization" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "ureq" -version = "2.9.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" -dependencies = [ - "base64", - "flate2", - "log", - "once_cell", - "rustls", - "rustls-pki-types", - "rustls-webpki", - "serde", - "serde_json", - "url", - "webpki-roots", -] - -[[package]] -name = "url" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-futures" -version = "0.4.42" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" -dependencies = [ - "cfg-if", - "js-sys", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.92" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" - -[[package]] -name = "web-sys" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webpki-roots" -version = "0.26.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" -dependencies = [ - "rustls-pki-types", -] - -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.5", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" -dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", - "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" - -[[package]] -name = "winnow" -version = "0.6.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14b9415ee827af173ebb3f15f9083df5a122eb93572ec28741fb153356ea2578" -dependencies = [ - "memchr", -] - -[[package]] -name = "winreg" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - -[[package]] -name = "zeroize" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" - -[[package]] -name = "zip" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" -dependencies = [ - "aes", - "byteorder", - "bzip2", - "constant_time_eq", - "crc32fast", - "crossbeam-utils", - "flate2", - "hmac", - "pbkdf2", - "sha1", - "time", - "zstd", -] - -[[package]] -name = "zstd" -version = "0.11.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "5.0.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" -dependencies = [ - "libc", - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "2.0.10+zstd.1.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" -dependencies = [ - "cc", - "pkg-config", -] diff --git a/runner/Cargo.toml b/runner/Cargo.toml deleted file mode 100644 index f9cb801..0000000 --- a/runner/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "runner" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -anyhow = "1.0.82" -serde = { version = "1.0.200", features = ["derive"] } -toml = "0.8.12" -reqwest = { version = "0.12", features = ["json"] } -tokio = { version = "1", features = ["full"] } -serde_json = "1.0.116" -serde_repr = "0.1" -tch = { version = "0.16.0", features = ["download-libtorch"] } -rand = "0.8.5" diff --git a/runner/Dockerfile b/runner/Dockerfile deleted file mode 100644 index 5685a0c..0000000 --- a/runner/Dockerfile +++ /dev/null @@ -1,12 +0,0 @@ -FROM docker.io/nvidia/cuda:11.7.1-devel-ubuntu22.04 - -RUN apt-get update -RUN apt-get install -y curl - -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y -ENV PATH="$PATH:/root/.cargo/bin" -RUN rustup toolchain install stable - -RUN apt-get install -y pkg-config libssl-dev - -WORKDIR /app diff --git a/runner/config.toml b/runner/config.toml deleted file mode 100644 index 659a748..0000000 --- a/runner/config.toml +++ /dev/null @@ -1,3 +0,0 @@ -hostname = "https://testing.andr3h3nriqu3s.com/api" -token = "d2bc41e8293937bcd9397870c98f97acc9603f742924b518e193cd1013e45d57897aa302b364001c72b458afcfb34239dfaf38a66b318e5cbc973eea" -data_path = "/home/andr3/Documents/my-repos/fyp" diff --git a/runner/data.toml b/runner/data.toml deleted file mode 100644 index a52e09b..0000000 --- a/runner/data.toml +++ /dev/null @@ -1 +0,0 @@ -id = "a7cec9e9-1d05-4633-8bc5-6faabe4fd5a3" diff --git a/runner/run.sh b/runner/run.sh deleted file mode 100755 index 4b5a346..0000000 --- a/runner/run.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -podman run --rm --network host --gpus all -ti -v $(pwd):/app -e "TERM=xterm-256color" fyp-runner bash diff --git a/runner/src/dataloader.rs b/runner/src/dataloader.rs deleted file mode 100644 index 281e4f9..0000000 --- a/runner/src/dataloader.rs +++ /dev/null @@ -1,115 +0,0 @@ -use crate::{model::DataPoint, settings::ConfigFile}; -use std::{path::Path, sync::Arc}; -use tch::Tensor; - -pub struct DataLoader { - pub batch_size: i64, - pub len: usize, - pub inputs: Vec, - pub labels: Vec, - pub pos: usize, -} - -fn import_image( - item: &DataPoint, - base_path: &Path, - classes_len: i64, - inputs: &mut Vec, - labels: &mut Vec, -) { - inputs.push( - tch::vision::image::load(base_path.join(&item.path)) - .ok() - .unwrap() - .unsqueeze(0), - ); - - if item.class >= 0 { - let t = tch::Tensor::from_slice(&[item.class]).onehot(classes_len as i64); - labels.push(t); - } else { - labels.push(tch::Tensor::zeros( - [1, classes_len as i64], - (tch::Kind::Float, tch::Device::Cpu), - )) - } -} - -impl DataLoader { - pub fn new( - config: Arc, - data: Vec, - classes_len: i64, - batch_size: i64, - ) -> DataLoader { - let len: f64 = (data.len() as f64) / (batch_size as f64); - let min_len: i64 = len.floor() as i64; - let max_len: i64 = len.ceil() as i64; - - println!( - "Creating dataloader data len: {} len: {} min_len: {} max_len:{}", - data.len(), - len, - min_len, - max_len - ); - - let base_path = Path::new(&config.data_path); - - let mut inputs: Vec = Vec::new(); - let mut all_labels: Vec = Vec::new(); - - for batch in 0..min_len { - let mut batch_acc: Vec = Vec::new(); - let mut labels: Vec = Vec::new(); - for image in 0..batch_size { - let i: usize = (batch * batch_size + image).try_into().unwrap(); - let item = &data[i]; - import_image(item, base_path, classes_len, &mut batch_acc, &mut labels) - } - inputs.push(tch::Tensor::cat(&batch_acc[0..], 0)); - all_labels.push(tch::Tensor::cat(&labels[0..], 0)); - } - - // Import the last batch that has irregular sizing - if min_len != max_len { - let mut batch_acc: Vec = Vec::new(); - let mut labels: Vec = Vec::new(); - for image in 0..(data.len() - (batch_size * min_len) as usize) { - let i: usize = (min_len * batch_size + (image as i64)) as usize; - let item = &data[i]; - import_image(item, base_path, classes_len, &mut batch_acc, &mut labels); - } - inputs.push(tch::Tensor::cat(&batch_acc[0..], 0)); - all_labels.push(tch::Tensor::cat(&labels[0..], 0)); - } - - println!("ins shape: {:?}", inputs[0].size()); - - return DataLoader { - batch_size, - inputs, - labels: all_labels, - len: max_len as usize, - pos: 0, - }; - } - - pub fn restart(self: &mut DataLoader) { - self.pos = 0; - } - - pub fn next(self: &mut DataLoader) -> Option<(Tensor, Tensor)> { - if self.pos >= self.len { - return None; - } - let input = self.inputs[self.pos].empty_like(); - self.inputs[self.pos] = self.inputs[self.pos].clone(&input); - let label = self.labels[self.pos].empty_like(); - self.labels[self.pos] = self.labels[self.pos].clone(&label); - - self.pos += 1; - - return Some((input, label)); - } -} diff --git a/runner/src/main.rs b/runner/src/main.rs deleted file mode 100644 index b3ae49b..0000000 --- a/runner/src/main.rs +++ /dev/null @@ -1,206 +0,0 @@ -mod dataloader; -mod model; -mod settings; -mod tasks; -mod training; -mod types; - -use crate::settings::*; -use crate::tasks::{fail_task, Task, TaskType}; -use crate::training::handle_train; -use anyhow::{bail, Result}; -use reqwest::StatusCode; -use serde_json::json; -use std::{fs, process::exit, sync::Arc, time::Duration}; - -enum ResultAlive { - Ok, - Error, - NotInit, -} - -async fn send_keep_alive_message( - config: Arc, - runner_data: Arc, -) -> ResultAlive { - let client = reqwest::Client::new(); - - let to_send = json!({ - "id": runner_data.id, - }); - - let resp = client - .post(format!("{}/tasks/runner/beat", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await; - - if resp.is_err() { - return ResultAlive::Error; - } - - let resp = resp.ok(); - - if resp.is_none() { - return ResultAlive::Error; - } - - let resp = resp.unwrap(); - - // TODO see if the message is related to not being inited - if resp.status() != 200 { - println!("Could not connect with the status"); - return ResultAlive::Error; - } - - ResultAlive::Ok -} - -async fn keep_alive(config: Arc, runner_data: Arc) -> Result<()> { - let mut failed = 0; - loop { - match send_keep_alive_message(config.clone(), runner_data.clone()).await { - ResultAlive::Error => failed += 1, - ResultAlive::NotInit => { - println!("Runner not inited! Restarting!"); - exit(1) - } - ResultAlive::Ok => failed = 0, - } - - // TODO move to config - if failed > 20 { - println!("Failed to connect to API! More than 20 times in a row stoping"); - exit(1) - } - - tokio::time::sleep(Duration::from_secs(1)).await; - } -} - -async fn handle_task( - task: Task, - config: Arc, - runner_data: Arc, -) -> Result<()> { - let res = match task.task_type { - TaskType::Training => handle_train(&task, config.clone(), runner_data.clone()).await, - _ => { - println!("Do not know how to handle this task #{:?}", task); - bail!("Failed") - } - }; - - if res.is_err() { - println!("task failed #{:?}", res); - fail_task( - &task, - config, - runner_data, - "Do not know how to handle this kind of task", - ) - .await? - } - - Ok(()) -} - -#[tokio::main] -async fn main() -> Result<()> { - // Load config file - let config_data = fs::read_to_string("./config.toml")?; - let mut config: ConfigFile = toml::from_str(&config_data)?; - - let client = reqwest::Client::new(); - if config.config_path == None { - config.config_path = Some(String::from("./data.toml")) - } - - let runner_data: RunnerData = load_runner_data(&config).await?; - - let to_send = json!({ - "id": runner_data.id, - }); - - // Inform the server that the runner is available - let resp = client - .post(format!("{}/tasks/runner/init", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await?; - - if resp.status() != 200 { - println!( - "Could not connect with the api: status {} body {}", - resp.status(), - resp.text().await? - ); - return Ok(()); - } - - let res = resp.json::().await?; - if res != "Ok" { - print!("Unexpected problem: {}", res); - return Ok(()); - } - - let config = Arc::new(config); - let runner_data = Arc::new(runner_data); - - let config_alive = config.clone(); - let runner_data_alive = runner_data.clone(); - std::thread::spawn(move || keep_alive(config_alive.clone(), runner_data_alive.clone())); - - println!("Started main loop"); - loop { - //TODO move time to config - tokio::time::sleep(Duration::from_secs(1)).await; - - let to_send = json!({ "id": runner_data.id }); - - let resp = client - .post(format!("{}/tasks/runner/active", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await; - - if resp.is_err() || resp.as_ref().ok().is_none() { - println!("Failed to get info from server {:?}", resp); - continue; - } - - let resp = resp?; - - match resp.status() { - // No active task - StatusCode::NOT_FOUND => (), - StatusCode::OK => { - println!("Found task!"); - - let task: Result = resp.json().await; - if task.is_err() || task.as_ref().ok().is_none() { - println!("Failed to resolve the json {:?}", task); - continue; - } - - let task = task?; - - let res = handle_task(task, config.clone(), runner_data.clone()).await; - - if res.is_err() || res.as_ref().ok().is_none() { - println!("Failed to run the task"); - } - - _ = res; - () - } - _ => { - println!("Unexpected error #{:?}", resp); - exit(1) - } - } - } -} diff --git a/runner/src/model/mod.rs b/runner/src/model/mod.rs deleted file mode 100644 index 38feaa4..0000000 --- a/runner/src/model/mod.rs +++ /dev/null @@ -1,117 +0,0 @@ -use anyhow::bail; -use serde::{Deserialize, Serialize}; -use serde_repr::{Deserialize_repr, Serialize_repr}; -use tch::{ - nn::{self, Module}, - Device, -}; - -#[derive(Debug)] -pub struct Model { - pub vs: nn::VarStore, - pub seq: nn::Sequential, - pub layers: Vec, -} - -#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr)] -#[repr(i8)] -pub enum LayerType { - Input = 1, - Dense = 2, - Flatten = 3, - SimpleBlock = 4, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Layer { - pub id: String, - pub definition_id: String, - pub layer_order: String, - pub layer_type: LayerType, - pub shape: String, - pub exp_type: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DataPoint { - pub class: i64, - pub path: String, -} - -pub fn build_model(layers: Vec, last_linear_size: i64, add_sigmoid: bool) -> Model { - let vs = nn::VarStore::new(Device::Cuda(0)); - - let mut seq = nn::seq(); - - let mut last_linear_size = last_linear_size; - let mut last_linear_conv: Vec = Vec::new(); - - for layer in layers.iter() { - match layer.layer_type { - LayerType::Input => { - last_linear_conv = serde_json::from_str(&layer.shape).ok().unwrap(); - println!("Layer: Input, In: {:?}", last_linear_conv); - } - LayerType::Dense => { - let shape: Vec = serde_json::from_str(&layer.shape).ok().unwrap(); - println!("Layer: Dense, In: {}, Out: {}", last_linear_size, shape[0]); - seq = seq - .add(nn::linear( - &vs.root(), - last_linear_size, - shape[0], - Default::default(), - )) - .add_fn(|xs| xs.relu()); - last_linear_size = shape[0]; - } - LayerType::Flatten => { - seq = seq.add_fn(|xs| xs.flatten(1, -1)); - last_linear_size = 1; - for i in &last_linear_conv { - last_linear_size *= i; - } - println!( - "Layer: flatten, In: {:?}, Out: {}", - last_linear_conv, last_linear_size - ) - } - LayerType::SimpleBlock => { - let new_last_linear_conv = - vec![128, last_linear_conv[1] / 2, last_linear_conv[2] / 2]; - println!( - "Layer: block, In: {:?}, Put: {:?}", - last_linear_conv, new_last_linear_conv, - ); - let out_size = vec![new_last_linear_conv[1], new_last_linear_conv[2]]; - seq = seq - .add(nn::conv2d( - &vs.root(), - last_linear_conv[0], - 128, - 3, - nn::ConvConfig::default(), - )) - .add_fn(|xs| xs.relu()) - .add(nn::conv2d( - &vs.root(), - 128, - 128, - 3, - nn::ConvConfig::default(), - )) - .add_fn(|xs| xs.relu()) - .add_fn(move |xs| xs.adaptive_avg_pool2d([out_size[1], out_size[1]])) - .add_fn(|xs| xs.leaky_relu()); - //m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0])) - last_linear_conv = new_last_linear_conv; - } - } - } - - if add_sigmoid { - seq = seq.add_fn(|xs| xs.sigmoid()); - } - - return Model { vs, layers, seq }; -} diff --git a/runner/src/settings.rs b/runner/src/settings.rs deleted file mode 100644 index a9c3603..0000000 --- a/runner/src/settings.rs +++ /dev/null @@ -1,57 +0,0 @@ -use anyhow::{bail, Result}; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::{fs, path}; - -#[derive(Deserialize)] -pub struct ConfigFile { - // Hostname to connect with the api - pub hostname: String, - // Token used in the api to authenticate - pub token: String, - // Path to where to store some generated configuration values - // defaults to ./data.toml - pub config_path: Option, - // Data Path - // Path to where the data is mounted - pub data_path: String, -} - -#[derive(Deserialize, Serialize)] -pub struct RunnerData { - pub id: String, -} - -pub async fn load_runner_data(config: &ConfigFile) -> Result { - let data_path = config.config_path.as_ref().unwrap(); - let data_path = path::Path::new(&*data_path); - - if data_path.exists() { - let runner_data = fs::read_to_string(data_path)?; - Ok(toml::from_str(&runner_data)?) - } else { - let client = reqwest::Client::new(); - let to_send = json!({ - "token": config.token, - "type": 1, - }); - - let register_resp = client - .post(format!("{}/tasks/runner/register", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await?; - - if register_resp.status() != 200 { - bail!(format!("Could not create runner {:#?}", register_resp)); - } - - let runner_data: RunnerData = register_resp.json().await?; - - fs::write(data_path, toml::to_string(&runner_data)?) - .expect("Faield to save data for runner"); - - Ok(runner_data) - } -} diff --git a/runner/src/tasks.rs b/runner/src/tasks.rs deleted file mode 100644 index 9b54157..0000000 --- a/runner/src/tasks.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::sync::Arc; - -use anyhow::{bail, Result}; -use serde::Deserialize; -use serde_json::json; -use serde_repr::Deserialize_repr; - -use crate::{ConfigFile, RunnerData}; - -#[derive(Clone, Copy, Deserialize_repr, Debug)] -#[repr(i8)] -pub enum TaskStatus { - FailedRunning = -2, - FailedCreation = -1, - Preparing = 0, - Todo = 1, - PickedUp = 2, - Running = 3, - Done = 4, -} - -#[derive(Clone, Copy, Deserialize_repr, Debug)] -#[repr(i8)] -pub enum TaskType { - Classification = 1, - Training = 2, - Retraining = 3, - DeleteUser = 4, -} - -#[derive(Deserialize, Debug)] -pub struct Task { - pub id: String, - pub user_id: String, - pub model_id: String, - pub status: TaskStatus, - pub status_message: String, - pub user_confirmed: i8, - pub compacted: i8, - #[serde(alias = "type")] - pub task_type: TaskType, - pub extra_task_info: String, - pub result: String, - pub created: String, -} - -pub async fn fail_task( - task: &Task, - config: Arc, - runner_data: Arc, - reason: &str, -) -> Result<()> { - println!("Marking Task as failed"); - - let client = reqwest::Client::new(); - - let to_send = json!({ - "id": runner_data.id, - "taskId": task.id, - "reason": reason - }); - - let resp = client - .post(format!("{}/tasks/runner/fail", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await?; - - if resp.status() != 200 { - println!("Failed to update status of task"); - bail!("Failed to update status of task"); - } - - Ok(()) -} - -impl Task { - pub async fn fail( - self: &mut Task, - config: Arc, - runner_data: Arc, - reason: &str, - ) -> Result<()> { - fail_task(self, config, runner_data, reason).await?; - self.status = TaskStatus::FailedRunning; - self.status_message = reason.to_string(); - Ok(()) - } -} diff --git a/runner/src/training.rs b/runner/src/training.rs deleted file mode 100644 index 99423ed..0000000 --- a/runner/src/training.rs +++ /dev/null @@ -1,599 +0,0 @@ -use crate::{ - dataloader::DataLoader, - model::{self, build_model}, - settings::{ConfigFile, RunnerData}, - tasks::{fail_task, Task}, - types::{DataPointRequest, Definition, ModelClass}, -}; -use std::{ - io::{self, Write}, - sync::Arc, -}; - -use anyhow::Result; -use rand::{seq::SliceRandom, thread_rng}; -use serde_json::json; -use tch::{ - nn::{self, Module, OptimizerConfig}, - Cuda, Tensor, -}; - -pub async fn handle_train( - task: &Task, - config: Arc, - runner_data: Arc, -) -> Result<()> { - let client = reqwest::Client::new(); - println!("Preparing to train a model"); - - let to_send = json!({ - "id": runner_data.id, - "taskId": task.id, - }); - - let mut defs: Vec = client - .post(format!("{}/tasks/runner/train/defs", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await? - .json() - .await?; - - if defs.len() == 0 { - println!("No defs found"); - fail_task(task, config, runner_data, "No definitions found").await?; - return Ok(()); - } - - let classes: Vec = client - .post(format!("{}/tasks/runner/train/classes", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await? - .json() - .await?; - - let data: DataPointRequest = client - .post(format!("{}/tasks/runner/train/datapoints", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await? - .json() - .await?; - - let mut testing = data.testing; - - testing.shuffle(&mut thread_rng()); - - let mut data_loader = DataLoader::new(config.clone(), testing, classes.len() as i64, 64); - - // TODO make this a vec - let mut model: Option = None; - - loop { - let config = config.clone(); - let runner_data = runner_data.clone(); - let mut to_remove: Vec = Vec::new(); - - let mut def_iter = defs.iter_mut(); - - let mut i: usize = 0; - while let Some(def) = def_iter.next() { - def.updateStatus( - task, - config.clone(), - runner_data.clone(), - crate::types::DefinitionStatus::Training, - ) - .await?; - - let model_err = train_definition( - def, - &mut data_loader, - model, - config.clone(), - runner_data.clone(), - &task, - ) - .await; - - if model_err.is_err() { - println!("Failed to create model {:?}", model_err); - model = None; - to_remove.push(i); - continue; - } - - model = model_err?; - - i += 1; - } - - defs = defs - .into_iter() - .enumerate() - .filter(|&(i, _)| to_remove.iter().any(|b| *b == i)) - .map(|(_, e)| e) - .collect(); - - break; - } - - fail_task(task, config, runner_data, "TODO").await?; - Ok(()) - - /* - for { - // Keep track of definitions that did not train fast enough - var toRemove ToRemoveList = []int{} - - for i, def := range definitions { - - accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes) - if err != nil { - log.Error("Failed to train definition!Err:", "err", err) - def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) - toRemove = append(toRemove, i) - continue - } - models[def.Id] = ml_model - - if accuracy >= float64(def.TargetAccuracy) { - log.Info("Found a definition that reaches target_accuracy!") - _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id) - if err != nil { - log.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return err - } - - _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING) - if err != nil { - log.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return err - } - - finished = true - break - } - - if def.Epoch > MAX_EPOCH { - fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy) - def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) - toRemove = append(toRemove, i) - continue - } - - _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id) - if err != nil { - log.Error("Failed to train definition!Err:\n", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return err - } - } - - if finished { - break - } - - sort.Sort(sort.Reverse(toRemove)) - - log.Info("Round done", "toRemove", toRemove) - - for _, n := range toRemove { - // Clean up unsed models - models[definitions[n].Id] = nil - definitions = remove(definitions, n) - } - - len_def := len(definitions) - - if len_def == 0 { - break - } - - if len_def == 1 { - continue - } - - sort.Sort(sort.Reverse(definitions)) - - acc := definitions[0].Accuracy - 20.0 - - log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc) - - toRemove = []int{} - for i, def := range definitions { - if def.Accuracy < acc { - toRemove = append(toRemove, i) - } - } - - log.Info("Removing due to accuracy", "toRemove", toRemove) - - sort.Sort(sort.Reverse(toRemove)) - for _, n := range toRemove { - log.Warn("Removing definition not fast enough learning", "n", n) - definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) - models[definitions[n].Id] = nil - definitions = remove(definitions, n) - } - } - - var def Definition - err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED) - if err != nil { - if err == NotFoundError { - log.Error("All definitions failed to train!") - } else { - log.Error("DB: failed to read definition", "err", err) - } - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return - } - - if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil { - log.Error("Failed to update model definition", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return - } - - to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id) - if err != nil { - log.Error("Failed to select model_definition to delete") - log.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return - } - defer to_delete.Close() - - for to_delete.Next() { - var id string - if err = to_delete.Scan(&id); err != nil { - log.Error("Failed to scan the id of a model_definition to delete", "err", err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return - } - os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) - } - - // TODO Check if returning also works here - if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil { - log.Error("Failed to delete model_definition") - log.Error(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return - } - - ModelUpdateStatus(c, model.Id, READY) - - return - */ -} - -async fn train_definition( - def: &Definition, - data_loader: &mut DataLoader, - model: Option, - config: Arc, - runner_data: Arc, - task: &Task, -) -> Result> { - let client = reqwest::Client::new(); - println!("About to start training definition"); - - let mut accuracy = 0; - - let model = model.unwrap_or({ - let layers: Vec = client - .post(format!("{}/tasks/runner/train/def/layers", config.hostname)) - .header("token", &config.token) - .body( - json!({ - "id": runner_data.id, - "taskId": task.id, - "defId": def.id, - }) - .to_string(), - ) - .send() - .await? - .json() - .await?; - - build_model(layers, 0, true) - }); - - // TODO CUDA - // get device - // Move model to cuda - - let mut opt = nn::Adam::default().build(&model.vs, 1e-3)?; - - let mut last_acc = 0.0; - - for epoch in 1..40 { - data_loader.restart(); - let mut mean_loss: f64 = 0.0; - let mut mean_acc: f64 = 0.0; - while let Some((inputs, labels)) = data_loader.next() { - let inputs = inputs - .to_kind(tch::Kind::Float) - .to_device(tch::Device::Cuda(0)); - let labels = labels - .to_kind(tch::Kind::Float) - .to_device(tch::Device::Cuda(0)); - let out = model.seq.forward(&inputs); - let weight: Option = None; - let loss = out.binary_cross_entropy(&labels, weight, tch::Reduction::Mean); - opt.backward_step(&loss); - mean_loss += loss - .to_device(tch::Device::Cpu) - .unsqueeze(0) - .double_value(&[0]); - - let out = out.to_device(tch::Device::Cpu); - - let test = out.empty_like(); - _ = out.clone(&test); - - let out = test.argmax(1, true); - - let mut labels = labels.to_device(tch::Device::Cpu); - - labels = labels.unsqueeze(-1); - - let size = out.size()[0]; - - let mut acc = 0; - for i in 0..size { - let res = out.double_value(&[i]); - let exp = labels.double_value(&[i, res as i64]); - if exp == 1.0 { - acc += 1; - } - } - - mean_acc += acc as f64 / size as f64; - last_acc = acc as f64 / size as f64; - } - print!( - "\repoch: {} loss: {} acc: {} l acc: {} ", - epoch, - mean_loss / data_loader.len as f64, - mean_acc / data_loader.len as f64, - last_acc - ); - io::stdout().flush().expect("Unable to flush stdout"); - } - - println!("\nlast acc: {}", last_acc); - - return Ok(Some(model)); - /* - - opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001) - if err != nil { - return - } - - for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ { - var trainIter *torch.Iter2 - trainIter, err = ds.TrainIter(32) - if err != nil { - return - } - // trainIter.ToDevice(device) - - log.Info("epoch", "epoch", epoch) - - var trainLoss float64 = 0 - var trainCorrect float64 = 0 - ok := true - for ok { - var item torch.Iter2Item - var loss *torch.Tensor - item, ok = trainIter.Next() - if !ok { - continue - } - - data := item.Data - data, err = data.ToDevice(device, gotch.Float, false, true, false) - if err != nil { - return - } - - var size []int64 - size, err = data.Size() - if err != nil { - return - } - - var zeros *torch.Tensor - zeros, err = torch.Zeros(size, gotch.Float, device) - if err != nil { - return - } - - data, err = zeros.Add(data, true) - if err != nil { - return - } - - log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) - - data, err = data.SetRequiresGrad(true, false) - if err != nil { - return - } - - log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) - - err = data.RetainGrad(false) - if err != nil { - return - } - - log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) - - pred := model.ForwardT(data, true) - pred, err = pred.SetRequiresGrad(true, true) - if err != nil { - return - } - - err = pred.RetainGrad(false) - if err != nil { - return - } - - label := item.Label - label, err = label.ToDevice(device, gotch.Float, false, true, false) - if err != nil { - return - } - label, err = label.SetRequiresGrad(true, true) - if err != nil { - return - } - err = label.RetainGrad(false) - if err != nil { - return - } - - // Calculate loss - loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false) - if err != nil { - return - } - loss, err = loss.SetRequiresGrad(true, false) - if err != nil { - return - } - err = loss.RetainGrad(false) - if err != nil { - return - } - - err = opt.ZeroGrad() - if err != nil { - return - } - - err = loss.Backward() - if err != nil { - return - } - - log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values()) - log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values()) - log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false)) - - vars := model.Vs.Variables() - - for k, v := range vars { - log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false)) - } - - model.Debug() - - err = opt.Step() - if err != nil { - return - } - - trainLoss = loss.Float64Values()[0] - - // Calculate accuracy - / *var p_pred, p_labels *torch.Tensor - p_pred, err = pred.Argmax([]int64{1}, true, false) - if err != nil { - return - } - - p_labels, err = item.Label.Argmax([]int64{1}, true, false) - if err != nil { - return - } - - floats := p_pred.Float64Values() - floats_labels := p_labels.Float64Values() - - for i := range floats { - if floats[i] == floats_labels[i] { - trainCorrect += 1 - } - } * / - - panic("fornow") - } - - //v := []float64{} - - log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize)) - - / *correct := int64(0) - //torch.NoGrad(func() { - ok = true - testIter := ds.TestIter(64) - for ok { - var item torch.Iter2Item - item, ok = testIter.Next() - if !ok { - continue - } - - output := model.Forward(item.Data) - - var pred, labels *torch.Tensor - pred, err = output.Argmax([]int64{1}, true, false) - if err != nil { - return - } - - labels, err = item.Label.Argmax([]int64{1}, true, false) - if err != nil { - return - } - - floats := pred.Float64Values() - floats_labels := labels.Float64Values() - - for i := range floats { - if floats[i] == floats_labels[i] { - correct += 1 - } - } - } - - accuracy = float64(correct) / float64(ds.TestImagesSize) - - log.Info("Eval accuracy", "accuracy", accuracy) - - err = def.UpdateAfterEpoch(db, accuracy*100) - if err != nil { - return - }* / - //}) - } - - result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id) - err = os.MkdirAll(result_path, os.ModePerm) - if err != nil { - return - } - - err = my_torch.SaveModel(model, path.Join(result_path, "model.dat")) - if err != nil { - return - } - - log.Info("Model finished training!", "accuracy", accuracy) - return - */ -} diff --git a/runner/src/types.rs b/runner/src/types.rs deleted file mode 100644 index b5fd4a4..0000000 --- a/runner/src/types.rs +++ /dev/null @@ -1,89 +0,0 @@ -use crate::{model, tasks::Task, ConfigFile, RunnerData}; -use anyhow::{bail, Result}; -use serde::Deserialize; -use serde_json::json; -use serde_repr::{Deserialize_repr, Serialize_repr}; -use std::sync::Arc; - -#[derive(Clone, Copy, Deserialize_repr, Serialize_repr, Debug)] -#[repr(i8)] -pub enum DefinitionStatus { - CanceldTraining = -4, - FailedTraining = -3, - PreInit = 1, - Init = 2, - Training = 3, - PausedTraining = 6, - Tranied = 4, - Ready = 5, -} - -#[derive(Deserialize, Debug)] -pub struct Definition { - pub id: String, - pub model_id: String, - pub accuracy: f64, - pub target_accuracy: i64, - pub epoch: i64, - pub status: i64, - pub created: String, - pub epoch_progress: i64, -} - -impl Definition { - pub async fn updateStatus( - self: &mut Definition, - task: &Task, - config: Arc, - runner_data: Arc, - status: DefinitionStatus, - ) -> Result<()> { - println!("Marking Task as faield"); - - let client = reqwest::Client::new(); - - let to_send = json!({ - "id": runner_data.id, - "taskId": task.id, - "defId": self.id, - "status": status, - }); - - let resp = client - .post(format!("{}/tasks/runner/train/def/status", config.hostname)) - .header("token", &config.token) - .body(to_send.to_string()) - .send() - .await?; - - if resp.status() != 200 { - println!("Failed to update status of task"); - bail!("Failed to update status of task"); - } - - Ok(()) - } -} - -#[derive(Clone, Copy, Deserialize_repr, Debug)] -#[repr(i8)] -pub enum ModelClassStatus { - ToTrain = 1, - Training = 2, - Trained = 3, -} - -#[derive(Deserialize, Debug)] -pub struct ModelClass { - pub id: String, - pub model_id: String, - pub name: String, - pub class_order: i64, - pub status: ModelClassStatus, -} - -#[derive(Deserialize, Debug)] -pub struct DataPointRequest { - pub testing: Vec, - pub training: Vec, -}