feat(sysinfo): added sysinfo helper

This commit is contained in:
sugarme 2020-07-08 20:03:34 +10:00
parent 8107d429b6
commit 787734f624
3 changed files with 167 additions and 17 deletions

View File

@ -36,7 +36,6 @@ func convBn(p nn.Path, cIn, cOut int64) (retVal nn.SequentialT) {
func layer(p nn.Path, cIn, cOut int64) (retVal nn.FuncT) {
pre := convBn(p.Sub("pre"), cIn, cOut)
block1 := convBn(p.Sub("b1"), cOut, cOut)
block2 := convBn(p.Sub("b2"), cOut, cOut)
@ -101,9 +100,16 @@ func main() {
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
fmt.Printf("Number of labels: %v\n", ds.Labels)
cuda := gotch.CudaBuilder(0)
device := cuda.CudaIfAvailable()
// device := gotch.CPU
var si *gotch.SI
si = gotch.GetSysInfo()
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
startRAM := si.TotalRam - si.FreeRam
// cuda := gotch.CudaBuilder(0)
// device := cuda.CudaIfAvailable()
device := gotch.CPU
vs := nn.NewVarStore(device)
@ -122,8 +128,8 @@ func main() {
opt.SetLR(learningRate(epoch))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
iter.Shuffle()
iter = iter.ToDevice(device)
// iter.Shuffle()
// iter = iter.ToDevice(device)
for {
item, ok := iter.Next()
@ -134,26 +140,25 @@ func main() {
// bimages := vision.Augmentation(item.Data, true, 4, 8)
// logits := net.ForwardT(bimages, true)
bImages := item.Data.MustTo(vs.Device(), true)
bLabels := item.Label.MustTo(vs.Device(), true)
// // logits := net.ForwardT(item.Data, true)
logits := net.ForwardT(bImages, true)
// // loss := logits.CrossEntropyForLogits(item.Label)
loss := logits.CrossEntropyForLogits(bLabels)
logits := net.ForwardT(item.Data, false)
loss := logits.CrossEntropyForLogits(item.Label)
opt.BackwardStep(loss)
lossVal = loss.Values()[0]
// logits.MustDrop()
bImages.MustDrop()
bLabels.MustDrop()
item.Data.MustDrop()
item.Label.MustDrop()
loss.MustDrop()
}
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
si = gotch.GetSysInfo()
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", epoch, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
iter.Drop()
}
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)

126
sysinfo.go Normal file
View File

@ -0,0 +1,126 @@
// A wrapper around the linux syscall sysinfo(2).
package gotch
// helper to debug memory blow-up
import (
"fmt"
"sync"
"syscall"
"time"
)
// Go-ized http://man7.org/linux/man-pages/man2/sysinfo.2.html
type SI struct {
Uptime time.Duration // time since boot
Loads [3]float64 // 1, 5, and 15 minute load averages, see e.g. UPTIME(1)
Procs uint64 // number of current processes
TotalRam uint64 // total usable main memory size [kB]
FreeRam uint64 // available memory size [kB]
SharedRam uint64 // amount of shared memory [kB]
BufferRam uint64 // memory used by buffers [kB]
TotalSwap uint64 // total swap space size [kB]
FreeSwap uint64 // swap space still available [kB]
TotalHighRam uint64 // total high memory size [kB]
FreeHighRam uint64 // available high memory size [kB]
mu sync.Mutex // ensures atomic writes; protects the following fields
}
var sis = &SI{}
// Get the linux sysinfo data structure.
//
// Useful links in the wild web:
// http://man7.org/linux/man-pages/man2/sysinfo.2.html
// http://man7.org/linux/man-pages/man1/uptime.1.html
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502
func GetSysInfo() *SI {
/*
// Note: uint64 is uint32 on 32 bit CPUs
type Sysinfo_t struct {
Uptime int64 // Seconds since boot
Loads [3]uint64 // 1, 5, and 15 minute load averages
Totalram uint64 // Total usable main memory size
Freeram uint64 // Available memory size
Sharedram uint64 // Amount of shared memory
Bufferram uint64 // Memory used by buffers
Totalswap uint64 // Total swap space size
Freeswap uint64 // swap space still available
Procs uint16 // Number of current processes
Pad uint16
Pad_cgo_0 [4]byte
Totalhigh uint64 // Total high memory size
Freehigh uint64 // Available high memory size
Unit uint32 // Memory unit size in bytes
X_f [0]byte
Pad_cgo_1 [4]byte // Padding to 64 bytes
}
*/
// ~1kB garbage
si := &syscall.Sysinfo_t{}
// XXX is a raw syscall thread safe?
err := syscall.Sysinfo(si)
if err != nil {
panic("Commander, we have a problem. syscall.Sysinfo:" + err.Error())
}
scale := 65536.0 // magic
defer sis.mu.Unlock()
sis.mu.Lock()
unit := uint64(si.Unit) * 1024 // kB
sis.Uptime = time.Duration(si.Uptime) * time.Second
sis.Loads[0] = float64(si.Loads[0]) / scale
sis.Loads[1] = float64(si.Loads[1]) / scale
sis.Loads[2] = float64(si.Loads[2]) / scale
sis.Procs = uint64(si.Procs)
sis.TotalRam = uint64(si.Totalram) / unit
sis.FreeRam = uint64(si.Freeram) / unit
sis.BufferRam = uint64(si.Bufferram) / unit
sis.TotalSwap = uint64(si.Totalswap) / unit
sis.FreeSwap = uint64(si.Freeswap) / unit
sis.TotalHighRam = uint64(si.Totalhigh) / unit
sis.FreeHighRam = uint64(si.Freehigh) / unit
return sis
}
// Make the "fmt" Stringer interface happy.
func (si SI) String() string {
// XXX: Is the copy of SI done atomic? Not sure.
// Without an outer lock this may print a junk.
return fmt.Sprintf("uptime\t\t%v\nload\t\t%2.2f %2.2f %2.2f\nprocs\t\t%d\n"+
"ram total\t%d kB\nram free\t%d kB\nram buffer\t%d kB\n"+
"swap total\t%d kB\nswap free\t%d kB",
//"high ram total\t%d kB\nhigh ram free\t%d kB\n"
si.Uptime, si.Loads[0], si.Loads[1], si.Loads[2], si.Procs,
si.TotalRam, si.FreeRam, si.BufferRam,
si.TotalSwap, si.FreeSwap,
// archaic si.TotalHighRam, si.FreeHighRam
)
}
/*
Convert to string in a thread safe way.
Output:
uptime 279h6m21s
load 0.12 0.04 0.05
procs 143
ram total 383752 kB
ram free 254980 kB
ram buffer 7640 kB
swap total 887800 kB
swap free 879356 kB
*/
func (si *SI) ToString() string {
defer si.mu.Unlock()
si.mu.Lock()
return si.String()
}

View File

@ -43,9 +43,23 @@ func NewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2, err error) {
return retVal, err
}
xsClone, err := xs.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
xsClone.Copy_(xs)
ysClone, err := ys.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
ysClone.Copy_(ys)
retVal = Iter2{
xs: xs.MustShallowClone(),
ys: ys.MustShallowClone(),
// xs: xs.MustShallowClone(),
// ys: ys.MustShallowClone(),
xs: xsClone,
ys: ysClone,
batchIndex: 0,
batchSize: batchSize,
totalSize: totalSize,
@ -134,3 +148,8 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
}, true
}
}
func (it Iter2) Drop() {
it.xs.MustDrop()
it.ys.MustDrop()
}