fyp/logic/models/train/torch/utils.go

153 lines
3.0 KiB
Go
Raw Normal View History

2024-04-19 15:39:51 +01:00
package train
import (
2024-04-22 00:09:07 +01:00
"unsafe"
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
2024-04-19 15:39:51 +01:00
"github.com/charmbracelet/log"
2024-04-22 00:09:07 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
2024-04-19 15:39:51 +01:00
)
func or_panic(err error) {
if err != nil {
log.Fatal(err)
}
}
type SimpleBlock struct {
C1, C2 *nn.Conv2D
BN1 *nn.BatchNorm
}
// BasicBlock returns a BasicBlockModule instance
2024-04-22 00:09:07 +01:00
func NewSimpleBlock(_vs *my_nn.Path, inplanes int64) *SimpleBlock {
vs := (*nn.Path)(unsafe.Pointer(_vs))
2024-04-19 15:39:51 +01:00
conf1 := nn.DefaultConv2DConfig()
conf1.Stride = []int64{2, 2}
conf2 := nn.DefaultConv2DConfig()
conf2.Padding = []int64{2, 2}
b := &SimpleBlock{
C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1),
C2: nn.NewConv2D(vs, 128, 128, 3, conf2),
BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()),
}
return b
}
// Forward method
func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor {
identity := x
out := b.C1.Forward(x)
out = out.MustRelu(false)
out = b.C2.Forward(out)
out = out.MustRelu(false)
shape, err := out.Size()
or_panic(err)
out, err = out.AdaptiveAvgPool2d(shape, false)
or_panic(err)
out = b.BN1.Forward(out)
out, err = out.LeakyRelu(false)
or_panic(err)
out = out.MustAdd(identity, false)
out = out.MustRelu(false)
return out
}
func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
identity := x
out := b.C1.ForwardT(x, train)
out = out.MustRelu(false)
out = b.C2.ForwardT(out, train)
out = out.MustRelu(false)
shape, err := out.Size()
or_panic(err)
out, err = out.AdaptiveAvgPool2d(shape, false)
or_panic(err)
out = b.BN1.ForwardT(out, train)
out, err = out.LeakyRelu(false)
or_panic(err)
out = out.MustAdd(identity, false)
out = out.MustRelu(false)
return out
}
// BasicBlock returns a BasicBlockModule instance
2024-04-22 00:09:07 +01:00
func NewLinear(vs *my_nn.Path, in, out int64) *my_nn.Linear {
config := my_nn.DefaultLinearConfig()
return my_nn.NewLinear(vs, in, out, config)
2024-04-19 15:39:51 +01:00
}
type Flatten struct{}
// BasicBlock returns a BasicBlockModule instance
func NewFlatten() *Flatten {
return &Flatten{}
}
2024-04-22 00:09:07 +01:00
// The flatten layer does not to move anything to the device
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
2024-04-23 00:14:35 +01:00
func (b *Flatten) Debug() {}
2024-04-22 00:09:07 +01:00
2024-04-19 15:39:51 +01:00
// Forward method
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
out, err := x.Flatten(1, -1, false)
or_panic(err)
return out
}
func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
out, err := x.Flatten(1, -1, false)
or_panic(err)
return out
}
type Sigmoid struct{}
func NewSigmoid() *Sigmoid {
return &Sigmoid{}
}
2024-04-22 00:09:07 +01:00
// The sigmoid layer does not need to move anything to another device
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
2024-04-23 00:14:35 +01:00
func (b *Sigmoid) Debug() {}
2024-04-22 00:09:07 +01:00
2024-04-19 15:39:51 +01:00
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
out, err := x.Sigmoid(false)
or_panic(err)
return out
}
func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
out, err := x.Sigmoid(false)
or_panic(err)
return out
}