More work on tring to make torch work

This commit is contained in:
2024-04-23 00:14:35 +01:00
parent 703fea46f2
commit a4a9ade71f
7 changed files with 109 additions and 43 deletions

View File

@@ -35,12 +35,12 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
for i := 0; i < len(n.Layers); i++ {
if i == 0 {
outs[0] = n.Layers[i].ForwardT(x, train)
defer outs[0].MustDrop()
//defer outs[0].MustDrop()
} else if i == len(n.Layers)-1 {
return n.Layers[i].ForwardT(outs[i-1], train)
} else {
outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
defer outs[i].MustDrop()
//defer outs[i].MustDrop()
}
}
panic("Do not reach here")
@@ -107,6 +107,12 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *
return b
}
func (model *ContainerModel) Debug() {
for _, v := range model.Layers {
v.Debug()
}
}
func SaveModel(model *ContainerModel, modelFn string) (err error) {
model.Vs.ToDevice(gotch.CPU)
return model.Vs.Save(modelFn)