More work done on torch

This commit is contained in:
2024-04-22 00:09:07 +01:00
parent 28707b3f1b
commit 703fea46f2
13 changed files with 2435 additions and 96 deletions

View File

@@ -16,16 +16,17 @@ import (
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
my_torch "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
modelloader "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/modelloader"
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"git.andr3h3nriqu3s.com/andr3/gotch"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log"
"github.com/goccy/go-json"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
torch "github.com/sugarme/gotch/ts"
)
const EPOCH_PER_RUN = 20
@@ -132,11 +133,12 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
}
model = my_torch.BuildModel(layers, 0, true)
}
// TODO Make the runner provide this
// device := gotch.CudaIfAvailable()
device := gotch.CPU
device := gotch.CudaIfAvailable()
// device := gotch.CPU
result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id)
err = os.MkdirAll(result_path, os.ModePerm)
@@ -144,6 +146,16 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
opt1.Debug() */
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
// TODO remove this
model.To(device)
defer model.To(gotch.CPU)
@@ -153,23 +165,18 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
err = ds.To(device)
if err != nil {
return
}
opt, err := nn.DefaultAdamConfig().Build(model.Vs, 0.001)
opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ {
var trainIter *torch.Iter2
trainIter, err = ds.TrainIter(64)
if err != nil {
return
}
// trainIter.ToDevice(device)
var trainIter *torch.Iter2
trainIter, err = ds.TrainIter(32)
if err != nil {
return
}
// trainIter.ToDevice(device)
log.Info("epoch", "epoch", epoch)
@@ -184,19 +191,49 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
continue
}
pred := model.ForwardT(item.Data, true)
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(item.Label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
data := item.Data
data, err = data.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
data, err = data.SetRequiresGrad(true, true)
if err != nil {
return
}
err = data.RetainGrad(false)
if err != nil {
return
}
pred := model.ForwardT(data, true)
pred, err = pred.SetRequiresGrad(true, true)
if err != nil {
return
}
pred.RetainGrad(false)
label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
label, err = label.SetRequiresGrad(true, true)
if err != nil {
return
}
label.RetainGrad(false)
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
if err != nil {
return
}
loss, err = loss.SetRequiresGrad(true, false)
if err != nil {
return
}
if err != nil {
return
}
err = opt.ZeroGrad()
if err != nil {
@@ -213,11 +250,32 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
vars := model.Vs.Variables()
for k, v := range vars {
var grad *torch.Tensor
grad, err = v.Grad(false)
if err != nil {
return
}
grad, err = grad.Abs(false)
if err != nil {
return
}
grad, err = grad.Max(false)
if err != nil {
return
}
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
}
trainLoss = loss.Float64Values()[0]
// Calculate accuracy
var p_pred, p_labels *torch.Tensor
/*var p_pred, p_labels *torch.Tensor
p_pred, err = pred.Argmax([]int64{1}, true, false)
if err != nil {
return
@@ -235,9 +293,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if floats[i] == floats_labels[i] {
trainCorrect += 1
}
}
} */
// panic("fornow")
}
//v := []float64{}
log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize))
/*correct := int64(0)