package main import ( "git.andr3h3nriqu3s.com/andr3/gotch" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" "github.com/charmbracelet/log" ) func _main() { log.Info("Hello world") m := train.BuildModel([]*dbtypes.Layer{ &dbtypes.Layer{ LayerType: dbtypes.LAYER_INPUT, Shape: "[ 2, 3, 3 ]", }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_FLATTEN, }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_DENSE, Shape: "[ 10 ]", }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_DENSE, Shape: "[ 10 ]", }, }, 0, true) var err error d := gotch.CudaIfAvailable() log.Info("device", "d", d) m.To(d) opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) if err != nil { return } ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d) ones = ones.MustSetRequiresGrad(true, true) ones.RetainGrad(false) res := m.ForwardT(ones, true) res = res.MustSetRequiresGrad(true, true) res.RetainGrad(false) outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d) outs = outs.MustSetRequiresGrad(true, true) outs.RetainsGrad(false) loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false) if err != nil { return } loss = loss.MustSetRequiresGrad(true, false) opt.ZeroGrad() log.Info("loss", "loss", loss.Float64Values()) loss.MustBackward() opt.Step() // log.Info(mean.MustGrad(false).Float64Values()) log.Info(res.MustGrad(false).Float64Values()) log.Info(ones.MustGrad(false).Float64Values()) log.Info(outs.MustGrad(false).Float64Values()) vars := m.Vs.Variables() for k, v := range vars { log.Info("[grad check]", "k", k) var grad *torch.Tensor grad, err = v.Grad(false) if err != nil { log.Error(err) return } grad, err = grad.Abs(false) if err != nil { log.Error(err) return } grad, err = grad.Max(false) if err != nil { log.Error(err) return } log.Info("[grad check]", "k", k, "grad", grad.Float64Values()) } }