Started working on moving to torch
This commit is contained in:
167
logic/models/train/torch/utils.go
Normal file
167
logic/models/train/torch/utils.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package train
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
torch "github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func or_panic(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleBlock struct {
|
||||
C1, C2 *nn.Conv2D
|
||||
BN1 *nn.BatchNorm
|
||||
}
|
||||
|
||||
// BasicBlock returns a BasicBlockModule instance
|
||||
func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock {
|
||||
conf1 := nn.DefaultConv2DConfig()
|
||||
conf1.Stride = []int64{2, 2}
|
||||
|
||||
conf2 := nn.DefaultConv2DConfig()
|
||||
conf2.Padding = []int64{2, 2}
|
||||
|
||||
b := &SimpleBlock{
|
||||
C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1),
|
||||
C2: nn.NewConv2D(vs, 128, 128, 3, conf2),
|
||||
BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Forward method
|
||||
func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
identity := x
|
||||
|
||||
out := b.C1.Forward(x)
|
||||
out = out.MustRelu(false)
|
||||
|
||||
out = b.C2.Forward(out)
|
||||
out = out.MustRelu(false)
|
||||
|
||||
shape, err := out.Size()
|
||||
or_panic(err)
|
||||
|
||||
out, err = out.AdaptiveAvgPool2d(shape, false)
|
||||
or_panic(err)
|
||||
|
||||
out = b.BN1.Forward(out)
|
||||
out, err = out.LeakyRelu(false)
|
||||
or_panic(err)
|
||||
|
||||
out = out.MustAdd(identity, false)
|
||||
out = out.MustRelu(false)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
identity := x
|
||||
|
||||
out := b.C1.ForwardT(x, train)
|
||||
out = out.MustRelu(false)
|
||||
|
||||
out = b.C2.ForwardT(out, train)
|
||||
out = out.MustRelu(false)
|
||||
|
||||
shape, err := out.Size()
|
||||
or_panic(err)
|
||||
|
||||
out, err = out.AdaptiveAvgPool2d(shape, false)
|
||||
or_panic(err)
|
||||
|
||||
out = b.BN1.ForwardT(out, train)
|
||||
out, err = out.LeakyRelu(false)
|
||||
or_panic(err)
|
||||
|
||||
out = out.MustAdd(identity, false)
|
||||
out = out.MustRelu(false)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type MyLinear struct {
|
||||
FC1 *nn.Linear
|
||||
}
|
||||
|
||||
// BasicBlock returns a BasicBlockModule instance
|
||||
func NewLinear(vs *nn.Path, in, out int64) *MyLinear {
|
||||
config := nn.DefaultLinearConfig()
|
||||
b := &MyLinear{
|
||||
FC1: nn.NewLinear(vs, in, out, config),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Forward method
|
||||
func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
var err error
|
||||
|
||||
out := b.FC1.Forward(x)
|
||||
|
||||
out, err = out.Relu(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
var err error
|
||||
|
||||
out := b.FC1.ForwardT(x, train)
|
||||
|
||||
out, err = out.Relu(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type Flatten struct{}
|
||||
|
||||
// BasicBlock returns a BasicBlockModule instance
|
||||
func NewFlatten() *Flatten {
|
||||
return &Flatten{}
|
||||
}
|
||||
|
||||
// Forward method
|
||||
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
|
||||
out, err := x.Flatten(1, -1, false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
|
||||
out, err := x.Flatten(1, -1, false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
type Sigmoid struct{}
|
||||
|
||||
func NewSigmoid() *Sigmoid {
|
||||
return &Sigmoid{}
|
||||
}
|
||||
|
||||
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
|
||||
out, err := x.Sigmoid(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
out, err := x.Sigmoid(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user