More work on tring to make torch work
This commit is contained in:
parent
703fea46f2
commit
a4a9ade71f
@ -35,6 +35,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pimgs = append(pimgs, img)
|
pimgs = append(pimgs, img)
|
||||||
|
|
||||||
t_label := make([]int, size)
|
t_label := make([]int, size)
|
||||||
@ -46,6 +47,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
plabels = append(plabels, label)
|
plabels = append(plabels, label)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,13 +64,16 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
|||||||
count = len(pimgs)
|
count = len(pimgs)
|
||||||
|
|
||||||
imgs, err = torch.Stack(pimgs, 0)
|
imgs, err = torch.Stack(pimgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
|
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -129,21 +134,46 @@ func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
|
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
|
||||||
|
|
||||||
train_images, err := ds.TrainImages.DetachCopy(false)
|
// Create a clone of the trainimages
|
||||||
|
size, err := ds.TrainImages.Size()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
train_labels, err := ds.TrainLabels.DetachCopy(false)
|
train_images, err := torch.Zeros(size, gotch.Float, ds.Device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
ds.TrainImages, err = ds.TrainImages.Clone(train_images, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Create a clone of the labels
|
||||||
|
size, err = ds.TrainLabels.Size()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
train_labels, err := torch.Zeros(size, gotch.Float, ds.Device)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ds.TrainLabels, err = ds.TrainLabels.Clone(train_labels, false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
||||||
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||||
|
"github.com/charmbracelet/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LinearConfig is a configuration for a linear layer
|
// LinearConfig is a configuration for a linear layer
|
||||||
@ -104,6 +105,11 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Linear) Debug() {
|
||||||
|
log.Info("Ws", "ws", l.Ws.MustGrad(false).MustMax(false).Float64Values())
|
||||||
|
log.Info("Bs", "bs", l.Bs.MustGrad(false).MustMax(false).Float64Values())
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Linear) ExtractFromVarstore(vs *VarStore) {
|
func (l *Linear) ExtractFromVarstore(vs *VarStore) {
|
||||||
l.Ws = vs.GetTensorOfVar(l.weight_name)
|
l.Ws = vs.GetTensorOfVar(l.weight_name)
|
||||||
l.Bs = vs.GetTensorOfVar(l.bias_name)
|
l.Bs = vs.GetTensorOfVar(l.bias_name)
|
||||||
|
@ -14,4 +14,5 @@ type MyLayer interface {
|
|||||||
torch.ModuleT
|
torch.ModuleT
|
||||||
|
|
||||||
ExtractFromVarstore(vs *VarStore)
|
ExtractFromVarstore(vs *VarStore)
|
||||||
|
Debug()
|
||||||
}
|
}
|
||||||
|
@ -35,12 +35,12 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
|||||||
for i := 0; i < len(n.Layers); i++ {
|
for i := 0; i < len(n.Layers); i++ {
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
outs[0] = n.Layers[i].ForwardT(x, train)
|
outs[0] = n.Layers[i].ForwardT(x, train)
|
||||||
defer outs[0].MustDrop()
|
//defer outs[0].MustDrop()
|
||||||
} else if i == len(n.Layers)-1 {
|
} else if i == len(n.Layers)-1 {
|
||||||
return n.Layers[i].ForwardT(outs[i-1], train)
|
return n.Layers[i].ForwardT(outs[i-1], train)
|
||||||
} else {
|
} else {
|
||||||
outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
|
outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
|
||||||
defer outs[i].MustDrop()
|
//defer outs[i].MustDrop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
panic("Do not reach here")
|
panic("Do not reach here")
|
||||||
@ -107,6 +107,12 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (model *ContainerModel) Debug() {
|
||||||
|
for _, v := range model.Layers {
|
||||||
|
v.Debug()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func SaveModel(model *ContainerModel, modelFn string) (err error) {
|
func SaveModel(model *ContainerModel, modelFn string) (err error) {
|
||||||
model.Vs.ToDevice(gotch.CPU)
|
model.Vs.ToDevice(gotch.CPU)
|
||||||
return model.Vs.Save(modelFn)
|
return model.Vs.Save(modelFn)
|
||||||
|
@ -107,6 +107,7 @@ func NewFlatten() *Flatten {
|
|||||||
|
|
||||||
// The flatten layer does not to move anything to the device
|
// The flatten layer does not to move anything to the device
|
||||||
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
||||||
|
func (b *Flatten) Debug() {}
|
||||||
|
|
||||||
// Forward method
|
// Forward method
|
||||||
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
|
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
|
||||||
@ -133,6 +134,7 @@ func NewSigmoid() *Sigmoid {
|
|||||||
|
|
||||||
// The sigmoid layer does not need to move anything to another device
|
// The sigmoid layer does not need to move anything to another device
|
||||||
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
||||||
|
func (b *Sigmoid) Debug() {}
|
||||||
|
|
||||||
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
|
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
|
||||||
out, err := x.Sigmoid(false)
|
out, err := x.Sigmoid(false)
|
||||||
|
@ -146,16 +146,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt1.Debug() */
|
|
||||||
|
|
||||||
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
|
|
||||||
|
|
||||||
// TODO remove this
|
|
||||||
model.To(device)
|
model.To(device)
|
||||||
defer model.To(gotch.CPU)
|
defer model.To(gotch.CPU)
|
||||||
|
|
||||||
@ -192,7 +182,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := item.Data
|
data := item.Data
|
||||||
data, err = data.ToDevice(device, gotch.Float, false, true, false)
|
data, err = data.ToDevice(device, gotch.Float, true, true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -206,13 +196,39 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pred := model.ForwardT(data, true)
|
var size []int64
|
||||||
|
size, err = data.Size()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var ones *torch.Tensor
|
||||||
|
ones, err = torch.Ones(size, gotch.Float, device)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, err = ones.SetRequiresGrad(true, true)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ones.RetainGrad(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//pred := model.ForwardT(data, true)
|
||||||
|
pred := model.ForwardT(ones, true)
|
||||||
pred, err = pred.SetRequiresGrad(true, true)
|
pred, err = pred.SetRequiresGrad(true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pred.RetainGrad(false)
|
err = pred.RetainGrad(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
label := item.Label
|
label := item.Label
|
||||||
label, err = label.ToDevice(device, gotch.Float, false, true, false)
|
label, err = label.ToDevice(device, gotch.Float, false, true, false)
|
||||||
@ -223,10 +239,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
label.RetainGrad(false)
|
err = label.RetainGrad(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate loss
|
// Calculate loss
|
||||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -234,6 +253,11 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
err = loss.RetainGrad(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
err = opt.ZeroGrad()
|
err = opt.ZeroGrad()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -245,31 +269,24 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = opt.Step()
|
|
||||||
if err != nil {
|
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
|
||||||
return
|
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
|
||||||
}
|
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
|
||||||
|
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
|
||||||
|
|
||||||
vars := model.Vs.Variables()
|
vars := model.Vs.Variables()
|
||||||
|
|
||||||
for k, v := range vars {
|
for k, v := range vars {
|
||||||
var grad *torch.Tensor
|
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
|
||||||
grad, err = v.Grad(false)
|
}
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
grad, err = grad.Abs(false)
|
model.Debug()
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
grad, err = grad.Max(false)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
|
err = opt.Step()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
trainLoss = loss.Float64Values()[0]
|
trainLoss = loss.Float64Values()[0]
|
||||||
@ -295,7 +312,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
|||||||
}
|
}
|
||||||
} */
|
} */
|
||||||
|
|
||||||
// panic("fornow")
|
panic("fornow")
|
||||||
}
|
}
|
||||||
|
|
||||||
//v := []float64{}
|
//v := []float64{}
|
||||||
|
6
test.go
6
test.go
@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/charmbracelet/log"
|
"github.com/charmbracelet/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main_() {
|
func _main() {
|
||||||
|
|
||||||
log.Info("Hello world")
|
log.Info("Hello world")
|
||||||
|
|
||||||
@ -27,6 +27,10 @@ func main_() {
|
|||||||
LayerType: dbtypes.LAYER_DENSE,
|
LayerType: dbtypes.LAYER_DENSE,
|
||||||
Shape: "[ 10 ]",
|
Shape: "[ 10 ]",
|
||||||
},
|
},
|
||||||
|
&dbtypes.Layer{
|
||||||
|
LayerType: dbtypes.LAYER_DENSE,
|
||||||
|
Shape: "[ 10 ]",
|
||||||
|
},
|
||||||
}, 0, true)
|
}, 0, true)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
Loading…
Reference in New Issue
Block a user