More work done on torch
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
package train
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
torch "github.com/sugarme/gotch/ts"
|
||||
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
||||
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||
)
|
||||
|
||||
func or_panic(err error) {
|
||||
@@ -19,7 +23,9 @@ type SimpleBlock struct {
|
||||
}
|
||||
|
||||
// BasicBlock returns a BasicBlockModule instance
|
||||
func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock {
|
||||
func NewSimpleBlock(_vs *my_nn.Path, inplanes int64) *SimpleBlock {
|
||||
vs := (*nn.Path)(unsafe.Pointer(_vs))
|
||||
|
||||
conf1 := nn.DefaultConv2DConfig()
|
||||
conf1.Stride = []int64{2, 2}
|
||||
|
||||
@@ -85,40 +91,11 @@ func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
type MyLinear struct {
|
||||
FC1 *nn.Linear
|
||||
}
|
||||
|
||||
// BasicBlock returns a BasicBlockModule instance
|
||||
func NewLinear(vs *nn.Path, in, out int64) *MyLinear {
|
||||
config := nn.DefaultLinearConfig()
|
||||
b := &MyLinear{
|
||||
FC1: nn.NewLinear(vs, in, out, config),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Forward method
|
||||
func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
var err error
|
||||
|
||||
out := b.FC1.Forward(x)
|
||||
|
||||
out, err = out.Relu(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
var err error
|
||||
|
||||
out := b.FC1.ForwardT(x, train)
|
||||
|
||||
out, err = out.Relu(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
func NewLinear(vs *my_nn.Path, in, out int64) *my_nn.Linear {
|
||||
config := my_nn.DefaultLinearConfig()
|
||||
return my_nn.NewLinear(vs, in, out, config)
|
||||
}
|
||||
|
||||
type Flatten struct{}
|
||||
@@ -128,6 +105,9 @@ func NewFlatten() *Flatten {
|
||||
return &Flatten{}
|
||||
}
|
||||
|
||||
// The flatten layer does not to move anything to the device
|
||||
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
||||
|
||||
// Forward method
|
||||
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
|
||||
@@ -151,6 +131,9 @@ func NewSigmoid() *Sigmoid {
|
||||
return &Sigmoid{}
|
||||
}
|
||||
|
||||
// The sigmoid layer does not need to move anything to another device
|
||||
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
||||
|
||||
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
out, err := x.Sigmoid(false)
|
||||
or_panic(err)
|
||||
|
||||
Reference in New Issue
Block a user