feat: more work done on the torch branch
This commit is contained in:
parent
527b57a111
commit
568be78723
49
test.go
49
test.go
@ -6,7 +6,7 @@ import (
|
|||||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
|
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
|
||||||
|
|
||||||
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
//my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
||||||
|
|
||||||
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
@ -46,25 +46,24 @@ func main() {
|
|||||||
|
|
||||||
m.To(d)
|
m.To(d)
|
||||||
|
|
||||||
var ones_grad float64 = 0
|
|
||||||
var count = 0
|
var count = 0
|
||||||
|
|
||||||
vars1 := m.Vs.Variables()
|
// vars1 := m.Vs.Variables()
|
||||||
|
//
|
||||||
for k, v := range vars1 {
|
// for k, v := range vars1 {
|
||||||
ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
|
// ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
|
||||||
v := ones.MustSetRequiresGrad(true, false)
|
// v := ones.MustSetRequiresGrad(true, false)
|
||||||
v.MustDrop()
|
// v.MustDrop()
|
||||||
ones.RetainGrad(false)
|
// ones.RetainGrad(false)
|
||||||
|
//
|
||||||
m.Vs.UpdateVarTensor(k, ones, true)
|
// m.Vs.UpdateVarTensor(k, ones, true)
|
||||||
m.Refresh()
|
// m.Refresh()
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
|
// opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
|
||||||
log.Info("start")
|
log.Info("start")
|
||||||
|
|
||||||
@ -75,8 +74,8 @@ func main() {
|
|||||||
// ones.RetainGrad(false)
|
// ones.RetainGrad(false)
|
||||||
|
|
||||||
res := m.ForwardT(ones, true)
|
res := m.ForwardT(ones, true)
|
||||||
res = res.MustSetRequiresGrad(true, true)
|
//res = res.MustSetRequiresGrad(true, true)
|
||||||
res.RetainGrad(false)
|
//res.RetainGrad(false)
|
||||||
|
|
||||||
outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
|
outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
|
||||||
|
|
||||||
@ -84,21 +83,21 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
loss = loss.MustSetRequiresGrad(true, true)
|
// loss = loss.MustSetRequiresGrad(true, true)
|
||||||
|
|
||||||
opt.ZeroGrad()
|
//opt.ZeroGrad()
|
||||||
|
|
||||||
log.Info("loss", "loss", loss.Float64Values())
|
log.Info("loss", "loss", loss.Float64Values())
|
||||||
|
|
||||||
loss.MustBackward()
|
loss.MustBackward()
|
||||||
|
|
||||||
opt.Step()
|
//opt.Step()
|
||||||
|
|
||||||
// log.Info(mean.MustGrad(false).Float64Values())
|
// log.Info(mean.MustGrad(false).Float64Values())
|
||||||
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
|
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
|
||||||
log.Info(res.MustGrad(true).MustMax(true).Float64Values())
|
// log.Info(res.MustGrad(true).MustMax(true).Float64Values())
|
||||||
|
|
||||||
log.Info(ones_grad)
|
// log.Info(ones_grad)
|
||||||
|
|
||||||
vars := m.Vs.Variables()
|
vars := m.Vs.Variables()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user