2024-04-19 15:39:51 +01:00
|
|
|
package train
|
|
|
|
|
|
|
|
import (
|
2024-04-22 00:09:07 +01:00
|
|
|
"unsafe"
|
|
|
|
|
|
|
|
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
|
|
|
|
2024-04-19 15:39:51 +01:00
|
|
|
"github.com/charmbracelet/log"
|
|
|
|
|
2024-04-22 00:09:07 +01:00
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
|
|
|
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
2024-04-19 15:39:51 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
func or_panic(err error) {
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type SimpleBlock struct {
|
|
|
|
C1, C2 *nn.Conv2D
|
|
|
|
BN1 *nn.BatchNorm
|
|
|
|
}
|
|
|
|
|
|
|
|
// BasicBlock returns a BasicBlockModule instance
|
2024-04-22 00:09:07 +01:00
|
|
|
func NewSimpleBlock(_vs *my_nn.Path, inplanes int64) *SimpleBlock {
|
|
|
|
vs := (*nn.Path)(unsafe.Pointer(_vs))
|
|
|
|
|
2024-04-19 15:39:51 +01:00
|
|
|
conf1 := nn.DefaultConv2DConfig()
|
|
|
|
conf1.Stride = []int64{2, 2}
|
|
|
|
|
|
|
|
conf2 := nn.DefaultConv2DConfig()
|
|
|
|
conf2.Padding = []int64{2, 2}
|
|
|
|
|
|
|
|
b := &SimpleBlock{
|
|
|
|
C1: nn.NewConv2D(vs, inplanes, 128, 3, conf1),
|
|
|
|
C2: nn.NewConv2D(vs, 128, 128, 3, conf2),
|
|
|
|
BN1: nn.NewBatchNorm(vs, 2, 128, nn.DefaultBatchNormConfig()),
|
|
|
|
}
|
|
|
|
return b
|
|
|
|
}
|
|
|
|
|
|
|
|
// Forward method
|
|
|
|
func (b *SimpleBlock) Forward(x *torch.Tensor) *torch.Tensor {
|
|
|
|
identity := x
|
|
|
|
|
|
|
|
out := b.C1.Forward(x)
|
|
|
|
out = out.MustRelu(false)
|
|
|
|
|
|
|
|
out = b.C2.Forward(out)
|
|
|
|
out = out.MustRelu(false)
|
|
|
|
|
|
|
|
shape, err := out.Size()
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
out, err = out.AdaptiveAvgPool2d(shape, false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
out = b.BN1.Forward(out)
|
|
|
|
out, err = out.LeakyRelu(false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
out = out.MustAdd(identity, false)
|
|
|
|
out = out.MustRelu(false)
|
|
|
|
|
|
|
|
return out
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
|
|
|
identity := x
|
|
|
|
|
|
|
|
out := b.C1.ForwardT(x, train)
|
|
|
|
out = out.MustRelu(false)
|
|
|
|
|
|
|
|
out = b.C2.ForwardT(out, train)
|
|
|
|
out = out.MustRelu(false)
|
|
|
|
|
|
|
|
shape, err := out.Size()
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
out, err = out.AdaptiveAvgPool2d(shape, false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
out = b.BN1.ForwardT(out, train)
|
|
|
|
out, err = out.LeakyRelu(false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
out = out.MustAdd(identity, false)
|
|
|
|
out = out.MustRelu(false)
|
|
|
|
|
|
|
|
return out
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// BasicBlock returns a BasicBlockModule instance
|
2024-04-22 00:09:07 +01:00
|
|
|
func NewLinear(vs *my_nn.Path, in, out int64) *my_nn.Linear {
|
|
|
|
config := my_nn.DefaultLinearConfig()
|
|
|
|
return my_nn.NewLinear(vs, in, out, config)
|
2024-04-19 15:39:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
type Flatten struct{}
|
|
|
|
|
|
|
|
// BasicBlock returns a BasicBlockModule instance
|
|
|
|
func NewFlatten() *Flatten {
|
|
|
|
return &Flatten{}
|
|
|
|
}
|
|
|
|
|
2024-04-22 00:09:07 +01:00
|
|
|
// The flatten layer does not to move anything to the device
|
|
|
|
func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
2024-04-23 00:14:35 +01:00
|
|
|
func (b *Flatten) Debug() {}
|
2024-04-22 00:09:07 +01:00
|
|
|
|
2024-04-19 15:39:51 +01:00
|
|
|
// Forward method
|
|
|
|
func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor {
|
|
|
|
|
|
|
|
out, err := x.Flatten(1, -1, false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
return out
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *Flatten) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
|
|
|
|
|
|
|
out, err := x.Flatten(1, -1, false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
return out
|
|
|
|
}
|
|
|
|
|
|
|
|
type Sigmoid struct{}
|
|
|
|
|
|
|
|
func NewSigmoid() *Sigmoid {
|
|
|
|
return &Sigmoid{}
|
|
|
|
}
|
|
|
|
|
2024-04-22 00:09:07 +01:00
|
|
|
// The sigmoid layer does not need to move anything to another device
|
|
|
|
func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {}
|
2024-04-23 00:14:35 +01:00
|
|
|
func (b *Sigmoid) Debug() {}
|
2024-04-22 00:09:07 +01:00
|
|
|
|
2024-04-19 15:39:51 +01:00
|
|
|
func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor {
|
|
|
|
out, err := x.Sigmoid(false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
return out
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *Sigmoid) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
|
|
|
out, err := x.Sigmoid(false)
|
|
|
|
or_panic(err)
|
|
|
|
|
|
|
|
return out
|
|
|
|
}
|
|
|
|
|