More work on tring to make torch work

This commit is contained in:
2024-04-23 00:14:35 +01:00
parent 703fea46f2
commit a4a9ade71f
7 changed files with 109 additions and 43 deletions

View File

@@ -146,16 +146,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
opt1.Debug() */
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
// TODO remove this
model.To(device)
defer model.To(gotch.CPU)
@@ -192,7 +182,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
}
data := item.Data
data, err = data.ToDevice(device, gotch.Float, false, true, false)
data, err = data.ToDevice(device, gotch.Float, true, true, false)
if err != nil {
return
}
@@ -206,13 +196,39 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
pred := model.ForwardT(data, true)
var size []int64
size, err = data.Size()
if err != nil {
return
}
var ones *torch.Tensor
ones, err = torch.Ones(size, gotch.Float, device)
if err != nil {
return
}
ones, err = ones.SetRequiresGrad(true, true)
if err != nil {
return
}
err = ones.RetainGrad(false)
if err != nil {
return
}
//pred := model.ForwardT(data, true)
pred := model.ForwardT(ones, true)
pred, err = pred.SetRequiresGrad(true, true)
if err != nil {
return
}
pred.RetainGrad(false)
err = pred.RetainGrad(false)
if err != nil {
return
}
label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false)
@@ -223,10 +239,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if err != nil {
return
}
label.RetainGrad(false)
err = label.RetainGrad(false)
if err != nil {
return
}
// Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
if err != nil {
return
}
@@ -234,6 +253,11 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if err != nil {
return
}
err = loss.RetainGrad(false)
if err != nil {
return
}
err = opt.ZeroGrad()
if err != nil {
@@ -245,31 +269,24 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return
}
err = opt.Step()
if err != nil {
return
}
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
vars := model.Vs.Variables()
for k, v := range vars {
var grad *torch.Tensor
grad, err = v.Grad(false)
if err != nil {
return
}
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
}
grad, err = grad.Abs(false)
if err != nil {
return
}
model.Debug()
grad, err = grad.Max(false)
if err != nil {
return
}
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
err = opt.Step()
if err != nil {
return
}
trainLoss = loss.Float64Values()[0]
@@ -295,7 +312,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
}
} */
// panic("fornow")
panic("fornow")
}
//v := []float64{}