remove linear.bias init when not required
This commit is contained in:
parent
e9278816b2
commit
ef00723027
10
nn/linear.go
10
nn/linear.go
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
|
@ -40,15 +39,8 @@ type Linear struct {
|
|||
// outDim - output dimension (y) [output features - columns]
|
||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||
dtype := gotch.DefaultDType
|
||||
var bs *ts.Tensor
|
||||
// bs has size of output dimension
|
||||
switch c.Bias {
|
||||
case false:
|
||||
// FIXME. do we need this? or just remove it and in the `Forward` creating on-fly
|
||||
// with same dtype and device to the input.
|
||||
bs = ts.MustZeros([]int64{outDim}, dtype, vs.Device())
|
||||
case true:
|
||||
if c.Bias {
|
||||
switch {
|
||||
case c.BsInit == nil:
|
||||
shape := []int64{inDim, outDim}
|
||||
|
|
Loading…
Reference in New Issue
Block a user