More work done on torch
This commit is contained in:
@@ -16,16 +16,17 @@ import (
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
|
||||
my_torch "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
|
||||
modelloader "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/modelloader"
|
||||
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/gotch"
|
||||
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
torch "github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
const EPOCH_PER_RUN = 20
|
||||
@@ -132,11 +133,12 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
}
|
||||
|
||||
model = my_torch.BuildModel(layers, 0, true)
|
||||
|
||||
}
|
||||
|
||||
// TODO Make the runner provide this
|
||||
// device := gotch.CudaIfAvailable()
|
||||
device := gotch.CPU
|
||||
device := gotch.CudaIfAvailable()
|
||||
// device := gotch.CPU
|
||||
|
||||
result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id)
|
||||
err = os.MkdirAll(result_path, os.ModePerm)
|
||||
@@ -144,6 +146,16 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt1.Debug() */
|
||||
|
||||
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
|
||||
|
||||
// TODO remove this
|
||||
model.To(device)
|
||||
defer model.To(gotch.CPU)
|
||||
|
||||
@@ -153,23 +165,18 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
err = ds.To(device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(model.Vs, 0.001)
|
||||
opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ {
|
||||
var trainIter *torch.Iter2
|
||||
trainIter, err = ds.TrainIter(64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// trainIter.ToDevice(device)
|
||||
var trainIter *torch.Iter2
|
||||
trainIter, err = ds.TrainIter(32)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// trainIter.ToDevice(device)
|
||||
|
||||
log.Info("epoch", "epoch", epoch)
|
||||
|
||||
@@ -184,19 +191,49 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
continue
|
||||
}
|
||||
|
||||
pred := model.ForwardT(item.Data, true)
|
||||
|
||||
// Calculate loss
|
||||
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(item.Label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
||||
data := item.Data
|
||||
data, err = data.ToDevice(device, gotch.Float, false, true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data, err = data.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = data.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pred := model.ForwardT(data, true)
|
||||
pred, err = pred.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pred.RetainGrad(false)
|
||||
|
||||
label := item.Label
|
||||
label, err = label.ToDevice(device, gotch.Float, false, true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
label, err = label.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
label.RetainGrad(false)
|
||||
|
||||
// Calculate loss
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
loss, err = loss.SetRequiresGrad(true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = opt.ZeroGrad()
|
||||
if err != nil {
|
||||
@@ -213,11 +250,32 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
vars := model.Vs.Variables()
|
||||
|
||||
for k, v := range vars {
|
||||
var grad *torch.Tensor
|
||||
grad, err = v.Grad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
grad, err = grad.Abs(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
grad, err = grad.Max(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
|
||||
}
|
||||
|
||||
trainLoss = loss.Float64Values()[0]
|
||||
|
||||
// Calculate accuracy
|
||||
|
||||
var p_pred, p_labels *torch.Tensor
|
||||
/*var p_pred, p_labels *torch.Tensor
|
||||
p_pred, err = pred.Argmax([]int64{1}, true, false)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -235,9 +293,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
if floats[i] == floats_labels[i] {
|
||||
trainCorrect += 1
|
||||
}
|
||||
}
|
||||
} */
|
||||
|
||||
// panic("fornow")
|
||||
}
|
||||
|
||||
//v := []float64{}
|
||||
|
||||
log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize))
|
||||
|
||||
/*correct := int64(0)
|
||||
|
||||
Reference in New Issue
Block a user