More work done on torch

This commit is contained in:
2024-04-22 00:09:07 +01:00
parent 28707b3f1b
commit 703fea46f2
13 changed files with 2435 additions and 96 deletions

View File

@@ -1,10 +1,14 @@
package train
import (
"unsafe"
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
"github.com/charmbracelet/log"
"github.com/sugarme/gotch/nn"
torch "github.com/sugarme/gotch/ts"
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
)
func or_panic(err error) {
@@ -19,7 +23,9 @@ type SimpleBlock struct {
}
// BasicBlock returns a BasicBlockModule instance
func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock {
func NewSimpleBlock(_vs *my_nn.Path, inplanes int64) *SimpleBlock {
vs := (*nn.Path)(unsafe.Pointer(_vs))
conf1 := nn.DefaultConv2DConfig()
conf1.Stride = []int64{2, 2}
@@ -85,40 +91,11 @@ func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
return out
}
type MyLinear struct {
FC1 *nn.Linear
}
// BasicBlock returns a BasicBlockModule instance
func NewLinear(vs *nn.Path, in, out int64) *MyLinear {
config := nn.DefaultLinearConfig()
b := &MyLinear{
FC1: nn.NewLinear(vs, in, out, config),
}
return b
}
// Forward method
func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor {
var err error
out := b.FC1.Forward(x)
out, err = out.Relu(false)
or_panic(err)
return out
}
func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
var err error
out := b.FC1.ForwardT(x, train)
out, err = out.Relu(false)
or_panic(err)
return out
func NewLinear(vs *my_nn.Path, in, out int64) *my_nn.Linear {
config := my_nn.DefaultLinearConfig()
return my_nn.NewLinear(vs, in, out, config)
}
type Flatten struct{}
@@ -128,6 +105,9 @@ func NewFlatten() *Flatten {
return &Flatten{}
}
// The flatten layer does not to move anything to the device
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
// Forward method
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
@@ -151,6 +131,9 @@ func NewSigmoid() *Sigmoid {
return &Sigmoid{}
}
// The sigmoid layer does not need to move anything to another device
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
out, err := x.Sigmoid(false)
or_panic(err)