diff --git a/test.go b/test.go index 94e7d56..fb317dd 100644 --- a/test.go +++ b/test.go @@ -6,7 +6,7 @@ import ( dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" - my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" + //my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" "github.com/charmbracelet/log" @@ -46,25 +46,24 @@ func main() { m.To(d) - var ones_grad float64 = 0 var count = 0 - vars1 := m.Vs.Variables() - - for k, v := range vars1 { - ones := torch.MustOnes(v.MustSize(), gotch.Float, d) - v := ones.MustSetRequiresGrad(true, false) - v.MustDrop() - ones.RetainGrad(false) - - m.Vs.UpdateVarTensor(k, ones, true) - m.Refresh() - } - - opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) - if err != nil { - return - } +// vars1 := m.Vs.Variables() +// +// for k, v := range vars1 { +// ones := torch.MustOnes(v.MustSize(), gotch.Float, d) +// v := ones.MustSetRequiresGrad(true, false) +// v.MustDrop() +// ones.RetainGrad(false) +// +// m.Vs.UpdateVarTensor(k, ones, true) +// m.Refresh() +// } +// +// opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) +// if err != nil { +// return +// } log.Info("start") @@ -75,8 +74,8 @@ func main() { // ones.RetainGrad(false) res := m.ForwardT(ones, true) - res = res.MustSetRequiresGrad(true, true) - res.RetainGrad(false) + //res = res.MustSetRequiresGrad(true, true) + //res.RetainGrad(false) outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d) @@ -84,21 +83,21 @@ func main() { if err != nil { log.Fatal(err) } - loss = loss.MustSetRequiresGrad(true, true) + // loss = loss.MustSetRequiresGrad(true, true) - opt.ZeroGrad() + //opt.ZeroGrad() log.Info("loss", "loss", loss.Float64Values()) loss.MustBackward() - opt.Step() + //opt.Step() // log.Info(mean.MustGrad(false).Float64Values()) //ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0] - log.Info(res.MustGrad(true).MustMax(true).Float64Values()) + // log.Info(res.MustGrad(true).MustMax(true).Float64Values()) - log.Info(ones_grad) +// log.Info(ones_grad) vars := m.Vs.Variables()