fyp/logic/models/train/torch/utils.go

168 lines
2.9 KiB
Go

package train
import (
"github.com/charmbracelet/log"
"github.com/sugarme/gotch/nn"
torch "github.com/sugarme/gotch/ts"
)
func or_panic(err error) {
if err != nil {
log.Fatal(err)
}
}
type SimpleBlock struct {
C1, C2 *nn.Conv2D
BN1 *nn.BatchNorm
}
// BasicBlock returns a BasicBlockModule instance
func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock {
conf1 := nn.DefaultConv2DConfig()
conf1.Stride = []int64{2, 2}
conf2 := nn.DefaultConv2DConfig()
conf2.Padding = []int64{2, 2}
b := &SimpleBlock{
C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1),
C2: nn.NewConv2D(vs, 128, 128, 3, conf2),
BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()),
}
return b
}
// Forward method
func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor {
identity := x
out := b.C1.Forward(x)
out = out.MustRelu(false)
out = b.C2.Forward(out)
out = out.MustRelu(false)
shape, err := out.Size()
or_panic(err)
out, err = out.AdaptiveAvgPool2d(shape, false)
or_panic(err)
out = b.BN1.Forward(out)
out, err = out.LeakyRelu(false)
or_panic(err)
out = out.MustAdd(identity, false)
out = out.MustRelu(false)
return out
}
func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
identity := x
out := b.C1.ForwardT(x, train)
out = out.MustRelu(false)
out = b.C2.ForwardT(out, train)
out = out.MustRelu(false)
shape, err := out.Size()
or_panic(err)
out, err = out.AdaptiveAvgPool2d(shape, false)
or_panic(err)
out = b.BN1.ForwardT(out, train)
out, err = out.LeakyRelu(false)
or_panic(err)
out = out.MustAdd(identity, false)
out = out.MustRelu(false)
return out
}
type MyLinear struct {
FC1 *nn.Linear
}
// BasicBlock returns a BasicBlockModule instance
func NewLinear(vs *nn.Path, in, out int64) *MyLinear {
config := nn.DefaultLinearConfig()
b := &MyLinear{
FC1: nn.NewLinear(vs, in, out, config),
}
return b
}
// Forward method
func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor {
var err error
out := b.FC1.Forward(x)
out, err = out.Relu(false)
or_panic(err)
return out
}
func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
var err error
out := b.FC1.ForwardT(x, train)
out, err = out.Relu(false)
or_panic(err)
return out
}
type Flatten struct{}
// BasicBlock returns a BasicBlockModule instance
func NewFlatten() *Flatten {
return &Flatten{}
}
// Forward method
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
out, err := x.Flatten(1, -1, false)
or_panic(err)
return out
}
func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
out, err := x.Flatten(1, -1, false)
or_panic(err)
return out
}
type Sigmoid struct{}
func NewSigmoid() *Sigmoid {
return &Sigmoid{}
}
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
out, err := x.Sigmoid(false)
or_panic(err)
return out
}
func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
out, err := x.Sigmoid(false)
or_panic(err)
return out
}