use crate::{ dataloader::DataLoader, model::{self, build_model}, settings::{ConfigFile, RunnerData}, tasks::{fail_task, Task}, types::{DataPointRequest, Definition, ModelClass}, }; use std::{ io::{self, Write}, sync::Arc, }; use anyhow::Result; use rand::{seq::SliceRandom, thread_rng}; use serde_json::json; use tch::{ nn::{self, Module, OptimizerConfig}, Cuda, Tensor, }; pub async fn handle_train( task: &Task, config: Arc, runner_data: Arc, ) -> Result<()> { let client = reqwest::Client::new(); println!("Preparing to train a model"); let to_send = json!({ "id": runner_data.id, "taskId": task.id, }); let mut defs: Vec = client .post(format!("{}/tasks/runner/train/defs", config.hostname)) .header("token", &config.token) .body(to_send.to_string()) .send() .await? .json() .await?; if defs.len() == 0 { println!("No defs found"); fail_task(task, config, runner_data, "No definitions found").await?; return Ok(()); } let classes: Vec = client .post(format!("{}/tasks/runner/train/classes", config.hostname)) .header("token", &config.token) .body(to_send.to_string()) .send() .await? .json() .await?; let data: DataPointRequest = client .post(format!("{}/tasks/runner/train/datapoints", config.hostname)) .header("token", &config.token) .body(to_send.to_string()) .send() .await? .json() .await?; let mut testing = data.testing; testing.shuffle(&mut thread_rng()); let mut data_loader = DataLoader::new(config.clone(), testing, classes.len() as i64, 64); // TODO make this a vec let mut model: Option = None; loop { let config = config.clone(); let runner_data = runner_data.clone(); let mut to_remove: Vec = Vec::new(); let mut def_iter = defs.iter_mut(); let mut i: usize = 0; while let Some(def) = def_iter.next() { def.updateStatus( task, config.clone(), runner_data.clone(), crate::types::DefinitionStatus::Training, ) .await?; let model_err = train_definition( def, &mut data_loader, model, config.clone(), runner_data.clone(), &task, ) .await; if model_err.is_err() { println!("Failed to create model {:?}", model_err); model = None; to_remove.push(i); continue; } model = model_err?; i += 1; } defs = defs .into_iter() .enumerate() .filter(|&(i, _)| to_remove.iter().any(|b| *b == i)) .map(|(_, e)| e) .collect(); break; } fail_task(task, config, runner_data, "TODO").await?; Ok(()) /* for { // Keep track of definitions that did not train fast enough var toRemove ToRemoveList = []int{} for i, def := range definitions { accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes) if err != nil { log.Error("Failed to train definition!Err:", "err", err) def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) toRemove = append(toRemove, i) continue } models[def.Id] = ml_model if accuracy >= float64(def.TargetAccuracy) { log.Info("Found a definition that reaches target_accuracy!") _, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id) if err != nil { log.Error("Failed to train definition!Err:\n", "err", err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return err } _, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING) if err != nil { log.Error("Failed to train definition!Err:\n", "err", err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return err } finished = true break } if def.Epoch > MAX_EPOCH { fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy) def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) toRemove = append(toRemove, i) continue } _, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id) if err != nil { log.Error("Failed to train definition!Err:\n", "err", err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return err } } if finished { break } sort.Sort(sort.Reverse(toRemove)) log.Info("Round done", "toRemove", toRemove) for _, n := range toRemove { // Clean up unsed models models[definitions[n].Id] = nil definitions = remove(definitions, n) } len_def := len(definitions) if len_def == 0 { break } if len_def == 1 { continue } sort.Sort(sort.Reverse(definitions)) acc := definitions[0].Accuracy - 20.0 log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc) toRemove = []int{} for i, def := range definitions { if def.Accuracy < acc { toRemove = append(toRemove, i) } } log.Info("Removing due to accuracy", "toRemove", toRemove) sort.Sort(sort.Reverse(toRemove)) for _, n := range toRemove { log.Warn("Removing definition not fast enough learning", "n", n) definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING) models[definitions[n].Id] = nil definitions = remove(definitions, n) } } var def Definition err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED) if err != nil { if err == NotFoundError { log.Error("All definitions failed to train!") } else { log.Error("DB: failed to read definition", "err", err) } ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil { log.Error("Failed to update model definition", "err", err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id) if err != nil { log.Error("Failed to select model_definition to delete") log.Error(err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } defer to_delete.Close() for to_delete.Next() { var id string if err = to_delete.Scan(&id); err != nil { log.Error("Failed to scan the id of a model_definition to delete", "err", err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) } // TODO Check if returning also works here if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil { log.Error("Failed to delete model_definition") log.Error(err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } ModelUpdateStatus(c, model.Id, READY) return */ } async fn train_definition( def: &Definition, data_loader: &mut DataLoader, model: Option, config: Arc, runner_data: Arc, task: &Task, ) -> Result> { let client = reqwest::Client::new(); println!("About to start training definition"); let mut accuracy = 0; let model = model.unwrap_or({ let layers: Vec = client .post(format!("{}/tasks/runner/train/def/layers", config.hostname)) .header("token", &config.token) .body( json!({ "id": runner_data.id, "taskId": task.id, "defId": def.id, }) .to_string(), ) .send() .await? .json() .await?; build_model(layers, 0, true) }); // TODO CUDA // get device // Move model to cuda let mut opt = nn::Adam::default().build(&model.vs, 1e-3)?; let mut last_acc = 0.0; for epoch in 1..40 { data_loader.restart(); let mut mean_loss: f64 = 0.0; let mut mean_acc: f64 = 0.0; while let Some((inputs, labels)) = data_loader.next() { let inputs = inputs .to_kind(tch::Kind::Float) .to_device(tch::Device::Cuda(0)); let labels = labels .to_kind(tch::Kind::Float) .to_device(tch::Device::Cuda(0)); let out = model.seq.forward(&inputs); let weight: Option = None; let loss = out.binary_cross_entropy(&labels, weight, tch::Reduction::Mean); opt.backward_step(&loss); mean_loss += loss .to_device(tch::Device::Cpu) .unsqueeze(0) .double_value(&[0]); let out = out.to_device(tch::Device::Cpu); let test = out.empty_like(); _ = out.clone(&test); let out = test.argmax(1, true); let mut labels = labels.to_device(tch::Device::Cpu); labels = labels.unsqueeze(-1); let size = out.size()[0]; let mut acc = 0; for i in 0..size { let res = out.double_value(&[i]); let exp = labels.double_value(&[i, res as i64]); if exp == 1.0 { acc += 1; } } mean_acc += acc as f64 / size as f64; last_acc = acc as f64 / size as f64; } print!( "\repoch: {} loss: {} acc: {} l acc: {} ", epoch, mean_loss / data_loader.len as f64, mean_acc / data_loader.len as f64, last_acc ); io::stdout().flush().expect("Unable to flush stdout"); } println!("\nlast acc: {}", last_acc); return Ok(Some(model)); /* opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001) if err != nil { return } for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ { var trainIter *torch.Iter2 trainIter, err = ds.TrainIter(32) if err != nil { return } // trainIter.ToDevice(device) log.Info("epoch", "epoch", epoch) var trainLoss float64 = 0 var trainCorrect float64 = 0 ok := true for ok { var item torch.Iter2Item var loss *torch.Tensor item, ok = trainIter.Next() if !ok { continue } data := item.Data data, err = data.ToDevice(device, gotch.Float, false, true, false) if err != nil { return } var size []int64 size, err = data.Size() if err != nil { return } var zeros *torch.Tensor zeros, err = torch.Zeros(size, gotch.Float, device) if err != nil { return } data, err = zeros.Add(data, true) if err != nil { return } log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) data, err = data.SetRequiresGrad(true, false) if err != nil { return } log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) err = data.RetainGrad(false) if err != nil { return } log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) pred := model.ForwardT(data, true) pred, err = pred.SetRequiresGrad(true, true) if err != nil { return } err = pred.RetainGrad(false) if err != nil { return } label := item.Label label, err = label.ToDevice(device, gotch.Float, false, true, false) if err != nil { return } label, err = label.SetRequiresGrad(true, true) if err != nil { return } err = label.RetainGrad(false) if err != nil { return } // Calculate loss loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false) if err != nil { return } loss, err = loss.SetRequiresGrad(true, false) if err != nil { return } err = loss.RetainGrad(false) if err != nil { return } err = opt.ZeroGrad() if err != nil { return } err = loss.Backward() if err != nil { return } log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values()) log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values()) log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false)) vars := model.Vs.Variables() for k, v := range vars { log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false)) } model.Debug() err = opt.Step() if err != nil { return } trainLoss = loss.Float64Values()[0] // Calculate accuracy / *var p_pred, p_labels *torch.Tensor p_pred, err = pred.Argmax([]int64{1}, true, false) if err != nil { return } p_labels, err = item.Label.Argmax([]int64{1}, true, false) if err != nil { return } floats := p_pred.Float64Values() floats_labels := p_labels.Float64Values() for i := range floats { if floats[i] == floats_labels[i] { trainCorrect += 1 } } * / panic("fornow") } //v := []float64{} log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize)) / *correct := int64(0) //torch.NoGrad(func() { ok = true testIter := ds.TestIter(64) for ok { var item torch.Iter2Item item, ok = testIter.Next() if !ok { continue } output := model.Forward(item.Data) var pred, labels *torch.Tensor pred, err = output.Argmax([]int64{1}, true, false) if err != nil { return } labels, err = item.Label.Argmax([]int64{1}, true, false) if err != nil { return } floats := pred.Float64Values() floats_labels := labels.Float64Values() for i := range floats { if floats[i] == floats_labels[i] { correct += 1 } } } accuracy = float64(correct) / float64(ds.TestImagesSize) log.Info("Eval accuracy", "accuracy", accuracy) err = def.UpdateAfterEpoch(db, accuracy*100) if err != nil { return }* / //}) } result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id) err = os.MkdirAll(result_path, os.ModePerm) if err != nil { return } err = my_torch.SaveModel(model, path.Join(result_path, "model.dat")) if err != nil { return } log.Info("Model finished training!", "accuracy", accuracy) return */ }