fyp/test.go

113 lines
2.2 KiB
Go
Raw Normal View History

2024-04-22 00:09:07 +01:00
package main
import (
"git.andr3h3nriqu3s.com/andr3/gotch"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log"
)
2024-04-23 00:14:35 +01:00
func _main() {
2024-04-22 00:09:07 +01:00
log.Info("Hello world")
m := train.BuildModel([]*dbtypes.Layer{
&dbtypes.Layer{
LayerType: dbtypes.LAYER_INPUT,
Shape: "[ 2, 3, 3 ]",
},
&dbtypes.Layer{
LayerType: dbtypes.LAYER_FLATTEN,
},
&dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE,
Shape: "[ 10 ]",
},
2024-04-23 00:14:35 +01:00
&dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE,
Shape: "[ 10 ]",
},
2024-04-22 00:09:07 +01:00
}, 0, true)
var err error
d := gotch.CudaIfAvailable()
log.Info("device", "d", d)
m.To(d)
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
if err != nil {
return
}
ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d)
ones = ones.MustSetRequiresGrad(true, true)
ones.RetainGrad(false)
res := m.ForwardT(ones, true)
res = res.MustSetRequiresGrad(true, true)
res.RetainGrad(false)
outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d)
outs = outs.MustSetRequiresGrad(true, true)
outs.RetainsGrad(false)
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false)
if err != nil {
return
}
loss = loss.MustSetRequiresGrad(true, false)
opt.ZeroGrad()
log.Info("loss", "loss", loss.Float64Values())
loss.MustBackward()
opt.Step()
// log.Info(mean.MustGrad(false).Float64Values())
log.Info(res.MustGrad(false).Float64Values())
log.Info(ones.MustGrad(false).Float64Values())
log.Info(outs.MustGrad(false).Float64Values())
vars := m.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k)
var grad *torch.Tensor
grad, err = v.Grad(false)
if err != nil {
log.Error(err)
return
}
grad, err = grad.Abs(false)
if err != nil {
log.Error(err)
return
}
grad, err = grad.Max(false)
if err != nil {
log.Error(err)
return
}
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
}
}