More work on the torch version

This commit is contained in:
2024-04-23 11:54:30 +01:00
parent a4a9ade71f
commit 527b57a111
4 changed files with 109 additions and 102 deletions

View File

@@ -182,53 +182,54 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
}
data := item.Data
data, err = data.ToDevice(device, gotch.Float, true, true, false)
data, err = data.ToDevice(device, gotch.Float, false, true, false)
if err != nil {
return
}
data, err = data.SetRequiresGrad(true, true)
var size []int64
size, err = data.Size()
if err != nil {
return
}
var zeros *torch.Tensor
zeros, err = torch.Zeros(size, gotch.Float, device)
if err != nil {
return
}
data, err = zeros.Add(data, true)
if err != nil {
return
}
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
data, err = data.SetRequiresGrad(true, false)
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
err = data.RetainGrad(false)
if err != nil {
return
}
var size []int64
size, err = data.Size()
if err != nil {
return
}
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
var ones *torch.Tensor
ones, err = torch.Ones(size, gotch.Float, device)
if err != nil {
return
}
ones, err = ones.SetRequiresGrad(true, true)
if err != nil {
return
}
err = ones.RetainGrad(false)
if err != nil {
return
}
//pred := model.ForwardT(data, true)
pred := model.ForwardT(ones, true)
pred := model.ForwardT(data, true)
pred, err = pred.SetRequiresGrad(true, true)
if err != nil {
return
}
err = pred.RetainGrad(false)
if err != nil {
return
}
if err != nil {
return
}
label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false)
@@ -240,9 +241,9 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
err = label.RetainGrad(false)
if err != nil {
return
}
if err != nil {
return
}
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
@@ -253,11 +254,10 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if err != nil {
return
}
err = loss.RetainGrad(false)
if err != nil {
return
}
err = loss.RetainGrad(false)
if err != nil {
return
}
err = opt.ZeroGrad()
if err != nil {
@@ -269,20 +269,17 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
vars := model.Vs.Variables()
for k, v := range vars {
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false))
}
model.Debug()
model.Debug()
err = opt.Step()
if err != nil {