fixed Linear.Forward with bias is nil and clean up

This commit is contained in:
sugarme 2023-08-12 22:09:34 +10:00
parent b3d821d34e
commit 1cffab577c
3 changed files with 9 additions and 3 deletions

View File

@ -245,7 +245,7 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtype
retVal.Uniform_(-bound, bound) retVal.Uniform_(-bound, bound)
*/ */
// For now, just make a random norm // NOTE. For now, just make a random norm
retVal = ts.MustRandn(dims, dtype, device) retVal = ts.MustRandn(dims, dtype, device)
return retVal return retVal

View File

@ -100,7 +100,11 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
// 1 1 1 ] // 1 1 1 ]
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) { func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false) mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true) if l.Bs != nil {
return mul.MustAdd(l.Bs, true)
} else {
return mul
}
} }
// ForwardT implements ModuleT interface for Linear layer. // ForwardT implements ModuleT interface for Linear layer.

View File

@ -129,7 +129,9 @@ func freeCTensor(ts *Tensor) error {
// Just return if it has been deleted previously! // Just return if it has been deleted previously!
if unsafe.Pointer(ts.ctensor) == nil { if unsafe.Pointer(ts.ctensor) == nil {
log.Printf("INFO: ctensor is nil. Nothing to delete here...\n") if gotch.Debug {
log.Printf("INFO: ctensor is nil. Nothing to delete here...\n")
}
return nil return nil
} }