More work on tring to make torch work
This commit is contained in:
@@ -35,12 +35,12 @@ func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
for i := 0; i < len(n.Layers); i++ {
|
||||
if i == 0 {
|
||||
outs[0] = n.Layers[i].ForwardT(x, train)
|
||||
defer outs[0].MustDrop()
|
||||
//defer outs[0].MustDrop()
|
||||
} else if i == len(n.Layers)-1 {
|
||||
return n.Layers[i].ForwardT(outs[i-1], train)
|
||||
} else {
|
||||
outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
|
||||
defer outs[i].MustDrop()
|
||||
//defer outs[i].MustDrop()
|
||||
}
|
||||
}
|
||||
panic("Do not reach here")
|
||||
@@ -107,6 +107,12 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *
|
||||
return b
|
||||
}
|
||||
|
||||
func (model *ContainerModel) Debug() {
|
||||
for _, v := range model.Layers {
|
||||
v.Debug()
|
||||
}
|
||||
}
|
||||
|
||||
func SaveModel(model *ContainerModel, modelFn string) (err error) {
|
||||
model.Vs.ToDevice(gotch.CPU)
|
||||
return model.Vs.Save(modelFn)
|
||||
|
||||
Reference in New Issue
Block a user