package train import ( "unsafe" my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" "github.com/charmbracelet/log" "git.andr3h3nriqu3s.com/andr3/gotch/nn" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" ) func or_panic(err error) { if err != nil { log.Fatal(err) } } type SimpleBlock struct { C1, C2 *nn.Conv2D BN1 *nn.BatchNorm } // BasicBlock returns a BasicBlockModule instance func NewSimpleBlock(_vs *my_nn.Path, inplanes int64) *SimpleBlock { vs := (*nn.Path)(unsafe.Pointer(_vs)) conf1 := nn.DefaultConv2DConfig() conf1.Stride = []int64{2, 2} conf2 := nn.DefaultConv2DConfig() conf2.Padding = []int64{2, 2} b := &SimpleBlock{ C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1), C2: nn.NewConv2D(vs, 128, 128, 3, conf2), BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()), } return b } // Forward method func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor { identity := x out := b.C1.Forward(x) out = out.MustRelu(false) out = b.C2.Forward(out) out = out.MustRelu(false) shape, err := out.Size() or_panic(err) out, err = out.AdaptiveAvgPool2d(shape, false) or_panic(err) out = b.BN1.Forward(out) out, err = out.LeakyRelu(false) or_panic(err) out = out.MustAdd(identity, false) out = out.MustRelu(false) return out } func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { identity := x out := b.C1.ForwardT(x, train) out = out.MustRelu(false) out = b.C2.ForwardT(out, train) out = out.MustRelu(false) shape, err := out.Size() or_panic(err) out, err = out.AdaptiveAvgPool2d(shape, false) or_panic(err) out = b.BN1.ForwardT(out, train) out, err = out.LeakyRelu(false) or_panic(err) out = out.MustAdd(identity, false) out = out.MustRelu(false) return out } // BasicBlock returns a BasicBlockModule instance func NewLinear(vs *my_nn.Path, in, out int64) *my_nn.Linear { config := my_nn.DefaultLinearConfig() return my_nn.NewLinear(vs, in, out, config) } type Flatten struct{} // BasicBlock returns a BasicBlockModule instance func NewFlatten() *Flatten { return &Flatten{} } // The flatten layer does not to move anything to the device func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {} // Forward method func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor { out, err := x.Flatten(1, -1, false) or_panic(err) return out } func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { out, err := x.Flatten(1, -1, false) or_panic(err) return out } type Sigmoid struct{} func NewSigmoid() *Sigmoid { return &Sigmoid{} } // The sigmoid layer does not need to move anything to another device func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {} func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor { out, err := x.Sigmoid(false) or_panic(err) return out } func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { out, err := x.Sigmoid(false) or_panic(err) return out }