fyp/test.go

121 lines
2.4 KiB
Go
Raw Permalink Normal View History

2024-04-22 00:09:07 +01:00
package main
import (
"git.andr3h3nriqu3s.com/andr3/gotch"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
2024-04-23 11:54:30 +01:00
//my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
2024-04-22 00:09:07 +01:00
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log"
)
2024-04-23 11:54:30 +01:00
func main() {
2024-04-22 00:09:07 +01:00
log.Info("Hello world")
m := train.BuildModel([]*dbtypes.Layer{
&dbtypes.Layer{
LayerType: dbtypes.LAYER_INPUT,
2024-04-23 11:54:30 +01:00
Shape: "[ 3, 28, 28 ]",
2024-04-22 00:09:07 +01:00
},
&dbtypes.Layer{
LayerType: dbtypes.LAYER_FLATTEN,
},
&dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE,
2024-04-23 11:54:30 +01:00
Shape: "[ 27 ]",
2024-04-22 00:09:07 +01:00
},
2024-04-23 00:14:35 +01:00
&dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE,
2024-04-23 11:54:30 +01:00
Shape: "[ 18 ]",
2024-04-23 00:14:35 +01:00
},
2024-04-23 11:54:30 +01:00
// &dbtypes.Layer{
// LayerType: dbtypes.LAYER_DENSE,
// Shape: "[ 9 ]",
// },
2024-04-22 00:09:07 +01:00
}, 0, true)
2024-04-23 11:54:30 +01:00
//var err error
2024-04-22 00:09:07 +01:00
d := gotch.CudaIfAvailable()
2024-04-23 11:54:30 +01:00
log.Info("device", "d", d)
m.To(d)
var count = 0
// vars1 := m.Vs.Variables()
//
// for k, v := range vars1 {
// ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
// v := ones.MustSetRequiresGrad(true, false)
// v.MustDrop()
// ones.RetainGrad(false)
//
// m.Vs.UpdateVarTensor(k, ones, true)
// m.Refresh()
// }
//
// opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
// if err != nil {
// return
// }
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
log.Info("start")
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
for count < 100 {
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
ones := torch.MustOnes([]int64{1, 3, 28, 28}, gotch.Float, d)
// ones = ones.MustSetRequiresGrad(true, true)
// ones.RetainGrad(false)
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
res := m.ForwardT(ones, true)
//res = res.MustSetRequiresGrad(true, true)
//res.RetainGrad(false)
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 2, false)
if err != nil {
log.Fatal(err)
}
// loss = loss.MustSetRequiresGrad(true, true)
2024-04-22 00:09:07 +01:00
//opt.ZeroGrad()
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
log.Info("loss", "loss", loss.Float64Values())
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
loss.MustBackward()
2024-04-22 00:09:07 +01:00
//opt.Step()
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
// log.Info(mean.MustGrad(false).Float64Values())
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
// log.Info(res.MustGrad(true).MustMax(true).Float64Values())
2024-04-22 00:09:07 +01:00
// log.Info(ones_grad)
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
vars := m.Vs.Variables()
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(true).Float64Values())
}
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
m.Debug()
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
outs.MustDrop()
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
count += 1
2024-04-22 00:09:07 +01:00
2024-04-23 11:54:30 +01:00
log.Fatal("grad zero")
2024-04-22 00:09:07 +01:00
}
2024-04-23 11:54:30 +01:00
log.Warn("out")
2024-04-22 00:09:07 +01:00
}