added nn.Path.Paths() and nn.VarStore.Summary() methods
This commit is contained in:
parent
6221c32cea
commit
653caf4be5
|
@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- added `nn.Path.Paths()` method
|
||||
- added `nn.VarStore.Summary()` method
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
@ -322,6 +323,28 @@ func (vs *VarStore) Copy(src VarStore) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Summary prints a simple list of all named variables with their shapes.
|
||||
func (vs *VarStore) Summary() {
|
||||
namedTensors := vs.Vars.NamedVariables
|
||||
layers := make([]string, 0, len(namedTensors))
|
||||
for tsName := range namedTensors {
|
||||
layers = append(layers, tsName)
|
||||
}
|
||||
sort.Strings(layers)
|
||||
for _, l := range layers {
|
||||
var x *ts.Tensor
|
||||
for tsName, tsVal := range namedTensors {
|
||||
if tsName == l {
|
||||
x = tsVal
|
||||
break
|
||||
}
|
||||
}
|
||||
fmt.Printf("%s - %+v\n", l, x.MustSize())
|
||||
}
|
||||
|
||||
fmt.Printf("Num of layers: %v\n", len(namedTensors))
|
||||
}
|
||||
|
||||
// Path methods:
|
||||
// =============
|
||||
|
||||
|
@ -341,6 +364,11 @@ func (p *Path) Sub(str string) *Path {
|
|||
}
|
||||
}
|
||||
|
||||
// Paths returns all sub paths from current path.
|
||||
func (p *Path) Paths() []string {
|
||||
return p.path
|
||||
}
|
||||
|
||||
// Device gets the device where the var-store variables are stored.
|
||||
func (p *Path) Device() gotch.Device {
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user