More work on the torch version
This commit is contained in:
parent
a4a9ade71f
commit
527b57a111
@ -27,6 +27,7 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(n.Layers) == 1 {
|
if len(n.Layers) == 1 {
|
||||||
|
log.Info("here")
|
||||||
return n.Layers[0].ForwardT(x, train)
|
return n.Layers[0].ForwardT(x, train)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,16 +182,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := item.Data
|
data := item.Data
|
||||||
data, err = data.ToDevice(device, gotch.Float, true, true, false)
|
data, err = data.ToDevice(device, gotch.Float, false, true, false)
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err = data.SetRequiresGrad(true, true)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = data.RetainGrad(false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -202,24 +193,34 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ones *torch.Tensor
|
var zeros *torch.Tensor
|
||||||
ones, err = torch.Ones(size, gotch.Float, device)
|
zeros, err = torch.Zeros(size, gotch.Float, device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ones, err = ones.SetRequiresGrad(true, true)
|
data, err = zeros.Add(data, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ones.RetainGrad(false)
|
log.Info("\n\nhere 1, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
|
||||||
|
|
||||||
|
data, err = data.SetRequiresGrad(true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//pred := model.ForwardT(data, true)
|
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
|
||||||
pred := model.ForwardT(ones, true)
|
|
||||||
|
err = data.RetainGrad(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("\n\nhere 2, data\n\n", "retains", data.MustRetainsGrad(false), "requires", data.MustRequiresGrad())
|
||||||
|
|
||||||
|
pred := model.ForwardT(data, true)
|
||||||
pred, err = pred.SetRequiresGrad(true, true)
|
pred, err = pred.SetRequiresGrad(true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -258,7 +259,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
err = opt.ZeroGrad()
|
err = opt.ZeroGrad()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -269,11 +269,9 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
|
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values())
|
||||||
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
|
|
||||||
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
|
|
||||||
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
|
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values())
|
||||||
|
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false))
|
||||||
|
|
||||||
vars := model.Vs.Variables()
|
vars := model.Vs.Variables()
|
||||||
|
|
||||||
@ -283,7 +281,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
|
|
||||||
model.Debug()
|
model.Debug()
|
||||||
|
|
||||||
|
|
||||||
err = opt.Step()
|
err = opt.Step()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
2
main.go
2
main.go
@ -23,7 +23,7 @@ const (
|
|||||||
dbname = "aistuff"
|
dbname = "aistuff"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main_() {
|
||||||
|
|
||||||
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
|
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
|
||||||
"password=%s dbname=%s sslmode=disable",
|
"password=%s dbname=%s sslmode=disable",
|
||||||
|
89
test.go
89
test.go
@ -5,35 +5,40 @@ import (
|
|||||||
|
|
||||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
|
"git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch"
|
||||||
|
|
||||||
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
||||||
|
|
||||||
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func _main() {
|
func main() {
|
||||||
|
|
||||||
log.Info("Hello world")
|
log.Info("Hello world")
|
||||||
|
|
||||||
m := train.BuildModel([]*dbtypes.Layer{
|
m := train.BuildModel([]*dbtypes.Layer{
|
||||||
&dbtypes.Layer{
|
&dbtypes.Layer{
|
||||||
LayerType: dbtypes.LAYER_INPUT,
|
LayerType: dbtypes.LAYER_INPUT,
|
||||||
Shape: "[ 2, 3, 3 ]",
|
Shape: "[ 3, 28, 28 ]",
|
||||||
},
|
},
|
||||||
&dbtypes.Layer{
|
&dbtypes.Layer{
|
||||||
LayerType: dbtypes.LAYER_FLATTEN,
|
LayerType: dbtypes.LAYER_FLATTEN,
|
||||||
},
|
},
|
||||||
&dbtypes.Layer{
|
&dbtypes.Layer{
|
||||||
LayerType: dbtypes.LAYER_DENSE,
|
LayerType: dbtypes.LAYER_DENSE,
|
||||||
Shape: "[ 10 ]",
|
Shape: "[ 27 ]",
|
||||||
},
|
},
|
||||||
&dbtypes.Layer{
|
&dbtypes.Layer{
|
||||||
LayerType: dbtypes.LAYER_DENSE,
|
LayerType: dbtypes.LAYER_DENSE,
|
||||||
Shape: "[ 10 ]",
|
Shape: "[ 18 ]",
|
||||||
},
|
},
|
||||||
|
// &dbtypes.Layer{
|
||||||
|
// LayerType: dbtypes.LAYER_DENSE,
|
||||||
|
// Shape: "[ 9 ]",
|
||||||
|
// },
|
||||||
}, 0, true)
|
}, 0, true)
|
||||||
|
|
||||||
var err error
|
//var err error
|
||||||
|
|
||||||
d := gotch.CudaIfAvailable()
|
d := gotch.CudaIfAvailable()
|
||||||
|
|
||||||
@ -41,72 +46,76 @@ func _main() {
|
|||||||
|
|
||||||
m.To(d)
|
m.To(d)
|
||||||
|
|
||||||
|
var ones_grad float64 = 0
|
||||||
|
var count = 0
|
||||||
|
|
||||||
|
vars1 := m.Vs.Variables()
|
||||||
|
|
||||||
|
for k, v := range vars1 {
|
||||||
|
ones := torch.MustOnes(v.MustSize(), gotch.Float, d)
|
||||||
|
v := ones.MustSetRequiresGrad(true, false)
|
||||||
|
v.MustDrop()
|
||||||
|
ones.RetainGrad(false)
|
||||||
|
|
||||||
|
m.Vs.UpdateVarTensor(k, ones, true)
|
||||||
|
m.Refresh()
|
||||||
|
}
|
||||||
|
|
||||||
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
|
opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d)
|
log.Info("start")
|
||||||
ones = ones.MustSetRequiresGrad(true, true)
|
|
||||||
ones.RetainGrad(false)
|
for count < 100 {
|
||||||
|
|
||||||
|
ones := torch.MustOnes([]int64{1, 3, 28, 28}, gotch.Float, d)
|
||||||
|
// ones = ones.MustSetRequiresGrad(true, true)
|
||||||
|
// ones.RetainGrad(false)
|
||||||
|
|
||||||
res := m.ForwardT(ones, true)
|
res := m.ForwardT(ones, true)
|
||||||
res = res.MustSetRequiresGrad(true, true)
|
res = res.MustSetRequiresGrad(true, true)
|
||||||
res.RetainGrad(false)
|
res.RetainGrad(false)
|
||||||
|
|
||||||
outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d)
|
outs := torch.MustZeros([]int64{1, 18}, gotch.Float, d)
|
||||||
outs = outs.MustSetRequiresGrad(true, true)
|
|
||||||
outs.RetainsGrad(false)
|
|
||||||
|
|
||||||
|
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 2, false)
|
||||||
loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
loss = loss.MustSetRequiresGrad(true, false)
|
loss = loss.MustSetRequiresGrad(true, true)
|
||||||
|
|
||||||
opt.ZeroGrad()
|
opt.ZeroGrad()
|
||||||
|
|
||||||
|
|
||||||
log.Info("loss", "loss", loss.Float64Values())
|
log.Info("loss", "loss", loss.Float64Values())
|
||||||
|
|
||||||
loss.MustBackward()
|
loss.MustBackward()
|
||||||
|
|
||||||
|
|
||||||
opt.Step()
|
opt.Step()
|
||||||
|
|
||||||
// log.Info(mean.MustGrad(false).Float64Values())
|
// log.Info(mean.MustGrad(false).Float64Values())
|
||||||
log.Info(res.MustGrad(false).Float64Values())
|
//ones_grad = ones.MustGrad(true).MustMax(true).Float64Values()[0]
|
||||||
log.Info(ones.MustGrad(false).Float64Values())
|
log.Info(res.MustGrad(true).MustMax(true).Float64Values())
|
||||||
log.Info(outs.MustGrad(false).Float64Values())
|
|
||||||
|
log.Info(ones_grad)
|
||||||
|
|
||||||
vars := m.Vs.Variables()
|
vars := m.Vs.Variables()
|
||||||
|
|
||||||
for k, v := range vars {
|
for k, v := range vars {
|
||||||
|
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(true).Float64Values())
|
||||||
log.Info("[grad check]", "k", k)
|
|
||||||
|
|
||||||
var grad *torch.Tensor
|
|
||||||
grad, err = v.Grad(false)
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
grad, err = grad.Abs(false)
|
m.Debug()
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
grad, err = grad.Max(false)
|
outs.MustDrop()
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
|
count += 1
|
||||||
}
|
|
||||||
|
log.Fatal("grad zero")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warn("out")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user