package train import ( "github.com/charmbracelet/log" "github.com/sugarme/gotch/nn" torch "github.com/sugarme/gotch/ts" ) func or_panic(err error) { if err != nil { log.Fatal(err) } } type SimpleBlock struct { C1, C2 *nn.Conv2D BN1 *nn.BatchNorm } // BasicBlock returns a BasicBlockModule instance func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock { conf1 := nn.DefaultConv2DConfig() conf1.Stride = []int64{2, 2} conf2 := nn.DefaultConv2DConfig() conf2.Padding = []int64{2, 2} b := &SimpleBlock{ C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1), C2: nn.NewConv2D(vs, 128, 128, 3, conf2), BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()), } return b } // Forward method func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor { identity := x out := b.C1.Forward(x) out = out.MustRelu(false) out = b.C2.Forward(out) out = out.MustRelu(false) shape, err := out.Size() or_panic(err) out, err = out.AdaptiveAvgPool2d(shape, false) or_panic(err) out = b.BN1.Forward(out) out, err = out.LeakyRelu(false) or_panic(err) out = out.MustAdd(identity, false) out = out.MustRelu(false) return out } func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { identity := x out := b.C1.ForwardT(x, train) out = out.MustRelu(false) out = b.C2.ForwardT(out, train) out = out.MustRelu(false) shape, err := out.Size() or_panic(err) out, err = out.AdaptiveAvgPool2d(shape, false) or_panic(err) out = b.BN1.ForwardT(out, train) out, err = out.LeakyRelu(false) or_panic(err) out = out.MustAdd(identity, false) out = out.MustRelu(false) return out } type MyLinear struct { FC1 *nn.Linear } // BasicBlock returns a BasicBlockModule instance func NewLinear(vs *nn.Path, in, out int64) *MyLinear { config := nn.DefaultLinearConfig() b := &MyLinear{ FC1: nn.NewLinear(vs, in, out, config), } return b } // Forward method func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor { var err error out := b.FC1.Forward(x) out, err = out.Relu(false) or_panic(err) return out } func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { var err error out := b.FC1.ForwardT(x, train) out, err = out.Relu(false) or_panic(err) return out } type Flatten struct{} // BasicBlock returns a BasicBlockModule instance func NewFlatten() *Flatten { return &Flatten{} } // Forward method func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor { out, err := x.Flatten(1, -1, false) or_panic(err) return out } func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { out, err := x.Flatten(1, -1, false) or_panic(err) return out } type Sigmoid struct{} func NewSigmoid() *Sigmoid { return &Sigmoid{} } func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor { out, err := x.Sigmoid(false) or_panic(err) return out } func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { out, err := x.Sigmoid(false) or_panic(err) return out }