More work on the torch version

This commit is contained in:
Andre Henriques 2024-04-23 11:54:30 +01:00
parent a4a9ade71f
commit 527b57a111
4 changed files with 109 additions and 102 deletions

View File

@ -27,6 +27,7 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
} }
if len(n.Layers) == 1 { if len(n.Layers) == 1 {
log.Info("here")
return n.Layers[0].ForwardT(x, train) return n.Layers[0].ForwardT(x, train)
} }

View File

@ -182,53 +182,54 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
} }
data := item.Data data := item.Data
data, err = data.ToDevice(device, gotch.Float, true, true, false) data, err = data.ToDevice(device, gotch.Float, false, true, false)
if err != nil { if err != nil {
return return
} }
data, err = data.SetRequiresGrad(true, true) var size []int64
size, err = data.Size()
if err != nil { if err != nil {
return return
} }
var zeros *torch.Tensor
zeros, err = torch.Zeros(size, gotch.Float, device)
if err != nil {
return
}
data, err = zeros.Add(data, true)
if err != nil {
return
}
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
data, err = data.SetRequiresGrad(true, false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
err = data.RetainGrad(false) err = data.RetainGrad(false)
if err != nil { if err != nil {
return return
} }
var size []int64 log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
size, err = data.Size()
if err != nil {
return
}
var ones *torch.Tensor pred := model.ForwardT(data, true)
ones, err = torch.Ones(size, gotch.Float, device)
if err != nil {
return
}
ones, err = ones.SetRequiresGrad(true, true)
if err != nil {
return
}
err = ones.RetainGrad(false)
if err != nil {
return
}
//pred := model.ForwardT(data, true)
pred := model.ForwardT(ones, true)
pred, err = pred.SetRequiresGrad(true, true) pred, err = pred.SetRequiresGrad(true, true)
if err != nil { if err != nil {
return return
} }
err = pred.RetainGrad(false) err = pred.RetainGrad(false)
if err != nil { if err != nil {
return return
} }
label := item.Label label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false) label, err = label.ToDevice(device, gotch.Float, false, true, false)
@ -240,9 +241,9 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return return
} }
err = label.RetainGrad(false) err = label.RetainGrad(false)
if err != nil { if err != nil {
return return
} }
// Calculate loss // Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false) loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
@ -253,11 +254,10 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if err != nil { if err != nil {
return return
} }
err = loss.RetainGrad(false) err = loss.RetainGrad(false)
if err != nil { if err != nil {
return return
} }
err = opt.ZeroGrad() err = opt.ZeroGrad()
if err != nil { if err != nil {
@ -269,20 +269,17 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return return
} }
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() ) log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) ) log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
vars := model.Vs.Variables() vars := model.Vs.Variables()
for k, v := range vars { for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) ) log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false))
} }
model.Debug() model.Debug()
err = opt.Step() err = opt.Step()
if err != nil { if err != nil {

View File

@ -23,7 +23,7 @@ const (
dbname = "aistuff" dbname = "aistuff"
) )
func main() { func main_() {
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
"password=%s dbname=%s sslmode=disable", "password=%s dbname=%s sslmode=disable",

123
test.go
View File

@ -5,108 +5,117 @@ import (
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
) )
func _main() { func main() {
log.Info("Hello world") log.Info("Hello world")
m := train.BuildModel([]*dbtypes.Layer{ m := train.BuildModel([]*dbtypes.Layer{
&dbtypes.Layer{ &dbtypes.Layer{
LayerType: dbtypes.LAYER_INPUT, LayerType: dbtypes.LAYER_INPUT,
Shape: "[ 2, 3, 3 ]", Shape: "[ 3, 28, 28 ]",
}, },
&dbtypes.Layer{ &dbtypes.Layer{
LayerType: dbtypes.LAYER_FLATTEN, LayerType: dbtypes.LAYER_FLATTEN,
}, },
&dbtypes.Layer{ &dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE, LayerType: dbtypes.LAYER_DENSE,
Shape: "[ 10 ]", Shape: "[ 27 ]",
}, },
&dbtypes.Layer{ &dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE, LayerType: dbtypes.LAYER_DENSE,
Shape: "[ 10 ]", Shape: "[ 18 ]",
}, },
// &dbtypes.Layer{
// LayerType: dbtypes.LAYER_DENSE,
// Shape: "[ 9 ]",
// },
}, 0, true) }, 0, true)
var err error //var err error
d := gotch.CudaIfAvailable() d := gotch.CudaIfAvailable()
log.Info("device", "d", d) log.Info("device", "d", d)
m.To(d) m.To(d)
var ones_grad float64 = 0
var count = 0
vars1 := m.Vs.Variables()
for k, v := range vars1 {
ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
v := ones.MustSetRequiresGrad(true, false)
v.MustDrop()
ones.RetainGrad(false)
m.Vs.UpdateVarTensor(k, ones, true)
m.Refresh()
}
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
if err != nil { if err != nil {
return return
} }
ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d) log.Info("start")
ones = ones.MustSetRequiresGrad(true, true)
ones.RetainGrad(false)
res := m.ForwardT(ones, true) for count < 100 {
res = res.MustSetRequiresGrad(true, true)
res.RetainGrad(false)
outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d) ones := torch.MustOnes([]int64{1, 3, 28, 28}, gotch.Float, d)
outs = outs.MustSetRequiresGrad(true, true) // ones = ones.MustSetRequiresGrad(true, true)
outs.RetainsGrad(false) // ones.RetainGrad(false)
res := m.ForwardT(ones, true)
res = res.MustSetRequiresGrad(true, true)
res.RetainGrad(false)
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false) outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
if err != nil {
return
}
loss = loss.MustSetRequiresGrad(true, false)
opt.ZeroGrad() loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 2, false)
log.Info("loss", "loss", loss.Float64Values())
loss.MustBackward()
opt.Step()
// log.Info(mean.MustGrad(false).Float64Values())
log.Info(res.MustGrad(false).Float64Values())
log.Info(ones.MustGrad(false).Float64Values())
log.Info(outs.MustGrad(false).Float64Values())
vars := m.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k)
var grad *torch.Tensor
grad, err = v.Grad(false)
if err != nil { if err != nil {
log.Error(err) log.Fatal(err)
return }
loss = loss.MustSetRequiresGrad(true, true)
opt.ZeroGrad()
log.Info("loss", "loss", loss.Float64Values())
loss.MustBackward()
opt.Step()
// log.Info(mean.MustGrad(false).Float64Values())
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
log.Info(res.MustGrad(true).MustMax(true).Float64Values())
log.Info(ones_grad)
vars := m.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(true).Float64Values())
} }
grad, err = grad.Abs(false) m.Debug()
if err != nil {
log.Error(err)
return
}
grad, err = grad.Max(false) outs.MustDrop()
if err != nil {
log.Error(err) count += 1
return
} log.Fatal("grad zero")
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
} }
log.Warn("out")
} }