package main import ( "git.andr3h3nriqu3s.com/andr3/gotch" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" //my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" "github.com/charmbracelet/log" ) func main() { log.Info("Hello world") m := train.BuildModel([]*dbtypes.Layer{ &dbtypes.Layer{ LayerType: dbtypes.LAYER_INPUT, Shape: "[ 3, 28, 28 ]", }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_FLATTEN, }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_DENSE, Shape: "[ 27 ]", }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_DENSE, Shape: "[ 18 ]", }, // &dbtypes.Layer{ // LayerType: dbtypes.LAYER_DENSE, // Shape: "[ 9 ]", // }, }, 0, true) //var err error d := gotch.CudaIfAvailable() log.Info("device", "d", d) m.To(d) var count = 0 // vars1 := m.Vs.Variables() // // for k, v := range vars1 { // ones := torch.MustOnes(v.MustSize(), gotch.Float, d) // v := ones.MustSetRequiresGrad(true, false) // v.MustDrop() // ones.RetainGrad(false) // // m.Vs.UpdateVarTensor(k, ones, true) // m.Refresh() // } // // opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) // if err != nil { // return // } log.Info("start") for count < 100 { ones := torch.MustOnes([]int64{1, 3, 28, 28}, gotch.Float, d) // ones = ones.MustSetRequiresGrad(true, true) // ones.RetainGrad(false) res := m.ForwardT(ones, true) //res = res.MustSetRequiresGrad(true, true) //res.RetainGrad(false) outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d) loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 2, false) if err != nil { log.Fatal(err) } // loss = loss.MustSetRequiresGrad(true, true) //opt.ZeroGrad() log.Info("loss", "loss", loss.Float64Values()) loss.MustBackward() //opt.Step() // log.Info(mean.MustGrad(false).Float64Values()) //ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0] // log.Info(res.MustGrad(true).MustMax(true).Float64Values()) // log.Info(ones_grad) vars := m.Vs.Variables() for k, v := range vars { log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(true).Float64Values()) } m.Debug() outs.MustDrop() count += 1 log.Fatal("grad zero") } log.Warn("out") }