From 527b57a1112abe24f84a2597a4f73d537c0d34f6 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Tue, 23 Apr 2024 11:54:30 +0100 Subject: [PATCH] More work on the torch version --- logic/models/train/torch/torch.go | 1 + logic/models/train/train_normal.go | 85 ++++++++++---------- main.go | 2 +- test.go | 123 ++++++++++++++++------------- 4 files changed, 109 insertions(+), 102 deletions(-) diff --git a/logic/models/train/torch/torch.go b/logic/models/train/torch/torch.go index d0f591e..0ac857b 100644 --- a/logic/models/train/torch/torch.go +++ b/logic/models/train/torch/torch.go @@ -27,6 +27,7 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { } if len(n.Layers) == 1 { + log.Info("here") return n.Layers[0].ForwardT(x, train) } diff --git a/logic/models/train/train_normal.go b/logic/models/train/train_normal.go index ddc7217..b50fc79 100644 --- a/logic/models/train/train_normal.go +++ b/logic/models/train/train_normal.go @@ -182,53 +182,54 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor } data := item.Data - data, err = data.ToDevice(device, gotch.Float, true, true, false) + data, err = data.ToDevice(device, gotch.Float, false, true, false) if err != nil { return } - data, err = data.SetRequiresGrad(true, true) + var size []int64 + size, err = data.Size() if err != nil { return } + + var zeros *torch.Tensor + zeros, err = torch.Zeros(size, gotch.Float, device) + if err != nil { + return + } + + data, err = zeros.Add(data, true) + if err != nil { + return + } + + log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) + + data, err = data.SetRequiresGrad(true, false) + if err != nil { + return + } + + log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) + err = data.RetainGrad(false) if err != nil { return } - var size []int64 - size, err = data.Size() - if err != nil { - return - } + log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad()) - var ones *torch.Tensor - ones, err = torch.Ones(size, gotch.Float, device) - if err != nil { - return - } - - ones, err = ones.SetRequiresGrad(true, true) - if err != nil { - return - } - - err = ones.RetainGrad(false) - if err != nil { - return - } - - //pred := model.ForwardT(data, true) - pred := model.ForwardT(ones, true) + pred := model.ForwardT(data, true) pred, err = pred.SetRequiresGrad(true, true) if err != nil { return } err = pred.RetainGrad(false) - if err != nil { - return - } + if err != nil { + return + } label := item.Label label, err = label.ToDevice(device, gotch.Float, false, true, false) @@ -240,9 +241,9 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor return } err = label.RetainGrad(false) - if err != nil { - return - } + if err != nil { + return + } // Calculate loss loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false) @@ -253,11 +254,10 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor if err != nil { return } - err = loss.RetainGrad(false) - if err != nil { - return - } - + err = loss.RetainGrad(false) + if err != nil { + return + } err = opt.ZeroGrad() if err != nil { @@ -269,20 +269,17 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor return } - - log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() ) - log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) ) - log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) ) - log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() ) + log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values()) + log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values()) + log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false)) vars := model.Vs.Variables() for k, v := range vars { - log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) ) + log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false)) } - model.Debug() - + model.Debug() err = opt.Step() if err != nil { diff --git a/main.go b/main.go index 2433118..4b2237c 100644 --- a/main.go +++ b/main.go @@ -23,7 +23,7 @@ const ( dbname = "aistuff" ) -func main() { +func main_() { psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ "password=%s dbname=%s sslmode=disable", diff --git a/test.go b/test.go index 6bfdaa7..94e7d56 100644 --- a/test.go +++ b/test.go @@ -5,108 +5,117 @@ import ( dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" + my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" "github.com/charmbracelet/log" ) -func _main() { +func main() { log.Info("Hello world") m := train.BuildModel([]*dbtypes.Layer{ &dbtypes.Layer{ LayerType: dbtypes.LAYER_INPUT, - Shape: "[ 2, 3, 3 ]", + Shape: "[ 3, 28, 28 ]", }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_FLATTEN, }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_DENSE, - Shape: "[ 10 ]", + Shape: "[ 27 ]", }, &dbtypes.Layer{ LayerType: dbtypes.LAYER_DENSE, - Shape: "[ 10 ]", + Shape: "[ 18 ]", }, + // &dbtypes.Layer{ + // LayerType: dbtypes.LAYER_DENSE, + // Shape: "[ 9 ]", + // }, }, 0, true) - var err error + //var err error d := gotch.CudaIfAvailable() - log.Info("device", "d", d) + log.Info("device", "d", d) - m.To(d) + m.To(d) + var ones_grad float64 = 0 + var count = 0 + + vars1 := m.Vs.Variables() + + for k, v := range vars1 { + ones := torch.MustOnes(v.MustSize(), gotch.Float, d) + v := ones.MustSetRequiresGrad(true, false) + v.MustDrop() + ones.RetainGrad(false) + + m.Vs.UpdateVarTensor(k, ones, true) + m.Refresh() + } opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) if err != nil { return } - ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d) - ones = ones.MustSetRequiresGrad(true, true) - ones.RetainGrad(false) + log.Info("start") - res := m.ForwardT(ones, true) - res = res.MustSetRequiresGrad(true, true) - res.RetainGrad(false) + for count < 100 { - outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d) - outs = outs.MustSetRequiresGrad(true, true) - outs.RetainsGrad(false) + ones := torch.MustOnes([]int64{1, 3, 28, 28}, gotch.Float, d) + // ones = ones.MustSetRequiresGrad(true, true) + // ones.RetainGrad(false) + res := m.ForwardT(ones, true) + res = res.MustSetRequiresGrad(true, true) + res.RetainGrad(false) - loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false) - if err != nil { - return - } - loss = loss.MustSetRequiresGrad(true, false) + outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d) - opt.ZeroGrad() - - - log.Info("loss", "loss", loss.Float64Values()) - - loss.MustBackward() - - - opt.Step() - - // log.Info(mean.MustGrad(false).Float64Values()) - log.Info(res.MustGrad(false).Float64Values()) - log.Info(ones.MustGrad(false).Float64Values()) - log.Info(outs.MustGrad(false).Float64Values()) - - vars := m.Vs.Variables() - - for k, v := range vars { - - log.Info("[grad check]", "k", k) - - var grad *torch.Tensor - grad, err = v.Grad(false) + loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 2, false) if err != nil { - log.Error(err) - return + log.Fatal(err) + } + loss = loss.MustSetRequiresGrad(true, true) + + opt.ZeroGrad() + + log.Info("loss", "loss", loss.Float64Values()) + + loss.MustBackward() + + opt.Step() + + // log.Info(mean.MustGrad(false).Float64Values()) + //ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0] + log.Info(res.MustGrad(true).MustMax(true).Float64Values()) + + log.Info(ones_grad) + + vars := m.Vs.Variables() + + for k, v := range vars { + log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(true).Float64Values()) } - grad, err = grad.Abs(false) - if err != nil { - log.Error(err) - return - } + m.Debug() - grad, err = grad.Max(false) - if err != nil { - log.Error(err) - return - } + outs.MustDrop() + + count += 1 + + log.Fatal("grad zero") - log.Info("[grad check]", "k", k, "grad", grad.Float64Values()) } + log.Warn("out") + }