Files
fyp/runner/src/training.rs

600 lines
19 KiB
Rust

use crate::{
dataloader::DataLoader,
model::{self, build_model},
settings::{ConfigFile, RunnerData},
tasks::{fail_task, Task},
types::{DataPointRequest, Definition, ModelClass},
};
use std::{
io::{self, Write},
sync::Arc,
};
use anyhow::Result;
use rand::{seq::SliceRandom, thread_rng};
use serde_json::json;
use tch::{
nn::{self, Module, OptimizerConfig},
Cuda, Tensor,
};
pub async fn handle_train(
task: &Task,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
) -> Result<()> {
let client = reqwest::Client::new();
println!("Preparing to train a model");
let to_send = json!({
"id": runner_data.id,
"taskId": task.id,
});
let mut defs: Vec<Definition> = client
.post(format!("{}/tasks/runner/train/defs", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
if defs.len() == 0 {
println!("No defs found");
fail_task(task, config, runner_data, "No definitions found").await?;
return Ok(());
}
let classes: Vec<ModelClass> = client
.post(format!("{}/tasks/runner/train/classes", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
let data: DataPointRequest = client
.post(format!("{}/tasks/runner/train/datapoints", config.hostname))
.header("token", &config.token)
.body(to_send.to_string())
.send()
.await?
.json()
.await?;
let mut testing = data.testing;
testing.shuffle(&mut thread_rng());
let mut data_loader = DataLoader::new(config.clone(), testing, classes.len() as i64, 64);
// TODO make this a vec
let mut model: Option<model::Model> = None;
loop {
let config = config.clone();
let runner_data = runner_data.clone();
let mut to_remove: Vec<usize> = Vec::new();
let mut def_iter = defs.iter_mut();
let mut i: usize = 0;
while let Some(def) = def_iter.next() {
def.updateStatus(
task,
config.clone(),
runner_data.clone(),
crate::types::DefinitionStatus::Training,
)
.await?;
let model_err = train_definition(
def,
&mut data_loader,
model,
config.clone(),
runner_data.clone(),
&task,
)
.await;
if model_err.is_err() {
println!("Failed to create model {:?}", model_err);
model = None;
to_remove.push(i);
continue;
}
model = model_err?;
i += 1;
}
defs = defs
.into_iter()
.enumerate()
.filter(|&(i, _)| to_remove.iter().any(|b| *b == i))
.map(|(_, e)| e)
.collect();
break;
}
fail_task(task, config, runner_data, "TODO").await?;
Ok(())
/*
for {
// Keep track of definitions that did not train fast enough
var toRemove ToRemoveList = []int{}
for i, def := range definitions {
accuracy, ml_model, err := trainDefinition(c, model, def, models[def.Id], classes)
if err != nil {
log.Error("Failed to train definition!Err:", "err", err)
def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
models[def.Id] = ml_model
if accuracy >= float64(def.TargetAccuracy) {
log.Info("Found a definition that reaches target_accuracy!")
_, err = db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, DEFINITION_STATUS_TRANIED, def.Epoch, def.Id)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
_, err = db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", DEFINITION_STATUS_CANCELD_TRAINING, def.Id, model.Id, DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
finished = true
break
}
if def.Epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %f < %d\n", accuracy, def.TargetAccuracy)
def.UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
toRemove = append(toRemove, i)
continue
}
_, err = db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.Epoch, DEFINITION_STATUS_PAUSED_TRAINING, def.Id)
if err != nil {
log.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return err
}
}
if finished {
break
}
sort.Sort(sort.Reverse(toRemove))
log.Info("Round done", "toRemove", toRemove)
for _, n := range toRemove {
// Clean up unsed models
models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
len_def := len(definitions)
if len_def == 0 {
break
}
if len_def == 1 {
continue
}
sort.Sort(sort.Reverse(definitions))
acc := definitions[0].Accuracy - 20.0
log.Info("Training models, Highest acc", "acc", definitions[0].Accuracy, "mod_acc", acc)
toRemove = []int{}
for i, def := range definitions {
if def.Accuracy < acc {
toRemove = append(toRemove, i)
}
}
log.Info("Removing due to accuracy", "toRemove", toRemove)
sort.Sort(sort.Reverse(toRemove))
for _, n := range toRemove {
log.Warn("Removing definition not fast enough learning", "n", n)
definitions[n].UpdateStatus(c, DEFINITION_STATUS_FAILED_TRAINING)
models[definitions[n].Id] = nil
definitions = remove(definitions, n)
}
}
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where md.model_id=$1 and md.status=$2 order by md.accuracy desc limit 1;", model.Id, DEFINITION_STATUS_TRANIED)
if err != nil {
if err == NotFoundError {
log.Error("All definitions failed to train!")
} else {
log.Error("DB: failed to read definition", "err", err)
}
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
log.Error("Failed to update model definition", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
to_delete, err := db.Query("select id from model_definition where status != $1 and model_id=$2", DEFINITION_STATUS_READY, model.Id)
if err != nil {
log.Error("Failed to select model_definition to delete")
log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
defer to_delete.Close()
for to_delete.Next() {
var id string
if err = to_delete.Scan(&id); err != nil {
log.Error("Failed to scan the id of a model_definition to delete", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
}
// TODO Check if returning also works here
if _, err = db.Exec("delete from model_definition where status!=$1 and model_id=$2;", DEFINITION_STATUS_READY, model.Id); err != nil {
log.Error("Failed to delete model_definition")
log.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
ModelUpdateStatus(c, model.Id, READY)
return
*/
}
async fn train_definition(
def: &Definition,
data_loader: &mut DataLoader,
model: Option<model::Model>,
config: Arc<ConfigFile>,
runner_data: Arc<RunnerData>,
task: &Task,
) -> Result<Option<model::Model>> {
let client = reqwest::Client::new();
println!("About to start training definition");
let mut accuracy = 0;
let model = model.unwrap_or({
let layers: Vec<model::Layer> = client
.post(format!("{}/tasks/runner/train/def/layers", config.hostname))
.header("token", &config.token)
.body(
json!({
"id": runner_data.id,
"taskId": task.id,
"defId": def.id,
})
.to_string(),
)
.send()
.await?
.json()
.await?;
build_model(layers, 0, true)
});
// TODO CUDA
// get device
// Move model to cuda
let mut opt = nn::Adam::default().build(&model.vs, 1e-3)?;
let mut last_acc = 0.0;
for epoch in 1..40 {
data_loader.restart();
let mut mean_loss: f64 = 0.0;
let mut mean_acc: f64 = 0.0;
while let Some((inputs, labels)) = data_loader.next() {
let inputs = inputs
.to_kind(tch::Kind::Float)
.to_device(tch::Device::Cuda(0));
let labels = labels
.to_kind(tch::Kind::Float)
.to_device(tch::Device::Cuda(0));
let out = model.seq.forward(&inputs);
let weight: Option<Tensor> = None;
let loss = out.binary_cross_entropy(&labels, weight, tch::Reduction::Mean);
opt.backward_step(&loss);
mean_loss += loss
.to_device(tch::Device::Cpu)
.unsqueeze(0)
.double_value(&[0]);
let out = out.to_device(tch::Device::Cpu);
let test = out.empty_like();
_ = out.clone(&test);
let out = test.argmax(1, true);
let mut labels = labels.to_device(tch::Device::Cpu);
labels = labels.unsqueeze(-1);
let size = out.size()[0];
let mut acc = 0;
for i in 0..size {
let res = out.double_value(&[i]);
let exp = labels.double_value(&[i, res as i64]);
if exp == 1.0 {
acc += 1;
}
}
mean_acc += acc as f64 / size as f64;
last_acc = acc as f64 / size as f64;
}
print!(
"\repoch: {} loss: {} acc: {} l acc: {} ",
epoch,
mean_loss / data_loader.len as f64,
mean_acc / data_loader.len as f64,
last_acc
);
io::stdout().flush().expect("Unable to flush stdout");
}
println!("\nlast acc: {}", last_acc);
return Ok(Some(model));
/*
opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ {
var trainIter *torch.Iter2
trainIter, err = ds.TrainIter(32)
if err != nil {
return
}
// trainIter.ToDevice(device)
log.Info("epoch", "epoch", epoch)
var trainLoss float64 = 0
var trainCorrect float64 = 0
ok := true
for ok {
var item torch.Iter2Item
var loss *torch.Tensor
item, ok = trainIter.Next()
if !ok {
continue
}
data := item.Data
data, err = data.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
var size []int64
size, err = data.Size()
if err != nil {
return
}
var zeros *torch.Tensor
zeros, err = torch.Zeros(size, gotch.Float, device)
if err != nil {
return
}
data, err = zeros.Add(data, true)
if err != nil {
return
}
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
data, err = data.SetRequiresGrad(true, false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
err = data.RetainGrad(false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
pred := model.ForwardT(data, true)
pred, err = pred.SetRequiresGrad(true, true)
if err != nil {
return
}
err = pred.RetainGrad(false)
if err != nil {
return
}
label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
label, err = label.SetRequiresGrad(true, true)
if err != nil {
return
}
err = label.RetainGrad(false)
if err != nil {
return
}
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
if err != nil {
return
}
loss, err = loss.SetRequiresGrad(true, false)
if err != nil {
return
}
err = loss.RetainGrad(false)
if err != nil {
return
}
err = opt.ZeroGrad()
if err != nil {
return
}
err = loss.Backward()
if err != nil {
return
}
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
vars := model.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false))
}
model.Debug()
err = opt.Step()
if err != nil {
return
}
trainLoss = loss.Float64Values()[0]
// Calculate accuracy
/ *var p_pred, p_labels *torch.Tensor
p_pred, err = pred.Argmax([]int64{1}, true, false)
if err != nil {
return
}
p_labels, err = item.Label.Argmax([]int64{1}, true, false)
if err != nil {
return
}
floats := p_pred.Float64Values()
floats_labels := p_labels.Float64Values()
for i := range floats {
if floats[i] == floats_labels[i] {
trainCorrect += 1
}
} * /
panic("fornow")
}
//v := []float64{}
log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize))
/ *correct := int64(0)
//torch.NoGrad(func() {
ok = true
testIter := ds.TestIter(64)
for ok {
var item torch.Iter2Item
item, ok = testIter.Next()
if !ok {
continue
}
output := model.Forward(item.Data)
var pred, labels *torch.Tensor
pred, err = output.Argmax([]int64{1}, true, false)
if err != nil {
return
}
labels, err = item.Label.Argmax([]int64{1}, true, false)
if err != nil {
return
}
floats := pred.Float64Values()
floats_labels := labels.Float64Values()
for i := range floats {
if floats[i] == floats_labels[i] {
correct += 1
}
}
}
accuracy = float64(correct) / float64(ds.TestImagesSize)
log.Info("Eval accuracy", "accuracy", accuracy)
err = def.UpdateAfterEpoch(db, accuracy*100)
if err != nil {
return
}* /
//})
}
result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id)
err = os.MkdirAll(result_path, os.ModePerm)
if err != nil {
return
}
err = my_torch.SaveModel(model, path.Join(result_path, "model.dat"))
if err != nil {
return
}
log.Info("Model finished training!", "accuracy", accuracy)
return
*/
}