More work on tring to make torch work

This commit is contained in:
Andre Henriques 2024-04-23 00:14:35 +01:00
parent 703fea46f2
commit a4a9ade71f
7 changed files with 109 additions and 43 deletions

View File

@ -35,6 +35,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
if err != nil { if err != nil {
return return
} }
pimgs = append(pimgs, img) pimgs = append(pimgs, img)
t_label := make([]int, size) t_label := make([]int, size)
@ -46,6 +47,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
if err != nil { if err != nil {
return return
} }
plabels = append(plabels, label) plabels = append(plabels, label)
} }
@ -62,13 +64,16 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
count = len(pimgs) count = len(pimgs)
imgs, err = torch.Stack(pimgs, 0) imgs, err = torch.Stack(pimgs, 0)
if err != nil {
return
}
labels, err = labels.ToDtype(gotch.Float, false, false, true) imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
if err != nil { if err != nil {
return return
} }
imgs, err = imgs.ToDtype(gotch.Float, false, false, true) labels, err = labels.ToDtype(gotch.Float, false, false, true)
if err != nil { if err != nil {
return return
} }
@ -129,21 +134,46 @@ func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
} }
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) { func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
train_images, err := ds.TrainImages.DetachCopy(false) // Create a clone of the trainimages
size, err := ds.TrainImages.Size()
if err != nil { if err != nil {
return return
} }
train_labels, err := ds.TrainLabels.DetachCopy(false) train_images, err := torch.Zeros(size, gotch.Float, ds.Device)
if err != nil { if err != nil {
return return
} }
iter, err = torch.NewIter2(train_images, train_labels, batchSize) ds.TrainImages, err = ds.TrainImages.Clone(train_images, false)
if err != nil { if err != nil {
return return
} }
// Create a clone of the labels
size, err = ds.TrainLabels.Size()
if err != nil {
return
}
train_labels, err := torch.Zeros(size, gotch.Float, ds.Device)
if err != nil {
return
}
ds.TrainLabels, err = ds.TrainLabels.Clone(train_labels, false)
if err != nil {
return
}
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
if err != nil {
return
}
return return
} }

View File

@ -7,6 +7,7 @@ import (
"git.andr3h3nriqu3s.com/andr3/gotch/nn" "git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts" "git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log"
) )
// LinearConfig is a configuration for a linear layer // LinearConfig is a configuration for a linear layer
@ -104,6 +105,11 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
} }
} }
func (l *Linear) Debug() {
log.Info("Ws", "ws", l.Ws.MustGrad(false).MustMax(false).Float64Values())
log.Info("Bs", "bs", l.Bs.MustGrad(false).MustMax(false).Float64Values())
}
func (l *Linear) ExtractFromVarstore(vs *VarStore) { func (l *Linear) ExtractFromVarstore(vs *VarStore) {
l.Ws = vs.GetTensorOfVar(l.weight_name) l.Ws = vs.GetTensorOfVar(l.weight_name)
l.Bs = vs.GetTensorOfVar(l.bias_name) l.Bs = vs.GetTensorOfVar(l.bias_name)

View File

@ -14,4 +14,5 @@ type MyLayer interface {
torch.ModuleT torch.ModuleT
ExtractFromVarstore(vs *VarStore) ExtractFromVarstore(vs *VarStore)
Debug()
} }

View File

@ -35,12 +35,12 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
for i := 0; i < len(n.Layers); i++ { for i := 0; i < len(n.Layers); i++ {
if i == 0 { if i == 0 {
outs[0] = n.Layers[i].ForwardT(x, train) outs[0] = n.Layers[i].ForwardT(x, train)
defer outs[0].MustDrop() //defer outs[0].MustDrop()
} else if i == len(n.Layers)-1 { } else if i == len(n.Layers)-1 {
return n.Layers[i].ForwardT(outs[i-1], train) return n.Layers[i].ForwardT(outs[i-1], train)
} else { } else {
outs[i] = n.Layers[i].ForwardT(outs[i-1], train) outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
defer outs[i].MustDrop() //defer outs[i].MustDrop()
} }
} }
panic("Do not reach here") panic("Do not reach here")
@ -107,6 +107,12 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *
return b return b
} }
func (model *ContainerModel) Debug() {
for _, v := range model.Layers {
v.Debug()
}
}
func SaveModel(model *ContainerModel, modelFn string) (err error) { func SaveModel(model *ContainerModel, modelFn string) (err error) {
model.Vs.ToDevice(gotch.CPU) model.Vs.ToDevice(gotch.CPU)
return model.Vs.Save(modelFn) return model.Vs.Save(modelFn)

View File

@ -107,6 +107,7 @@ func NewFlatten() *Flatten {
// The flatten layer does not to move anything to the device // The flatten layer does not to move anything to the device
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {} func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
func (b *Flatten) Debug() {}
// Forward method // Forward method
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor { func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
@ -133,6 +134,7 @@ func NewSigmoid() *Sigmoid {
// The sigmoid layer does not need to move anything to another device // The sigmoid layer does not need to move anything to another device
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {} func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
func (b *Sigmoid) Debug() {}
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor { func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
out, err := x.Sigmoid(false) out, err := x.Sigmoid(false)

View File

@ -146,16 +146,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return return
} }
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
if err != nil {
return
}
opt1.Debug() */
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
// TODO remove this
model.To(device) model.To(device)
defer model.To(gotch.CPU) defer model.To(gotch.CPU)
@ -192,7 +182,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
} }
data := item.Data data := item.Data
data, err = data.ToDevice(device, gotch.Float, false, true, false) data, err = data.ToDevice(device, gotch.Float, true, true, false)
if err != nil { if err != nil {
return return
} }
@ -206,13 +196,39 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return return
} }
pred := model.ForwardT(data, true) var size []int64
size, err = data.Size()
if err != nil {
return
}
var ones *torch.Tensor
ones, err = torch.Ones(size, gotch.Float, device)
if err != nil {
return
}
ones, err = ones.SetRequiresGrad(true, true)
if err != nil {
return
}
err = ones.RetainGrad(false)
if err != nil {
return
}
//pred := model.ForwardT(data, true)
pred := model.ForwardT(ones, true)
pred, err = pred.SetRequiresGrad(true, true) pred, err = pred.SetRequiresGrad(true, true)
if err != nil { if err != nil {
return return
} }
pred.RetainGrad(false) err = pred.RetainGrad(false)
if err != nil {
return
}
label := item.Label label := item.Label
label, err = label.ToDevice(device, gotch.Float, false, true, false) label, err = label.ToDevice(device, gotch.Float, false, true, false)
@ -223,10 +239,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if err != nil { if err != nil {
return return
} }
label.RetainGrad(false) err = label.RetainGrad(false)
if err != nil {
return
}
// Calculate loss // Calculate loss
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false) loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
if err != nil { if err != nil {
return return
} }
@ -234,6 +253,11 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
if err != nil { if err != nil {
return return
} }
err = loss.RetainGrad(false)
if err != nil {
return
}
err = opt.ZeroGrad() err = opt.ZeroGrad()
if err != nil { if err != nil {
@ -245,31 +269,24 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
return return
} }
err = opt.Step()
if err != nil { log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
return log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
} log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
vars := model.Vs.Variables() vars := model.Vs.Variables()
for k, v := range vars { for k, v := range vars {
var grad *torch.Tensor log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
grad, err = v.Grad(false) }
if err != nil {
return
}
grad, err = grad.Abs(false) model.Debug()
if err != nil {
return
}
grad, err = grad.Max(false)
if err != nil {
return
}
log.Info("[grad check]", "k", k, "grad", grad.Float64Values()) err = opt.Step()
if err != nil {
return
} }
trainLoss = loss.Float64Values()[0] trainLoss = loss.Float64Values()[0]
@ -295,7 +312,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
} }
} */ } */
// panic("fornow") panic("fornow")
} }
//v := []float64{} //v := []float64{}

View File

@ -11,7 +11,7 @@ import (
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
) )
func main_() { func _main() {
log.Info("Hello world") log.Info("Hello world")
@ -27,6 +27,10 @@ func main_() {
LayerType: dbtypes.LAYER_DENSE, LayerType: dbtypes.LAYER_DENSE,
Shape: "[ 10 ]", Shape: "[ 10 ]",
}, },
&dbtypes.Layer{
LayerType: dbtypes.LAYER_DENSE,
Shape: "[ 10 ]",
},
}, 0, true) }, 0, true)
var err error var err error