reworked ts.Format()
This commit is contained in:
parent
9121872ceb
commit
ce421dd4c5
|
@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
- Fixed incorrect indexing at `dutil/Dataset.Next()`
|
- Fixed incorrect indexing at `dutil/Dataset.Next()`
|
||||||
- Added `nn.MSELoss()`
|
- Added `nn.MSELoss()`
|
||||||
|
- reworked `ts.Format()`
|
||||||
|
|
||||||
## [Nofix]
|
## [Nofix]
|
||||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||||
|
|
744
ts/print.go
744
ts/print.go
|
@ -6,14 +6,10 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
)
|
)
|
||||||
|
|
||||||
var fmtByte = []byte("%")
|
|
||||||
var precByte = []byte(".")
|
|
||||||
|
|
||||||
func (ts *Tensor) ValueGo() interface{} {
|
func (ts *Tensor) ValueGo() interface{} {
|
||||||
dtype := ts.DType()
|
dtype := ts.DType()
|
||||||
numel := ts.Numel()
|
numel := ts.Numel()
|
||||||
|
@ -51,288 +47,6 @@ func (ts *Tensor) ValueGo() interface{} {
|
||||||
|
|
||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
func (ts *Tensor) ToSlice() reflect.Value {
|
|
||||||
// Create a 1-dimensional slice of the base large enough for the data and
|
|
||||||
// copy the data in.
|
|
||||||
shape := ts.MustSize()
|
|
||||||
dt := ts.DType()
|
|
||||||
n := int(numElements(shape))
|
|
||||||
var (
|
|
||||||
slice reflect.Value
|
|
||||||
typ reflect.Type
|
|
||||||
)
|
|
||||||
if dt.String() == "String" {
|
|
||||||
panic("Unsupported 'String' type")
|
|
||||||
} else {
|
|
||||||
gtyp, err := gotch.ToGoType(dt)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
typ = reflect.SliceOf(gtyp)
|
|
||||||
slice = reflect.MakeSlice(typ, n, n)
|
|
||||||
data := ts.ValueGo()
|
|
||||||
slice = reflect.ValueOf(data)
|
|
||||||
}
|
|
||||||
// Now we have the data in place in the base slice we can add the
|
|
||||||
// dimensions. We want to walk backwards through the shape. If the shape is
|
|
||||||
// length 1 or 0 then we're already done.
|
|
||||||
if len(shape) == 0 {
|
|
||||||
return slice.Index(0)
|
|
||||||
}
|
|
||||||
if len(shape) == 1 {
|
|
||||||
return slice
|
|
||||||
}
|
|
||||||
// We have a special case if the tensor has no data. Our backing slice is
|
|
||||||
// empty, but we still want to create slices following the shape. In this
|
|
||||||
// case only the final part of the shape will be 0 and we want to recalculate
|
|
||||||
// n at this point ignoring that 0.
|
|
||||||
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
|
|
||||||
// want 6 zero length slices to group as follows.
|
|
||||||
// {{} {}} {{} {}} {{} {}}
|
|
||||||
if n == 0 {
|
|
||||||
n = int(numElements(shape[:len(shape)-1]))
|
|
||||||
}
|
|
||||||
for i := len(shape) - 2; i >= 0; i-- {
|
|
||||||
underlyingSize := typ.Elem().Size()
|
|
||||||
typ = reflect.SliceOf(typ)
|
|
||||||
subsliceLen := int(shape[i+1])
|
|
||||||
if subsliceLen != 0 {
|
|
||||||
n = n / subsliceLen
|
|
||||||
}
|
|
||||||
// Just using reflection it is difficult to avoid unnecessary
|
|
||||||
// allocations while setting up the sub-slices as the Slice function on
|
|
||||||
// a slice Value allocates. So we end up doing pointer arithmetic!
|
|
||||||
// Pointer() on a slice gives us access to the data backing the slice.
|
|
||||||
// We insert slice headers directly into this data.
|
|
||||||
data := unsafe.Pointer(slice.Pointer())
|
|
||||||
nextSlice := reflect.MakeSlice(typ, n, n)
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
|
|
||||||
setSliceInSlice(nextSlice, j, sliceHeader{
|
|
||||||
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
|
|
||||||
Len: subsliceLen,
|
|
||||||
Cap: subsliceLen,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("nextSlice length: %v\n", nextSlice.Len())
|
|
||||||
fmt.Printf("%v\n\n", nextSlice)
|
|
||||||
|
|
||||||
slice = nextSlice
|
|
||||||
}
|
|
||||||
return slice
|
|
||||||
}
|
|
||||||
|
|
||||||
// setSliceInSlice sets slice[index] = content.
|
|
||||||
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
|
|
||||||
const sliceSize = unsafe.Sizeof(sliceHeader{})
|
|
||||||
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
|
|
||||||
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
|
|
||||||
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
|
|
||||||
}
|
|
||||||
func numElements(shape []int64) int64 {
|
|
||||||
n := int64(1)
|
|
||||||
for _, d := range shape {
|
|
||||||
n *= d
|
|
||||||
}
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
|
|
||||||
// this is not inspected by the garbage collector
|
|
||||||
type sliceHeader struct {
|
|
||||||
Data unsafe.Pointer
|
|
||||||
Len int
|
|
||||||
Cap int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format implements fmt.Formatter interface so that we can use
|
|
||||||
// fmt.Print... and verbs to print out Tensor value in different formats.
|
|
||||||
func (ts *Tensor) Format(s fmt.State, c rune) {
|
|
||||||
shape := ts.MustSize()
|
|
||||||
device := ts.MustDevice()
|
|
||||||
dtype := ts.DType()
|
|
||||||
if c == 'i' {
|
|
||||||
fmt.Fprintf(s, "\nTENSOR INFO:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n\tDefined:\t%v\n", shape, dtype, device, ts.MustDefined())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data := ts.ValueGo()
|
|
||||||
|
|
||||||
f := newFmtState(s, c, shape)
|
|
||||||
f.setWidth(data)
|
|
||||||
f.makePad()
|
|
||||||
|
|
||||||
// 0d (scalar)
|
|
||||||
if len(shape) == 0 {
|
|
||||||
fmt.Printf("%v", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1d (slice)
|
|
||||||
if len(shape) == 1 {
|
|
||||||
f.writeSlice(data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2d (matrix)
|
|
||||||
if len(shape) == 2 {
|
|
||||||
f.writeMatrix(data, shape)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// >= 3d (tensor)
|
|
||||||
mSize := int(shape[len(shape)-2] * shape[len(shape)-1])
|
|
||||||
mShape := shape[len(shape)-2:]
|
|
||||||
dims := shape[:len(shape)-2]
|
|
||||||
var rdims []int64
|
|
||||||
for d := len(dims) - 1; d >= 0; d-- {
|
|
||||||
rdims = append(rdims, dims[d])
|
|
||||||
}
|
|
||||||
|
|
||||||
f.writeTensor(0, rdims, 0, mSize, mShape, data, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// fmtState is a struct that implements fmt.State interface
|
|
||||||
type fmtState struct {
|
|
||||||
fmt.State
|
|
||||||
c rune // format verb
|
|
||||||
pad []byte // padding
|
|
||||||
w int // width
|
|
||||||
p int // precision
|
|
||||||
shape []int64
|
|
||||||
buf *bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFmtState(s fmt.State, c rune, shape []int64) *fmtState {
|
|
||||||
w, _ := s.Width()
|
|
||||||
p, _ := s.Precision()
|
|
||||||
|
|
||||||
return &fmtState{
|
|
||||||
State: s,
|
|
||||||
c: c,
|
|
||||||
w: w,
|
|
||||||
p: p,
|
|
||||||
shape: shape,
|
|
||||||
buf: bytes.NewBuffer(make([]byte, 0)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeTensor iterates recursively through a reversed shape of tensor starting from axis 3 and
|
|
||||||
// and prints out matrices (of last two dims size).
|
|
||||||
func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShape []int64, data interface{}, mName string) {
|
|
||||||
|
|
||||||
offsetSize := product(dims[:d]) * mSize
|
|
||||||
for i := 0; i < int(dims[d]); i++ {
|
|
||||||
name := fmt.Sprintf("%v,%v", i+1, mName)
|
|
||||||
if d >= len(dims)-1 { // last dim, let's print out
|
|
||||||
// write matrix name
|
|
||||||
nameStr := fmt.Sprintf("(%v.,.) =\n", name)
|
|
||||||
f.Write([]byte(nameStr))
|
|
||||||
// write matrix values
|
|
||||||
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
|
|
||||||
f.writeMatrix(slice, mShape)
|
|
||||||
} else { // recursively loop
|
|
||||||
f.writeTensor(d+1, dims, offset, mSize, mShape, data, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// update offset
|
|
||||||
offset += offsetSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func product(dims []int64) int {
|
|
||||||
var p int = 1
|
|
||||||
if len(dims) == 0 {
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, d := range dims {
|
|
||||||
p = p * int(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fmtState) writeMatrix(data interface{}, shape []int64) {
|
|
||||||
n := shapeToSize(shape)
|
|
||||||
dataLen := reflect.ValueOf(data).Len()
|
|
||||||
if dataLen != n {
|
|
||||||
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
|
|
||||||
}
|
|
||||||
if len(shape) != 2 {
|
|
||||||
log.Fatal("Shape must have length of 2.\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
stride := int(shape[1])
|
|
||||||
currIdx := 0
|
|
||||||
nextIdx := stride
|
|
||||||
for i := 0; i < int(shape[0]); i++ {
|
|
||||||
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
|
|
||||||
f.writeSlice(slice.Interface())
|
|
||||||
currIdx = nextIdx
|
|
||||||
nextIdx += stride
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1 line between matrix
|
|
||||||
f.Write([]byte("\n"))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fmtState) writeSlice(data interface{}) {
|
|
||||||
format := f.cleanFmt()
|
|
||||||
dataLen := reflect.ValueOf(data).Len()
|
|
||||||
for i := 0; i < dataLen; i++ {
|
|
||||||
el := reflect.ValueOf(data).Index(i).Interface()
|
|
||||||
|
|
||||||
// TODO: more format options here
|
|
||||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
|
||||||
f.Write(f.buf.Bytes())
|
|
||||||
f.Write(f.pad[:f.w-w]) // prepad
|
|
||||||
f.Write(f.pad[:2]) // pad
|
|
||||||
f.buf.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
f.Write([]byte("\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fmtState) cleanFmt() string {
|
|
||||||
buf := bytes.NewBuffer(fmtByte)
|
|
||||||
|
|
||||||
// width
|
|
||||||
if w, ok := f.Width(); ok {
|
|
||||||
buf.WriteString(strconv.Itoa(w))
|
|
||||||
}
|
|
||||||
|
|
||||||
// precision
|
|
||||||
if p, ok := f.Precision(); ok {
|
|
||||||
buf.Write(precByte)
|
|
||||||
buf.WriteString(strconv.Itoa(p))
|
|
||||||
}
|
|
||||||
|
|
||||||
buf.WriteRune(f.c)
|
|
||||||
return buf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fmtState) makePad() {
|
|
||||||
f.pad = make([]byte, maxInt(f.w, 4))
|
|
||||||
for i := range f.pad {
|
|
||||||
f.pad[i] = ' '
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setWidth determines maximal width from input data and set to `w` field
|
|
||||||
func (f *fmtState) setWidth(data interface{}) {
|
|
||||||
format := f.cleanFmt()
|
|
||||||
f.w = 0
|
|
||||||
for i := 0; i < reflect.ValueOf(data).Len(); i++ {
|
|
||||||
el := reflect.ValueOf(data).Index(i).Interface()
|
|
||||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
|
||||||
if w > f.w {
|
|
||||||
f.w = w
|
|
||||||
}
|
|
||||||
f.buf.Reset()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shapeToSize(shape []int64) int {
|
func shapeToSize(shape []int64) int {
|
||||||
n := 1
|
n := 1
|
||||||
|
@ -345,9 +59,467 @@ func shapeToSize(shape []int64) int {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shapeToNumels(shape []int) int {
|
||||||
|
n := 1
|
||||||
|
for _, d := range shape {
|
||||||
|
n *= int(d)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func shapeToStrides(shape []int) []int {
|
||||||
|
numel := shapeToNumels(shape)
|
||||||
|
var strides []int
|
||||||
|
for _, v := range shape {
|
||||||
|
numel /= int(v)
|
||||||
|
strides = append(strides, numel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strides
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSliceInt(in []int64) []int {
|
||||||
|
out := make([]int, len(in))
|
||||||
|
for i := 0; i < len(in); i++ {
|
||||||
|
out[i] = int(in[i])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func maxInt(a, b int) int {
|
func maxInt(a, b int) int {
|
||||||
if a >= b {
|
if a >= b {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func minInt(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
} else {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSlice(input interface{}) []interface{} {
|
||||||
|
vlen := reflect.ValueOf(input).Len()
|
||||||
|
out := make([]interface{}, vlen)
|
||||||
|
for i := 0; i < vlen; i++ {
|
||||||
|
out[i] = reflect.ValueOf(input).Index(i).Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func sliceInterface(data interface{}, start, end int) []interface{} {
|
||||||
|
return toSlice(data)[start:end]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Format interface for Tensor:
|
||||||
|
// ======================================
|
||||||
|
var (
|
||||||
|
fmtByte = []byte("%")
|
||||||
|
precByte = []byte(".")
|
||||||
|
fmtFlags = [...]rune{'+', '-', '#', ' ', '0'}
|
||||||
|
|
||||||
|
ufVec = []byte("Vector")
|
||||||
|
ufMat = []byte("Matrix")
|
||||||
|
ufTensor = []byte("Tensor")
|
||||||
|
)
|
||||||
|
|
||||||
|
// fmtState is a custom formatter for Tensor that implements fmt.State interface
|
||||||
|
type fmtState struct {
|
||||||
|
fmt.State
|
||||||
|
verb rune // format verb
|
||||||
|
flat bool // whether to print tensor in flatten format
|
||||||
|
meta bool // whether to print out meta data
|
||||||
|
ext bool // whether to print full tensor data (no truncation)
|
||||||
|
pad []byte // padding space
|
||||||
|
w int // width - total calculated space to print out tensor values
|
||||||
|
p int // precision for float dtype
|
||||||
|
base int // integer counting base for integer dtype cases
|
||||||
|
htrunc []byte // horizontal truncation symbol space
|
||||||
|
vtrunc []byte // vertical truncation symbol space
|
||||||
|
rows int // total rows
|
||||||
|
cols int // total columns
|
||||||
|
printRows int // rows to print
|
||||||
|
printCols int // columns to print
|
||||||
|
shape []int // shape of tensor to print out
|
||||||
|
buf *bytes.Buffer // memory to hold formated tensor data to print
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFmtState(s fmt.State, verb rune, shape []int) *fmtState {
|
||||||
|
w, _ := s.Width()
|
||||||
|
p, _ := s.Precision()
|
||||||
|
|
||||||
|
return &fmtState{
|
||||||
|
State: s,
|
||||||
|
verb: verb,
|
||||||
|
flat: s.Flag('-'),
|
||||||
|
meta: s.Flag('+'),
|
||||||
|
ext: s.Flag('#'),
|
||||||
|
w: w,
|
||||||
|
p: p,
|
||||||
|
htrunc: []byte("..., "),
|
||||||
|
vtrunc: []byte("...,\n"),
|
||||||
|
shape: shape,
|
||||||
|
buf: bytes.NewBuffer(make([]byte, 0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// originalFmt returns original format.
|
||||||
|
func (f *fmtState) originalFmt() string {
|
||||||
|
// write format symbol and verbs
|
||||||
|
buf := bytes.NewBuffer(fmtByte) // '%'
|
||||||
|
for _, flag := range fmtFlags {
|
||||||
|
if f.Flag(int(flag)) {
|
||||||
|
buf.WriteRune(flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// write width format
|
||||||
|
if w, ok := f.Width(); ok {
|
||||||
|
buf.WriteString(strconv.Itoa(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// write precision verb
|
||||||
|
if p, ok := f.Precision(); ok {
|
||||||
|
buf.Write(precByte)
|
||||||
|
buf.WriteString(strconv.Itoa(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(f.verb)
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanFmt returns a start of the format.
|
||||||
|
func (f *fmtState) initFmt() string {
|
||||||
|
buf := bytes.NewBuffer(fmtByte)
|
||||||
|
|
||||||
|
// write width format
|
||||||
|
if w, ok := f.Width(); ok {
|
||||||
|
buf.WriteString(strconv.Itoa(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// write precision verb
|
||||||
|
if p, ok := f.Precision(); ok {
|
||||||
|
buf.Write(precByte)
|
||||||
|
buf.WriteString(strconv.Itoa(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(f.verb)
|
||||||
|
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// cast casts tensor data to the formatter.
|
||||||
|
func (f *fmtState) cast(ts *Tensor) {
|
||||||
|
// rows and columns
|
||||||
|
if ts.Dim() == 1 {
|
||||||
|
f.rows = 1
|
||||||
|
f.cols = int(ts.Numel())
|
||||||
|
} else {
|
||||||
|
shape := ts.MustSize()
|
||||||
|
f.rows = int(shape[len(shape)-2])
|
||||||
|
f.cols = int(shape[len(shape)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// printRows and printCols
|
||||||
|
switch {
|
||||||
|
case f.flat && f.ext:
|
||||||
|
f.printCols = int(ts.Numel())
|
||||||
|
case f.flat:
|
||||||
|
f.printCols = 10
|
||||||
|
case f.ext:
|
||||||
|
f.printCols = f.cols
|
||||||
|
f.printRows = f.rows
|
||||||
|
default:
|
||||||
|
f.printCols = minInt(f.cols, 6)
|
||||||
|
f.printRows = minInt(f.rows, 6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmtVerb formats verbs.
|
||||||
|
func (f *fmtState) fmtVerb(ts *Tensor) {
|
||||||
|
if f.verb == 'H' { // print out only header.
|
||||||
|
f.meta = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// var typ T
|
||||||
|
typ := ts.DType()
|
||||||
|
|
||||||
|
switch typ.String() {
|
||||||
|
case "float32", "float64":
|
||||||
|
switch f.verb {
|
||||||
|
case 'f', 'e', 'E', 'G', 'b':
|
||||||
|
// accepted. Do nothing
|
||||||
|
default:
|
||||||
|
f.verb = 'g'
|
||||||
|
}
|
||||||
|
|
||||||
|
case "uint8", "int8", "int16", "int32", "int64":
|
||||||
|
switch f.verb {
|
||||||
|
case 'b':
|
||||||
|
f.base = 2
|
||||||
|
case 'd':
|
||||||
|
f.base = 10
|
||||||
|
case 'o':
|
||||||
|
f.base = 8
|
||||||
|
case 'x', 'X':
|
||||||
|
f.base = 16
|
||||||
|
default:
|
||||||
|
f.base = 10
|
||||||
|
f.verb = 'd'
|
||||||
|
}
|
||||||
|
case "bool":
|
||||||
|
f.verb = 't'
|
||||||
|
default:
|
||||||
|
f.verb = 'v'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeWidth computes a width that can fit for every element.
|
||||||
|
func (f *fmtState) computeWidth(values interface{}) {
|
||||||
|
format := f.initFmt()
|
||||||
|
vlen := reflect.ValueOf(values).Len()
|
||||||
|
f.w = 0
|
||||||
|
for i := 0; i < vlen; i++ {
|
||||||
|
val := reflect.ValueOf(values).Index(i)
|
||||||
|
w, _ := fmt.Fprintf(f.buf, format, val)
|
||||||
|
|
||||||
|
if w > f.w {
|
||||||
|
f.w = w
|
||||||
|
}
|
||||||
|
f.buf.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makePad prepares white spaces for print-out format.
|
||||||
|
func (f *fmtState) makePad() {
|
||||||
|
f.pad = make([]byte, maxInt(f.w, 2))
|
||||||
|
for i := range f.pad {
|
||||||
|
f.pad[i] = ' ' // one white space
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fmtState) writeHTrunc() {
|
||||||
|
f.Write(f.htrunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fmtState) writeVTrunc() {
|
||||||
|
f.Write(f.vtrunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format implements fmt.Formatter interface so that we can use
|
||||||
|
// fmt.Print... and verbs to print out Tensor value in different formats.
|
||||||
|
func (ts *Tensor) Format(s fmt.State, verb rune) {
|
||||||
|
shape := toSliceInt(ts.MustSize())
|
||||||
|
strides := shapeToStrides(shape)
|
||||||
|
device := ts.MustDevice()
|
||||||
|
dtype := ts.DType().String()
|
||||||
|
if verb == 'i' {
|
||||||
|
fmt.Fprintf(
|
||||||
|
s,
|
||||||
|
"\nTENSOR META:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n",
|
||||||
|
shape,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := ts.ValueGo()
|
||||||
|
|
||||||
|
f := newFmtState(s, verb, shape)
|
||||||
|
f.computeWidth(data)
|
||||||
|
f.makePad()
|
||||||
|
f.cast(ts)
|
||||||
|
|
||||||
|
// Tensor meta data
|
||||||
|
if f.meta {
|
||||||
|
switch ts.Dim() {
|
||||||
|
case 1:
|
||||||
|
f.Write(ufVec)
|
||||||
|
case 2:
|
||||||
|
f.Write(ufMat)
|
||||||
|
default:
|
||||||
|
f.Write(ufTensor)
|
||||||
|
fmt.Fprintf(f, ": Dim=%d, ", ts.Dim())
|
||||||
|
}
|
||||||
|
fmt.Fprintf(f, "Shape=%v, Strides=%v\n", shape, strides)
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.verb == 'H' {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.flat {
|
||||||
|
// TODO.
|
||||||
|
// writeFlatTensor()
|
||||||
|
log.Printf("WARNING: f.writeFlatTensor() NotImplemedted.\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0d (scalar)
|
||||||
|
if len(shape) == 0 {
|
||||||
|
fmt.Printf("%v", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1d (slice)
|
||||||
|
if len(shape) == 1 {
|
||||||
|
values := sliceInterface(data, 0, shape[0])
|
||||||
|
f.writeVector(values)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2d (matrix)
|
||||||
|
if len(shape) == 2 {
|
||||||
|
vlen := shape[0] * shape[1]
|
||||||
|
values := sliceInterface(data, 0, vlen)
|
||||||
|
f.writeMatrix(values, shape)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// >= 3d (tensor)
|
||||||
|
f.Write([]byte("\n"))
|
||||||
|
f.writeTensor(ts, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTensor writes input tensor in specified format.
|
||||||
|
func (f *fmtState) writeTensor(ts *Tensor, values interface{}) {
|
||||||
|
shape := toSliceInt(ts.MustSize())
|
||||||
|
strides := shapeToStrides(shape)
|
||||||
|
size := shapeToNumels(shape)
|
||||||
|
mSize := shape[len(shape)-1] * shape[len(shape)-2]
|
||||||
|
var (
|
||||||
|
offset int
|
||||||
|
printOne = false
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < size; i += int(mSize) {
|
||||||
|
dims := make([]int, len(strides)-2)
|
||||||
|
for n := 0; n < len(strides[:len(strides)-2]); n++ {
|
||||||
|
stride := strides[n]
|
||||||
|
var dim int = offset
|
||||||
|
for _, s := range strides[:n] {
|
||||||
|
dim = dim % s
|
||||||
|
}
|
||||||
|
dim = dim / stride
|
||||||
|
dims[n] = dim
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
vlimit = f.printCols
|
||||||
|
shouldPrint = true
|
||||||
|
printVTrunc bool
|
||||||
|
conds []bool
|
||||||
|
)
|
||||||
|
for i, val := range dims {
|
||||||
|
maxDim := shape[i]
|
||||||
|
if (val < vlimit/2) || (val >= maxDim-vlimit/2) {
|
||||||
|
conds = append(conds, true)
|
||||||
|
shouldPrint = shouldPrint && true
|
||||||
|
} else {
|
||||||
|
conds = append(conds, false)
|
||||||
|
shouldPrint = shouldPrint && false
|
||||||
|
|
||||||
|
printVTrunc = true
|
||||||
|
}
|
||||||
|
} // inner for
|
||||||
|
if shouldPrint {
|
||||||
|
dimsLabel := "("
|
||||||
|
for _, d := range dims {
|
||||||
|
dimsLabel += fmt.Sprintf("%v, ", d)
|
||||||
|
}
|
||||||
|
dimsLabel += ".,.) =\n"
|
||||||
|
f.Write([]byte(dimsLabel))
|
||||||
|
|
||||||
|
// Print matrix [H, W]
|
||||||
|
data := sliceInterface(values, offset, offset+mSize)
|
||||||
|
f.writeMatrix(data, shape[len(shape)-2:])
|
||||||
|
|
||||||
|
printOne = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if printVTrunc && !printOne {
|
||||||
|
// NOTE. vertical truncation at > 2D level
|
||||||
|
vtrunc := fmt.Sprintf("...,\n\n")
|
||||||
|
f.Write([]byte(vtrunc))
|
||||||
|
printOne = true
|
||||||
|
}
|
||||||
|
offset += mSize
|
||||||
|
} // outer for
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fmtState) writeMatrix(data []interface{}, shape []int) {
|
||||||
|
n := shapeToNumels(shape)
|
||||||
|
dataLen := len(data)
|
||||||
|
if dataLen != n {
|
||||||
|
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
|
||||||
|
}
|
||||||
|
if len(shape) != 2 {
|
||||||
|
log.Fatal("Shape must have length of 2.\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
stride := int(shape[1])
|
||||||
|
currIdx := 0
|
||||||
|
nextIdx := stride
|
||||||
|
truncatedRows := shape[0] - f.printCols
|
||||||
|
for row := 0; row < int(shape[0]); row++ {
|
||||||
|
var slice []interface{}
|
||||||
|
switch {
|
||||||
|
case row < f.printCols/2: // First part
|
||||||
|
slice = data[currIdx:nextIdx]
|
||||||
|
f.writeVector(slice)
|
||||||
|
currIdx = nextIdx
|
||||||
|
nextIdx += stride
|
||||||
|
case row == f.printCols/2: // Truncated sign
|
||||||
|
if f.printCols != f.cols { // truncated mode
|
||||||
|
f.writeVTrunc()
|
||||||
|
} else { // full mode
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
currIdx = nextIdx
|
||||||
|
nextIdx += stride
|
||||||
|
case row > f.printCols/2 && row < f.printCols/2+truncatedRows: // Skip part
|
||||||
|
currIdx = nextIdx
|
||||||
|
nextIdx += stride
|
||||||
|
case row >= f.printCols/2+truncatedRows: // Second part
|
||||||
|
slice = data[currIdx:nextIdx]
|
||||||
|
f.writeVector(slice)
|
||||||
|
currIdx = nextIdx
|
||||||
|
nextIdx += stride
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1 line between matrix
|
||||||
|
f.Write([]byte("\n"))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fmtState) writeVector(data []interface{}) {
|
||||||
|
format := f.initFmt()
|
||||||
|
vlen := len(data)
|
||||||
|
for col := 0; col < vlen; col++ {
|
||||||
|
if f.cols <= f.printCols || (col < f.printCols/2 || (col >= f.cols-f.printCols/2)) {
|
||||||
|
el := data[col]
|
||||||
|
// TODO: more format options here
|
||||||
|
w, _ := fmt.Fprintf(f.buf, format, el)
|
||||||
|
f.Write(f.buf.Bytes())
|
||||||
|
f.Write(f.pad[:f.w-w]) // prepad
|
||||||
|
f.Write(f.pad[:2]) // pad
|
||||||
|
f.buf.Reset()
|
||||||
|
} else if col == f.printCols/2 {
|
||||||
|
f.writeHTrunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print prints tensor meta data to stdout.
|
||||||
|
func (ts *Tensor) Info() {
|
||||||
|
fmt.Printf("%i", ts)
|
||||||
|
}
|
||||||
|
|
19
ts/print_test.go
Normal file
19
ts/print_test.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package ts_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/ts"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTensor_Format(t *testing.T) {
|
||||||
|
shape := []int64{8, 8, 8}
|
||||||
|
numels := int64(8 * 8 * 8)
|
||||||
|
|
||||||
|
x := ts.MustArange(ts.IntScalar(numels), gotch.Float, gotch.CPU).MustView(shape, true)
|
||||||
|
|
||||||
|
fmt.Printf("%0.1f", x) // print truncated data
|
||||||
|
// fmt.Printf("%#0.1f", x) // print full data
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user