feat: more work done on the torch branch

This commit is contained in:
Andre Henriques 2024-05-02 16:38:29 +01:00
parent 527b57a111
commit 568be78723

49
test.go
View File

@ -6,7 +6,7 @@ import (
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" //my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
@ -46,25 +46,24 @@ func main() {
m.To(d) m.To(d)
var ones_grad float64 = 0
var count = 0 var count = 0
vars1 := m.Vs.Variables() // vars1 := m.Vs.Variables()
//
for k, v := range vars1 { // for k, v := range vars1 {
ones := torch.MustOnes(v.MustSize(), gotch.Float, d) // ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
v := ones.MustSetRequiresGrad(true, false) // v := ones.MustSetRequiresGrad(true, false)
v.MustDrop() // v.MustDrop()
ones.RetainGrad(false) // ones.RetainGrad(false)
//
m.Vs.UpdateVarTensor(k, ones, true) // m.Vs.UpdateVarTensor(k, ones, true)
m.Refresh() // m.Refresh()
} // }
//
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) // opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
if err != nil { // if err != nil {
return // return
} // }
log.Info("start") log.Info("start")
@ -75,8 +74,8 @@ func main() {
// ones.RetainGrad(false) // ones.RetainGrad(false)
res := m.ForwardT(ones, true) res := m.ForwardT(ones, true)
res = res.MustSetRequiresGrad(true, true) //res = res.MustSetRequiresGrad(true, true)
res.RetainGrad(false) //res.RetainGrad(false)
outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d) outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
@ -84,21 +83,21 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
loss = loss.MustSetRequiresGrad(true, true) // loss = loss.MustSetRequiresGrad(true, true)
opt.ZeroGrad() //opt.ZeroGrad()
log.Info("loss", "loss", loss.Float64Values()) log.Info("loss", "loss", loss.Float64Values())
loss.MustBackward() loss.MustBackward()
opt.Step() //opt.Step()
// log.Info(mean.MustGrad(false).Float64Values()) // log.Info(mean.MustGrad(false).Float64Values())
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0] //ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
log.Info(res.MustGrad(true).MustMax(true).Float64Values()) // log.Info(res.MustGrad(true).MustMax(true).Float64Values())
log.Info(ones_grad) // log.Info(ones_grad)
vars := m.Vs.Variables() vars := m.Vs.Variables()