added nn.Path.Paths() and nn.VarStore.Summary() methods

This commit is contained in:
sugarme 2021-11-13 22:23:38 +11:00
parent 6221c32cea
commit 653caf4be5
2 changed files with 30 additions and 0 deletions

View File

@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- added `nn.Path.Paths()` method
- added `nn.VarStore.Summary()` method
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -4,6 +4,7 @@ import (
"fmt"
"log"
"reflect"
"sort"
"strings"
"sync"
@ -322,6 +323,28 @@ func (vs *VarStore) Copy(src VarStore) error {
return nil
}
// Summary prints a simple list of all named variables with their shapes.
func (vs *VarStore) Summary() {
namedTensors := vs.Vars.NamedVariables
layers := make([]string, 0, len(namedTensors))
for tsName := range namedTensors {
layers = append(layers, tsName)
}
sort.Strings(layers)
for _, l := range layers {
var x *ts.Tensor
for tsName, tsVal := range namedTensors {
if tsName == l {
x = tsVal
break
}
}
fmt.Printf("%s - %+v\n", l, x.MustSize())
}
fmt.Printf("Num of layers: %v\n", len(namedTensors))
}
// Path methods:
// =============
@ -341,6 +364,11 @@ func (p *Path) Sub(str string) *Path {
}
}
// Paths returns all sub paths from current path.
func (p *Path) Paths() []string {
return p.path
}
// Device gets the device where the var-store variables are stored.
func (p *Path) Device() gotch.Device {