More work on tring to make torch work

This commit is contained in:
2024-04-23 00:14:35 +01:00
parent 703fea46f2
commit a4a9ade71f
7 changed files with 109 additions and 43 deletions

View File

@@ -7,6 +7,7 @@ import (
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
"github.com/charmbracelet/log"
)
// LinearConfig is a configuration for a linear layer
@@ -104,6 +105,11 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
}
}
func (l *Linear) Debug() {
log.Info("Ws", "ws", l.Ws.MustGrad(false).MustMax(false).Float64Values())
log.Info("Bs", "bs", l.Bs.MustGrad(false).MustMax(false).Float64Values())
}
func (l *Linear) ExtractFromVarstore(vs *VarStore) {
l.Ws = vs.GetTensorOfVar(l.weight_name)
l.Bs = vs.GetTensorOfVar(l.bias_name)

View File

@@ -14,4 +14,5 @@ type MyLayer interface {
torch.ModuleT
ExtractFromVarstore(vs *VarStore)
Debug()
}