More work on tring to make torch work
This commit is contained in:
@@ -146,16 +146,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt1.Debug() */
|
||||
|
||||
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
|
||||
|
||||
// TODO remove this
|
||||
model.To(device)
|
||||
defer model.To(gotch.CPU)
|
||||
|
||||
@@ -192,7 +182,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
}
|
||||
|
||||
data := item.Data
|
||||
data, err = data.ToDevice(device, gotch.Float, false, true, false)
|
||||
data, err = data.ToDevice(device, gotch.Float, true, true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -206,13 +196,39 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
pred := model.ForwardT(data, true)
|
||||
var size []int64
|
||||
size, err = data.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ones *torch.Tensor
|
||||
ones, err = torch.Ones(size, gotch.Float, device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ones, err = ones.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = ones.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//pred := model.ForwardT(data, true)
|
||||
pred := model.ForwardT(ones, true)
|
||||
pred, err = pred.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pred.RetainGrad(false)
|
||||
err = pred.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
label := item.Label
|
||||
label, err = label.ToDevice(device, gotch.Float, false, true, false)
|
||||
@@ -223,10 +239,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
label.RetainGrad(false)
|
||||
err = label.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate loss
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -234,6 +253,11 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = loss.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
err = opt.ZeroGrad()
|
||||
if err != nil {
|
||||
@@ -245,31 +269,24 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
err = opt.Step()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
|
||||
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
|
||||
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
|
||||
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
|
||||
|
||||
vars := model.Vs.Variables()
|
||||
|
||||
for k, v := range vars {
|
||||
var grad *torch.Tensor
|
||||
grad, err = v.Grad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
|
||||
}
|
||||
|
||||
grad, err = grad.Abs(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
model.Debug()
|
||||
|
||||
grad, err = grad.Max(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
|
||||
err = opt.Step()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
trainLoss = loss.Float64Values()[0]
|
||||
@@ -295,7 +312,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
}
|
||||
} */
|
||||
|
||||
// panic("fornow")
|
||||
panic("fornow")
|
||||
}
|
||||
|
||||
//v := []float64{}
|
||||
|
||||
Reference in New Issue
Block a user