175 lines
4.2 KiB
Go
175 lines
4.2 KiB
Go
package my_nn
|
|
|
|
// linear is a fully-connected layer
|
|
|
|
import (
|
|
"math"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
"github.com/charmbracelet/log"
|
|
)
|
|
|
|
// LinearConfig is a configuration for a linear layer
|
|
type LinearConfig struct {
|
|
WsInit nn.Init // iniital weights
|
|
BsInit nn.Init // optional initial bias
|
|
Bias bool
|
|
}
|
|
|
|
// DefaultLinearConfig creates default LinearConfig with
|
|
// weights initiated using KaimingUniform and Bias is set to true
|
|
func DefaultLinearConfig() *LinearConfig {
|
|
negSlope := math.Sqrt(5)
|
|
return &LinearConfig{
|
|
// NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!!
|
|
// Avoid using it now.
|
|
WsInit: nn.NewKaimingUniformInit(nn.WithKaimingNegativeSlope(negSlope)),
|
|
BsInit: nil,
|
|
Bias: true,
|
|
}
|
|
}
|
|
|
|
// Linear is a linear fully-connected layer
|
|
type Linear struct {
|
|
Ws *ts.Tensor
|
|
weight_name string
|
|
Bs *ts.Tensor
|
|
bias_name string
|
|
}
|
|
|
|
// NewLinear creates a new linear layer
|
|
// y = x*wT + b
|
|
// inDim - input dimension (x) [input features - columns]
|
|
// outDim - output dimension (y) [output features - columns]
|
|
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
|
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|
var bias_name string
|
|
var bs *ts.Tensor
|
|
var err error
|
|
if c.Bias {
|
|
switch {
|
|
case c.BsInit == nil:
|
|
shape := []int64{inDim, outDim}
|
|
fanIn, _, err := nn.CalculateFans(shape)
|
|
or_panic(err)
|
|
bound := 0.0
|
|
if fanIn > 0 {
|
|
bound = 1 / math.Sqrt(float64(fanIn))
|
|
}
|
|
bsInit := nn.NewUniformInit(-bound, bound)
|
|
bs, bias_name, err = vs.NewVarNamed("bias", []int64{outDim}, bsInit)
|
|
or_panic(err)
|
|
|
|
// Find better way to do this
|
|
bs, err = bs.T(true)
|
|
or_panic(err)
|
|
bs, err = bs.T(true)
|
|
or_panic(err)
|
|
|
|
bs, err = bs.SetRequiresGrad(true, true)
|
|
or_panic(err)
|
|
|
|
err = bs.RetainGrad(false)
|
|
or_panic(err)
|
|
|
|
vs.varstore.UpdateVarTensor(bias_name, bs, true)
|
|
|
|
case c.BsInit != nil:
|
|
bs, bias_name, err = vs.NewVarNamed("bias", []int64{outDim}, c.BsInit)
|
|
or_panic(err)
|
|
}
|
|
}
|
|
|
|
ws, weight_name, err := vs.NewVarNamed("weight", []int64{outDim, inDim}, c.WsInit)
|
|
or_panic(err)
|
|
|
|
ws, err = ws.T(true)
|
|
or_panic(err)
|
|
|
|
ws, err = ws.SetRequiresGrad(true, true)
|
|
or_panic(err)
|
|
|
|
err = ws.RetainGrad(false)
|
|
or_panic(err)
|
|
|
|
|
|
vs.varstore.UpdateVarTensor(weight_name, ws, true)
|
|
|
|
|
|
return &Linear{
|
|
Ws: ws,
|
|
weight_name: weight_name,
|
|
Bs: bs,
|
|
bias_name: bias_name,
|
|
}
|
|
}
|
|
|
|
func (l *Linear) Debug() {
|
|
log.Info("Ws", "ws", l.Ws.MustGrad(false).MustMax(false).Float64Values())
|
|
log.Info("Bs", "bs", l.Bs.MustGrad(false).MustMax(false).Float64Values())
|
|
}
|
|
|
|
func (l *Linear) ExtractFromVarstore(vs *VarStore) {
|
|
l.Ws = vs.GetTensorOfVar(l.weight_name)
|
|
l.Bs = vs.GetTensorOfVar(l.bias_name)
|
|
}
|
|
|
|
// Implement `Module` for `Linear` struct:
|
|
// =======================================
|
|
|
|
// Forward proceeds input node through linear layer.
|
|
// NOTE:
|
|
// - It assumes that node has dimensions of 2 (matrix).
|
|
// To make it work for matrix multiplication, input node should
|
|
// has same number of **column** as number of **column** in
|
|
// `LinearLayer` `Ws` property as weights matrix will be
|
|
// transposed before multiplied to input node. (They are all used `inDim`)
|
|
// - Input node should have shape of `shape{batch size, input features}`.
|
|
// (shape{batchSize, inDim}). The input features is `inDim` while the
|
|
// output feature is `outDim` in `LinearConfig` struct.
|
|
//
|
|
// Example:
|
|
//
|
|
// inDim := 3
|
|
// outDim := 2
|
|
// batchSize := 4
|
|
// weights: 2x3
|
|
// [ 1 1 1
|
|
// 1 1 1 ]
|
|
//
|
|
// input node: 3x4
|
|
// [ 1 1 1
|
|
// 1 1 1
|
|
// 1 1 1
|
|
// 1 1 1 ]
|
|
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
|
mul, err := xs.Matmul(l.Ws, false)
|
|
or_panic(err)
|
|
if l.Bs != nil {
|
|
mul, err = mul.Add(l.Bs, false)
|
|
or_panic(err)
|
|
}
|
|
|
|
out, err := mul.Relu(false)
|
|
or_panic(err)
|
|
|
|
return out
|
|
}
|
|
|
|
// ForwardT implements ModuleT interface for Linear layer.
|
|
//
|
|
// NOTE: train param will not be used.
|
|
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
|
mul, err := xs.Matmul(l.Ws, true)
|
|
or_panic(err)
|
|
|
|
|
|
mul, err = mul.Add(l.Bs, true)
|
|
or_panic(err)
|
|
|
|
out, err := mul.Relu(true)
|
|
or_panic(err)
|
|
return out
|
|
}
|