More work on the torch version
This commit is contained in:
123
test.go
123
test.go
@@ -5,108 +5,117 @@ import (
|
||||
|
||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
|
||||
|
||||
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
||||
|
||||
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
func _main() {
|
||||
func main() {
|
||||
|
||||
log.Info("Hello world")
|
||||
|
||||
m := train.BuildModel([]*dbtypes.Layer{
|
||||
&dbtypes.Layer{
|
||||
LayerType: dbtypes.LAYER_INPUT,
|
||||
Shape: "[ 2, 3, 3 ]",
|
||||
Shape: "[ 3, 28, 28 ]",
|
||||
},
|
||||
&dbtypes.Layer{
|
||||
LayerType: dbtypes.LAYER_FLATTEN,
|
||||
},
|
||||
&dbtypes.Layer{
|
||||
LayerType: dbtypes.LAYER_DENSE,
|
||||
Shape: "[ 10 ]",
|
||||
Shape: "[ 27 ]",
|
||||
},
|
||||
&dbtypes.Layer{
|
||||
LayerType: dbtypes.LAYER_DENSE,
|
||||
Shape: "[ 10 ]",
|
||||
Shape: "[ 18 ]",
|
||||
},
|
||||
// &dbtypes.Layer{
|
||||
// LayerType: dbtypes.LAYER_DENSE,
|
||||
// Shape: "[ 9 ]",
|
||||
// },
|
||||
}, 0, true)
|
||||
|
||||
var err error
|
||||
//var err error
|
||||
|
||||
d := gotch.CudaIfAvailable()
|
||||
|
||||
log.Info("device", "d", d)
|
||||
log.Info("device", "d", d)
|
||||
|
||||
m.To(d)
|
||||
m.To(d)
|
||||
|
||||
var ones_grad float64 = 0
|
||||
var count = 0
|
||||
|
||||
vars1 := m.Vs.Variables()
|
||||
|
||||
for k, v := range vars1 {
|
||||
ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
|
||||
v := ones.MustSetRequiresGrad(true, false)
|
||||
v.MustDrop()
|
||||
ones.RetainGrad(false)
|
||||
|
||||
m.Vs.UpdateVarTensor(k, ones, true)
|
||||
m.Refresh()
|
||||
}
|
||||
|
||||
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d)
|
||||
ones = ones.MustSetRequiresGrad(true, true)
|
||||
ones.RetainGrad(false)
|
||||
log.Info("start")
|
||||
|
||||
res := m.ForwardT(ones, true)
|
||||
res = res.MustSetRequiresGrad(true, true)
|
||||
res.RetainGrad(false)
|
||||
for count < 100 {
|
||||
|
||||
outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d)
|
||||
outs = outs.MustSetRequiresGrad(true, true)
|
||||
outs.RetainsGrad(false)
|
||||
ones := torch.MustOnes([]int64{1, 3, 28, 28}, gotch.Float, d)
|
||||
// ones = ones.MustSetRequiresGrad(true, true)
|
||||
// ones.RetainGrad(false)
|
||||
|
||||
res := m.ForwardT(ones, true)
|
||||
res = res.MustSetRequiresGrad(true, true)
|
||||
res.RetainGrad(false)
|
||||
|
||||
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
loss = loss.MustSetRequiresGrad(true, false)
|
||||
outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
|
||||
|
||||
opt.ZeroGrad()
|
||||
|
||||
|
||||
log.Info("loss", "loss", loss.Float64Values())
|
||||
|
||||
loss.MustBackward()
|
||||
|
||||
|
||||
opt.Step()
|
||||
|
||||
// log.Info(mean.MustGrad(false).Float64Values())
|
||||
log.Info(res.MustGrad(false).Float64Values())
|
||||
log.Info(ones.MustGrad(false).Float64Values())
|
||||
log.Info(outs.MustGrad(false).Float64Values())
|
||||
|
||||
vars := m.Vs.Variables()
|
||||
|
||||
for k, v := range vars {
|
||||
|
||||
log.Info("[grad check]", "k", k)
|
||||
|
||||
var grad *torch.Tensor
|
||||
grad, err = v.Grad(false)
|
||||
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 2, false)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
log.Fatal(err)
|
||||
}
|
||||
loss = loss.MustSetRequiresGrad(true, true)
|
||||
|
||||
opt.ZeroGrad()
|
||||
|
||||
log.Info("loss", "loss", loss.Float64Values())
|
||||
|
||||
loss.MustBackward()
|
||||
|
||||
opt.Step()
|
||||
|
||||
// log.Info(mean.MustGrad(false).Float64Values())
|
||||
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
|
||||
log.Info(res.MustGrad(true).MustMax(true).Float64Values())
|
||||
|
||||
log.Info(ones_grad)
|
||||
|
||||
vars := m.Vs.Variables()
|
||||
|
||||
for k, v := range vars {
|
||||
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(true).Float64Values())
|
||||
}
|
||||
|
||||
grad, err = grad.Abs(false)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
m.Debug()
|
||||
|
||||
grad, err = grad.Max(false)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
outs.MustDrop()
|
||||
|
||||
count += 1
|
||||
|
||||
log.Fatal("grad zero")
|
||||
|
||||
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
|
||||
}
|
||||
|
||||
log.Warn("out")
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user