More work done on torch
This commit is contained in:
168
logic/models/train/torch/nn/linear.go
Normal file
168
logic/models/train/torch/nn/linear.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package my_nn
|
||||
|
||||
// linear is a fully-connected layer
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
||||
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
||||
)
|
||||
|
||||
// LinearConfig is a configuration for a linear layer
|
||||
type LinearConfig struct {
|
||||
WsInit nn.Init // iniital weights
|
||||
BsInit nn.Init // optional initial bias
|
||||
Bias bool
|
||||
}
|
||||
|
||||
// DefaultLinearConfig creates default LinearConfig with
|
||||
// weights initiated using KaimingUniform and Bias is set to true
|
||||
func DefaultLinearConfig() *LinearConfig {
|
||||
negSlope := math.Sqrt(5)
|
||||
return &LinearConfig{
|
||||
// NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!!
|
||||
// Avoid using it now.
|
||||
WsInit: nn.NewKaimingUniformInit(nn.WithKaimingNegativeSlope(negSlope)),
|
||||
BsInit: nil,
|
||||
Bias: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Linear is a linear fully-connected layer
|
||||
type Linear struct {
|
||||
Ws *ts.Tensor
|
||||
weight_name string
|
||||
Bs *ts.Tensor
|
||||
bias_name string
|
||||
}
|
||||
|
||||
// NewLinear creates a new linear layer
|
||||
// y = x*wT + b
|
||||
// inDim - input dimension (x) [input features - columns]
|
||||
// outDim - output dimension (y) [output features - columns]
|
||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||
var bias_name string
|
||||
var bs *ts.Tensor
|
||||
var err error
|
||||
if c.Bias {
|
||||
switch {
|
||||
case c.BsInit == nil:
|
||||
shape := []int64{inDim, outDim}
|
||||
fanIn, _, err := nn.CalculateFans(shape)
|
||||
or_panic(err)
|
||||
bound := 0.0
|
||||
if fanIn > 0 {
|
||||
bound = 1 / math.Sqrt(float64(fanIn))
|
||||
}
|
||||
bsInit := nn.NewUniformInit(-bound, bound)
|
||||
bs, bias_name, err = vs.NewVarNamed("bias", []int64{outDim}, bsInit)
|
||||
or_panic(err)
|
||||
|
||||
// Find better way to do this
|
||||
bs, err = bs.T(true)
|
||||
or_panic(err)
|
||||
bs, err = bs.T(true)
|
||||
or_panic(err)
|
||||
|
||||
bs, err = bs.SetRequiresGrad(true, true)
|
||||
or_panic(err)
|
||||
|
||||
err = bs.RetainGrad(false)
|
||||
or_panic(err)
|
||||
|
||||
vs.varstore.UpdateVarTensor(bias_name, bs, true)
|
||||
|
||||
case c.BsInit != nil:
|
||||
bs, bias_name, err = vs.NewVarNamed("bias", []int64{outDim}, c.BsInit)
|
||||
or_panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
ws, weight_name, err := vs.NewVarNamed("weight", []int64{outDim, inDim}, c.WsInit)
|
||||
or_panic(err)
|
||||
|
||||
ws, err = ws.T(true)
|
||||
or_panic(err)
|
||||
|
||||
ws, err = ws.SetRequiresGrad(true, true)
|
||||
or_panic(err)
|
||||
|
||||
err = ws.RetainGrad(false)
|
||||
or_panic(err)
|
||||
|
||||
|
||||
vs.varstore.UpdateVarTensor(weight_name, ws, true)
|
||||
|
||||
|
||||
return &Linear{
|
||||
Ws: ws,
|
||||
weight_name: weight_name,
|
||||
Bs: bs,
|
||||
bias_name: bias_name,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Linear) ExtractFromVarstore(vs *VarStore) {
|
||||
l.Ws = vs.GetTensorOfVar(l.weight_name)
|
||||
l.Bs = vs.GetTensorOfVar(l.bias_name)
|
||||
}
|
||||
|
||||
// Implement `Module` for `Linear` struct:
|
||||
// =======================================
|
||||
|
||||
// Forward proceeds input node through linear layer.
|
||||
// NOTE:
|
||||
// - It assumes that node has dimensions of 2 (matrix).
|
||||
// To make it work for matrix multiplication, input node should
|
||||
// has same number of **column** as number of **column** in
|
||||
// `LinearLayer` `Ws` property as weights matrix will be
|
||||
// transposed before multiplied to input node. (They are all used `inDim`)
|
||||
// - Input node should have shape of `shape{batch size, input features}`.
|
||||
// (shape{batchSize, inDim}). The input features is `inDim` while the
|
||||
// output feature is `outDim` in `LinearConfig` struct.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// inDim := 3
|
||||
// outDim := 2
|
||||
// batchSize := 4
|
||||
// weights: 2x3
|
||||
// [ 1 1 1
|
||||
// 1 1 1 ]
|
||||
//
|
||||
// input node: 3x4
|
||||
// [ 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
||||
mul, err := xs.Matmul(l.Ws, false)
|
||||
or_panic(err)
|
||||
if l.Bs != nil {
|
||||
mul, err = mul.Add(l.Bs, false)
|
||||
or_panic(err)
|
||||
}
|
||||
|
||||
out, err := mul.Relu(false)
|
||||
or_panic(err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// ForwardT implements ModuleT interface for Linear layer.
|
||||
//
|
||||
// NOTE: train param will not be used.
|
||||
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
||||
mul, err := xs.Matmul(l.Ws, true)
|
||||
or_panic(err)
|
||||
|
||||
|
||||
mul, err = mul.Add(l.Bs, true)
|
||||
or_panic(err)
|
||||
|
||||
out, err := mul.Relu(true)
|
||||
or_panic(err)
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user