More work on the torch version
This commit is contained in:
@@ -182,53 +182,54 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
}
|
||||
|
||||
data := item.Data
|
||||
data, err = data.ToDevice(device, gotch.Float, true, true, false)
|
||||
data, err = data.ToDevice(device, gotch.Float, false, true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data, err = data.SetRequiresGrad(true, true)
|
||||
var size []int64
|
||||
size, err = data.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var zeros *torch.Tensor
|
||||
zeros, err = torch.Zeros(size, gotch.Float, device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data, err = zeros.Add(data, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
|
||||
|
||||
data, err = data.SetRequiresGrad(true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
|
||||
|
||||
err = data.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var size []int64
|
||||
size, err = data.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
|
||||
|
||||
var ones *torch.Tensor
|
||||
ones, err = torch.Ones(size, gotch.Float, device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ones, err = ones.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = ones.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//pred := model.ForwardT(data, true)
|
||||
pred := model.ForwardT(ones, true)
|
||||
pred := model.ForwardT(data, true)
|
||||
pred, err = pred.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = pred.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
label := item.Label
|
||||
label, err = label.ToDevice(device, gotch.Float, false, true, false)
|
||||
@@ -240,9 +241,9 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
err = label.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate loss
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
|
||||
@@ -253,11 +254,10 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = loss.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = loss.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = opt.ZeroGrad()
|
||||
if err != nil {
|
||||
@@ -269,20 +269,17 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
|
||||
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
|
||||
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
|
||||
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
|
||||
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
|
||||
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
|
||||
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
|
||||
|
||||
vars := model.Vs.Variables()
|
||||
|
||||
for k, v := range vars {
|
||||
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
|
||||
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false))
|
||||
}
|
||||
|
||||
model.Debug()
|
||||
|
||||
model.Debug()
|
||||
|
||||
err = opt.Step()
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user