More work on tring to make torch work
This commit is contained in:
parent
703fea46f2
commit
a4a9ade71f
@ -35,6 +35,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pimgs = append(pimgs, img)
|
||||
|
||||
t_label := make([]int, size)
|
||||
@ -46,6 +47,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plabels = append(plabels, label)
|
||||
}
|
||||
|
||||
@ -62,8 +64,6 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
count = len(pimgs)
|
||||
|
||||
imgs, err = torch.Stack(pimgs, 0)
|
||||
|
||||
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -73,6 +73,11 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
return
|
||||
}
|
||||
|
||||
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -130,16 +135,41 @@ func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
|
||||
|
||||
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
|
||||
|
||||
train_images, err := ds.TrainImages.DetachCopy(false)
|
||||
// Create a clone of the trainimages
|
||||
size, err := ds.TrainImages.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
train_labels, err := ds.TrainLabels.DetachCopy(false)
|
||||
train_images, err := torch.Zeros(size, gotch.Float, ds.Device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.TrainImages, err = ds.TrainImages.Clone(train_images, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Create a clone of the labels
|
||||
size, err = ds.TrainLabels.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
train_labels, err := torch.Zeros(size, gotch.Float, ds.Device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.TrainLabels, err = ds.TrainLabels.Clone(train_labels, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
||||
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
// LinearConfig is a configuration for a linear layer
|
||||
@ -104,6 +105,11 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Linear) Debug() {
|
||||
log.Info("Ws", "ws", l.Ws.MustGrad(false).MustMax(false).Float64Values())
|
||||
log.Info("Bs", "bs", l.Bs.MustGrad(false).MustMax(false).Float64Values())
|
||||
}
|
||||
|
||||
func (l *Linear) ExtractFromVarstore(vs *VarStore) {
|
||||
l.Ws = vs.GetTensorOfVar(l.weight_name)
|
||||
l.Bs = vs.GetTensorOfVar(l.bias_name)
|
||||
|
@ -14,4 +14,5 @@ type MyLayer interface {
|
||||
torch.ModuleT
|
||||
|
||||
ExtractFromVarstore(vs *VarStore)
|
||||
Debug()
|
||||
}
|
||||
|
@ -35,12 +35,12 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
for i := 0; i < len(n.Layers); i++ {
|
||||
if i == 0 {
|
||||
outs[0] = n.Layers[i].ForwardT(x, train)
|
||||
defer outs[0].MustDrop()
|
||||
//defer outs[0].MustDrop()
|
||||
} else if i == len(n.Layers)-1 {
|
||||
return n.Layers[i].ForwardT(outs[i-1], train)
|
||||
} else {
|
||||
outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
|
||||
defer outs[i].MustDrop()
|
||||
//defer outs[i].MustDrop()
|
||||
}
|
||||
}
|
||||
panic("Do not reach here")
|
||||
@ -107,6 +107,12 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *
|
||||
return b
|
||||
}
|
||||
|
||||
func (model *ContainerModel) Debug() {
|
||||
for _, v := range model.Layers {
|
||||
v.Debug()
|
||||
}
|
||||
}
|
||||
|
||||
func SaveModel(model *ContainerModel, modelFn string) (err error) {
|
||||
model.Vs.ToDevice(gotch.CPU)
|
||||
return model.Vs.Save(modelFn)
|
||||
|
@ -107,6 +107,7 @@ func NewFlatten() *Flatten {
|
||||
|
||||
// The flatten layer does not to move anything to the device
|
||||
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
||||
func (b *Flatten) Debug() {}
|
||||
|
||||
// Forward method
|
||||
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
@ -133,6 +134,7 @@ func NewSigmoid() *Sigmoid {
|
||||
|
||||
// The sigmoid layer does not need to move anything to another device
|
||||
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
||||
func (b *Sigmoid) Debug() {}
|
||||
|
||||
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
out, err := x.Sigmoid(false)
|
||||
|
@ -146,16 +146,6 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
/* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
opt1.Debug() */
|
||||
|
||||
//log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")
|
||||
|
||||
// TODO remove this
|
||||
model.To(device)
|
||||
defer model.To(gotch.CPU)
|
||||
|
||||
@ -192,7 +182,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
}
|
||||
|
||||
data := item.Data
|
||||
data, err = data.ToDevice(device, gotch.Float, false, true, false)
|
||||
data, err = data.ToDevice(device, gotch.Float, true, true, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -206,13 +196,39 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
pred := model.ForwardT(data, true)
|
||||
var size []int64
|
||||
size, err = data.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ones *torch.Tensor
|
||||
ones, err = torch.Ones(size, gotch.Float, device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ones, err = ones.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = ones.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//pred := model.ForwardT(data, true)
|
||||
pred := model.ForwardT(ones, true)
|
||||
pred, err = pred.SetRequiresGrad(true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pred.RetainGrad(false)
|
||||
err = pred.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
label := item.Label
|
||||
label, err = label.ToDevice(device, gotch.Float, false, true, false)
|
||||
@ -223,10 +239,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
label.RetainGrad(false)
|
||||
err = label.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate loss
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false)
|
||||
loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 2, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -234,6 +253,11 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = loss.RetainGrad(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
err = opt.ZeroGrad()
|
||||
if err != nil {
|
||||
@ -245,33 +269,26 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
return
|
||||
}
|
||||
|
||||
err = opt.Step()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("pred grad", "pred", pred.MustGrad(false).MustMax(false).Float64Values() )
|
||||
log.Info("pred grad", "ones", ones.MustGrad(false).MustMax(false).Float64Values(), "lol", ones.MustRetainsGrad(false) )
|
||||
log.Info("pred grad", "data", data.MustGrad(false).MustMax(false).Float64Values(), "lol", data.MustRetainsGrad(false) )
|
||||
log.Info("pred grad", "outs", label.MustGrad(false).MustMax(false).Float64Values() )
|
||||
|
||||
vars := model.Vs.Variables()
|
||||
|
||||
for k, v := range vars {
|
||||
var grad *torch.Tensor
|
||||
grad, err = v.Grad(false)
|
||||
if err != nil {
|
||||
return
|
||||
log.Info("[grad check]", "k", k, "grad", v.MustGrad(false).MustMax(false).Float64Values(), "lol", v.MustRetainsGrad(false) )
|
||||
}
|
||||
|
||||
grad, err = grad.Abs(false)
|
||||
model.Debug()
|
||||
|
||||
|
||||
err = opt.Step()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
grad, err = grad.Max(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("[grad check]", "k", k, "grad", grad.Float64Values())
|
||||
}
|
||||
|
||||
trainLoss = loss.Float64Values()[0]
|
||||
|
||||
// Calculate accuracy
|
||||
@ -295,7 +312,7 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor
|
||||
}
|
||||
} */
|
||||
|
||||
// panic("fornow")
|
||||
panic("fornow")
|
||||
}
|
||||
|
||||
//v := []float64{}
|
||||
|
6
test.go
6
test.go
@ -11,7 +11,7 @@ import (
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
func main_() {
|
||||
func _main() {
|
||||
|
||||
log.Info("Hello world")
|
||||
|
||||
@ -27,6 +27,10 @@ func main_() {
|
||||
LayerType: dbtypes.LAYER_DENSE,
|
||||
Shape: "[ 10 ]",
|
||||
},
|
||||
&dbtypes.Layer{
|
||||
LayerType: dbtypes.LAYER_DENSE,
|
||||
Shape: "[ 10 ]",
|
||||
},
|
||||
}, 0, true)
|
||||
|
||||
var err error
|
||||
|
Loading…
Reference in New Issue
Block a user