reworked ts.Format()
This commit is contained in:
parent
9121872ceb
commit
ce421dd4c5
|
@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
## [Unreleased]
|
||||
- Fixed incorrect indexing at `dutil/Dataset.Next()`
|
||||
- Added `nn.MSELoss()`
|
||||
- reworked `ts.Format()`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
744
ts/print.go
744
ts/print.go
|
@ -6,14 +6,10 @@ import (
|
|||
"log"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
var fmtByte = []byte("%")
|
||||
var precByte = []byte(".")
|
||||
|
||||
func (ts *Tensor) ValueGo() interface{} {
|
||||
dtype := ts.DType()
|
||||
numel := ts.Numel()
|
||||
|
@ -51,288 +47,6 @@ func (ts *Tensor) ValueGo() interface{} {
|
|||
|
||||
return dst
|
||||
}
|
||||
func (ts *Tensor) ToSlice() reflect.Value {
|
||||
// Create a 1-dimensional slice of the base large enough for the data and
|
||||
// copy the data in.
|
||||
shape := ts.MustSize()
|
||||
dt := ts.DType()
|
||||
n := int(numElements(shape))
|
||||
var (
|
||||
slice reflect.Value
|
||||
typ reflect.Type
|
||||
)
|
||||
if dt.String() == "String" {
|
||||
panic("Unsupported 'String' type")
|
||||
} else {
|
||||
gtyp, err := gotch.ToGoType(dt)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
typ = reflect.SliceOf(gtyp)
|
||||
slice = reflect.MakeSlice(typ, n, n)
|
||||
data := ts.ValueGo()
|
||||
slice = reflect.ValueOf(data)
|
||||
}
|
||||
// Now we have the data in place in the base slice we can add the
|
||||
// dimensions. We want to walk backwards through the shape. If the shape is
|
||||
// length 1 or 0 then we're already done.
|
||||
if len(shape) == 0 {
|
||||
return slice.Index(0)
|
||||
}
|
||||
if len(shape) == 1 {
|
||||
return slice
|
||||
}
|
||||
// We have a special case if the tensor has no data. Our backing slice is
|
||||
// empty, but we still want to create slices following the shape. In this
|
||||
// case only the final part of the shape will be 0 and we want to recalculate
|
||||
// n at this point ignoring that 0.
|
||||
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
|
||||
// want 6 zero length slices to group as follows.
|
||||
// {{} {}} {{} {}} {{} {}}
|
||||
if n == 0 {
|
||||
n = int(numElements(shape[:len(shape)-1]))
|
||||
}
|
||||
for i := len(shape) - 2; i >= 0; i-- {
|
||||
underlyingSize := typ.Elem().Size()
|
||||
typ = reflect.SliceOf(typ)
|
||||
subsliceLen := int(shape[i+1])
|
||||
if subsliceLen != 0 {
|
||||
n = n / subsliceLen
|
||||
}
|
||||
// Just using reflection it is difficult to avoid unnecessary
|
||||
// allocations while setting up the sub-slices as the Slice function on
|
||||
// a slice Value allocates. So we end up doing pointer arithmetic!
|
||||
// Pointer() on a slice gives us access to the data backing the slice.
|
||||
// We insert slice headers directly into this data.
|
||||
data := unsafe.Pointer(slice.Pointer())
|
||||
nextSlice := reflect.MakeSlice(typ, n, n)
|
||||
for j := 0; j < n; j++ {
|
||||
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
|
||||
setSliceInSlice(nextSlice, j, sliceHeader{
|
||||
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
|
||||
Len: subsliceLen,
|
||||
Cap: subsliceLen,
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Printf("nextSlice length: %v\n", nextSlice.Len())
|
||||
fmt.Printf("%v\n\n", nextSlice)
|
||||
|
||||
slice = nextSlice
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
// setSliceInSlice sets slice[index] = content.
|
||||
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
|
||||
const sliceSize = unsafe.Sizeof(sliceHeader{})
|
||||
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
|
||||
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
|
||||
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
|
||||
}
|
||||
func numElements(shape []int64) int64 {
|
||||
n := int64(1)
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
|
||||
// this is not inspected by the garbage collector
|
||||
type sliceHeader struct {
|
||||
Data unsafe.Pointer
|
||||
Len int
|
||||
Cap int
|
||||
}
|
||||
|
||||
// Format implements fmt.Formatter interface so that we can use
|
||||
// fmt.Print... and verbs to print out Tensor value in different formats.
|
||||
func (ts *Tensor) Format(s fmt.State, c rune) {
|
||||
shape := ts.MustSize()
|
||||
device := ts.MustDevice()
|
||||
dtype := ts.DType()
|
||||
if c == 'i' {
|
||||
fmt.Fprintf(s, "\nTENSOR INFO:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n\tDefined:\t%v\n", shape, dtype, device, ts.MustDefined())
|
||||
return
|
||||
}
|
||||
|
||||
data := ts.ValueGo()
|
||||
|
||||
f := newFmtState(s, c, shape)
|
||||
f.setWidth(data)
|
||||
f.makePad()
|
||||
|
||||
// 0d (scalar)
|
||||
if len(shape) == 0 {
|
||||
fmt.Printf("%v", data)
|
||||
}
|
||||
|
||||
// 1d (slice)
|
||||
if len(shape) == 1 {
|
||||
f.writeSlice(data)
|
||||
return
|
||||
}
|
||||
|
||||
// 2d (matrix)
|
||||
if len(shape) == 2 {
|
||||
f.writeMatrix(data, shape)
|
||||
return
|
||||
}
|
||||
|
||||
// >= 3d (tensor)
|
||||
mSize := int(shape[len(shape)-2] * shape[len(shape)-1])
|
||||
mShape := shape[len(shape)-2:]
|
||||
dims := shape[:len(shape)-2]
|
||||
var rdims []int64
|
||||
for d := len(dims) - 1; d >= 0; d-- {
|
||||
rdims = append(rdims, dims[d])
|
||||
}
|
||||
|
||||
f.writeTensor(0, rdims, 0, mSize, mShape, data, "")
|
||||
}
|
||||
|
||||
// fmtState is a struct that implements fmt.State interface
|
||||
type fmtState struct {
|
||||
fmt.State
|
||||
c rune // format verb
|
||||
pad []byte // padding
|
||||
w int // width
|
||||
p int // precision
|
||||
shape []int64
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newFmtState(s fmt.State, c rune, shape []int64) *fmtState {
|
||||
w, _ := s.Width()
|
||||
p, _ := s.Precision()
|
||||
|
||||
return &fmtState{
|
||||
State: s,
|
||||
c: c,
|
||||
w: w,
|
||||
p: p,
|
||||
shape: shape,
|
||||
buf: bytes.NewBuffer(make([]byte, 0)),
|
||||
}
|
||||
}
|
||||
|
||||
// writeTensor iterates recursively through a reversed shape of tensor starting from axis 3 and
|
||||
// and prints out matrices (of last two dims size).
|
||||
func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShape []int64, data interface{}, mName string) {
|
||||
|
||||
offsetSize := product(dims[:d]) * mSize
|
||||
for i := 0; i < int(dims[d]); i++ {
|
||||
name := fmt.Sprintf("%v,%v", i+1, mName)
|
||||
if d >= len(dims)-1 { // last dim, let's print out
|
||||
// write matrix name
|
||||
nameStr := fmt.Sprintf("(%v.,.) =\n", name)
|
||||
f.Write([]byte(nameStr))
|
||||
// write matrix values
|
||||
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
|
||||
f.writeMatrix(slice, mShape)
|
||||
} else { // recursively loop
|
||||
f.writeTensor(d+1, dims, offset, mSize, mShape, data, name)
|
||||
}
|
||||
|
||||
// update offset
|
||||
offset += offsetSize
|
||||
}
|
||||
}
|
||||
|
||||
func product(dims []int64) int {
|
||||
var p int = 1
|
||||
if len(dims) == 0 {
|
||||
return p
|
||||
}
|
||||
|
||||
for _, d := range dims {
|
||||
p = p * int(d)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (f *fmtState) writeMatrix(data interface{}, shape []int64) {
|
||||
n := shapeToSize(shape)
|
||||
dataLen := reflect.ValueOf(data).Len()
|
||||
if dataLen != n {
|
||||
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
|
||||
}
|
||||
if len(shape) != 2 {
|
||||
log.Fatal("Shape must have length of 2.\n")
|
||||
}
|
||||
|
||||
stride := int(shape[1])
|
||||
currIdx := 0
|
||||
nextIdx := stride
|
||||
for i := 0; i < int(shape[0]); i++ {
|
||||
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
|
||||
f.writeSlice(slice.Interface())
|
||||
currIdx = nextIdx
|
||||
nextIdx += stride
|
||||
}
|
||||
|
||||
// 1 line between matrix
|
||||
f.Write([]byte("\n"))
|
||||
|
||||
}
|
||||
|
||||
func (f *fmtState) writeSlice(data interface{}) {
|
||||
format := f.cleanFmt()
|
||||
dataLen := reflect.ValueOf(data).Len()
|
||||
for i := 0; i < dataLen; i++ {
|
||||
el := reflect.ValueOf(data).Index(i).Interface()
|
||||
|
||||
// TODO: more format options here
|
||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
||||
f.Write(f.buf.Bytes())
|
||||
f.Write(f.pad[:f.w-w]) // prepad
|
||||
f.Write(f.pad[:2]) // pad
|
||||
f.buf.Reset()
|
||||
}
|
||||
|
||||
f.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
func (f *fmtState) cleanFmt() string {
|
||||
buf := bytes.NewBuffer(fmtByte)
|
||||
|
||||
// width
|
||||
if w, ok := f.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(w))
|
||||
}
|
||||
|
||||
// precision
|
||||
if p, ok := f.Precision(); ok {
|
||||
buf.Write(precByte)
|
||||
buf.WriteString(strconv.Itoa(p))
|
||||
}
|
||||
|
||||
buf.WriteRune(f.c)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (f *fmtState) makePad() {
|
||||
f.pad = make([]byte, maxInt(f.w, 4))
|
||||
for i := range f.pad {
|
||||
f.pad[i] = ' '
|
||||
}
|
||||
}
|
||||
|
||||
// setWidth determines maximal width from input data and set to `w` field
|
||||
func (f *fmtState) setWidth(data interface{}) {
|
||||
format := f.cleanFmt()
|
||||
f.w = 0
|
||||
for i := 0; i < reflect.ValueOf(data).Len(); i++ {
|
||||
el := reflect.ValueOf(data).Index(i).Interface()
|
||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
||||
if w > f.w {
|
||||
f.w = w
|
||||
}
|
||||
f.buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func shapeToSize(shape []int64) int {
|
||||
n := 1
|
||||
|
@ -345,9 +59,467 @@ func shapeToSize(shape []int64) int {
|
|||
return n
|
||||
}
|
||||
|
||||
func shapeToNumels(shape []int) int {
|
||||
n := 1
|
||||
for _, d := range shape {
|
||||
n *= int(d)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func shapeToStrides(shape []int) []int {
|
||||
numel := shapeToNumels(shape)
|
||||
var strides []int
|
||||
for _, v := range shape {
|
||||
numel /= int(v)
|
||||
strides = append(strides, numel)
|
||||
}
|
||||
|
||||
return strides
|
||||
}
|
||||
|
||||
func toSliceInt(in []int64) []int {
|
||||
out := make([]int, len(in))
|
||||
for i := 0; i < len(in); i++ {
|
||||
out[i] = int(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a >= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func toSlice(input interface{}) []interface{} {
|
||||
vlen := reflect.ValueOf(input).Len()
|
||||
out := make([]interface{}, vlen)
|
||||
for i := 0; i < vlen; i++ {
|
||||
out[i] = reflect.ValueOf(input).Index(i).Interface()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func sliceInterface(data interface{}, start, end int) []interface{} {
|
||||
return toSlice(data)[start:end]
|
||||
}
|
||||
|
||||
// Implement Format interface for Tensor:
|
||||
// ======================================
|
||||
var (
|
||||
fmtByte = []byte("%")
|
||||
precByte = []byte(".")
|
||||
fmtFlags = [...]rune{'+', '-', '#', ' ', '0'}
|
||||
|
||||
ufVec = []byte("Vector")
|
||||
ufMat = []byte("Matrix")
|
||||
ufTensor = []byte("Tensor")
|
||||
)
|
||||
|
||||
// fmtState is a custom formatter for Tensor that implements fmt.State interface
|
||||
type fmtState struct {
|
||||
fmt.State
|
||||
verb rune // format verb
|
||||
flat bool // whether to print tensor in flatten format
|
||||
meta bool // whether to print out meta data
|
||||
ext bool // whether to print full tensor data (no truncation)
|
||||
pad []byte // padding space
|
||||
w int // width - total calculated space to print out tensor values
|
||||
p int // precision for float dtype
|
||||
base int // integer counting base for integer dtype cases
|
||||
htrunc []byte // horizontal truncation symbol space
|
||||
vtrunc []byte // vertical truncation symbol space
|
||||
rows int // total rows
|
||||
cols int // total columns
|
||||
printRows int // rows to print
|
||||
printCols int // columns to print
|
||||
shape []int // shape of tensor to print out
|
||||
buf *bytes.Buffer // memory to hold formated tensor data to print
|
||||
}
|
||||
|
||||
func newFmtState(s fmt.State, verb rune, shape []int) *fmtState {
|
||||
w, _ := s.Width()
|
||||
p, _ := s.Precision()
|
||||
|
||||
return &fmtState{
|
||||
State: s,
|
||||
verb: verb,
|
||||
flat: s.Flag('-'),
|
||||
meta: s.Flag('+'),
|
||||
ext: s.Flag('#'),
|
||||
w: w,
|
||||
p: p,
|
||||
htrunc: []byte("..., "),
|
||||
vtrunc: []byte("...,\n"),
|
||||
shape: shape,
|
||||
buf: bytes.NewBuffer(make([]byte, 0)),
|
||||
}
|
||||
}
|
||||
|
||||
// originalFmt returns original format.
|
||||
func (f *fmtState) originalFmt() string {
|
||||
// write format symbol and verbs
|
||||
buf := bytes.NewBuffer(fmtByte) // '%'
|
||||
for _, flag := range fmtFlags {
|
||||
if f.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
// write width format
|
||||
if w, ok := f.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(w))
|
||||
}
|
||||
|
||||
// write precision verb
|
||||
if p, ok := f.Precision(); ok {
|
||||
buf.Write(precByte)
|
||||
buf.WriteString(strconv.Itoa(p))
|
||||
}
|
||||
|
||||
buf.WriteRune(f.verb)
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// cleanFmt returns a start of the format.
|
||||
func (f *fmtState) initFmt() string {
|
||||
buf := bytes.NewBuffer(fmtByte)
|
||||
|
||||
// write width format
|
||||
if w, ok := f.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(w))
|
||||
}
|
||||
|
||||
// write precision verb
|
||||
if p, ok := f.Precision(); ok {
|
||||
buf.Write(precByte)
|
||||
buf.WriteString(strconv.Itoa(p))
|
||||
}
|
||||
|
||||
buf.WriteRune(f.verb)
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// cast casts tensor data to the formatter.
|
||||
func (f *fmtState) cast(ts *Tensor) {
|
||||
// rows and columns
|
||||
if ts.Dim() == 1 {
|
||||
f.rows = 1
|
||||
f.cols = int(ts.Numel())
|
||||
} else {
|
||||
shape := ts.MustSize()
|
||||
f.rows = int(shape[len(shape)-2])
|
||||
f.cols = int(shape[len(shape)-1])
|
||||
}
|
||||
|
||||
// printRows and printCols
|
||||
switch {
|
||||
case f.flat && f.ext:
|
||||
f.printCols = int(ts.Numel())
|
||||
case f.flat:
|
||||
f.printCols = 10
|
||||
case f.ext:
|
||||
f.printCols = f.cols
|
||||
f.printRows = f.rows
|
||||
default:
|
||||
f.printCols = minInt(f.cols, 6)
|
||||
f.printRows = minInt(f.rows, 6)
|
||||
}
|
||||
}
|
||||
|
||||
// fmtVerb formats verbs.
|
||||
func (f *fmtState) fmtVerb(ts *Tensor) {
|
||||
if f.verb == 'H' { // print out only header.
|
||||
f.meta = true
|
||||
return
|
||||
}
|
||||
|
||||
// var typ T
|
||||
typ := ts.DType()
|
||||
|
||||
switch typ.String() {
|
||||
case "float32", "float64":
|
||||
switch f.verb {
|
||||
case 'f', 'e', 'E', 'G', 'b':
|
||||
// accepted. Do nothing
|
||||
default:
|
||||
f.verb = 'g'
|
||||
}
|
||||
|
||||
case "uint8", "int8", "int16", "int32", "int64":
|
||||
switch f.verb {
|
||||
case 'b':
|
||||
f.base = 2
|
||||
case 'd':
|
||||
f.base = 10
|
||||
case 'o':
|
||||
f.base = 8
|
||||
case 'x', 'X':
|
||||
f.base = 16
|
||||
default:
|
||||
f.base = 10
|
||||
f.verb = 'd'
|
||||
}
|
||||
case "bool":
|
||||
f.verb = 't'
|
||||
default:
|
||||
f.verb = 'v'
|
||||
}
|
||||
}
|
||||
|
||||
// computeWidth computes a width that can fit for every element.
|
||||
func (f *fmtState) computeWidth(values interface{}) {
|
||||
format := f.initFmt()
|
||||
vlen := reflect.ValueOf(values).Len()
|
||||
f.w = 0
|
||||
for i := 0; i < vlen; i++ {
|
||||
val := reflect.ValueOf(values).Index(i)
|
||||
w, _ := fmt.Fprintf(f.buf, format, val)
|
||||
|
||||
if w > f.w {
|
||||
f.w = w
|
||||
}
|
||||
f.buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// makePad prepares white spaces for print-out format.
|
||||
func (f *fmtState) makePad() {
|
||||
f.pad = make([]byte, maxInt(f.w, 2))
|
||||
for i := range f.pad {
|
||||
f.pad[i] = ' ' // one white space
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fmtState) writeHTrunc() {
|
||||
f.Write(f.htrunc)
|
||||
}
|
||||
|
||||
func (f *fmtState) writeVTrunc() {
|
||||
f.Write(f.vtrunc)
|
||||
}
|
||||
|
||||
// Format implements fmt.Formatter interface so that we can use
|
||||
// fmt.Print... and verbs to print out Tensor value in different formats.
|
||||
func (ts *Tensor) Format(s fmt.State, verb rune) {
|
||||
shape := toSliceInt(ts.MustSize())
|
||||
strides := shapeToStrides(shape)
|
||||
device := ts.MustDevice()
|
||||
dtype := ts.DType().String()
|
||||
if verb == 'i' {
|
||||
fmt.Fprintf(
|
||||
s,
|
||||
"\nTENSOR META:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n",
|
||||
shape,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
data := ts.ValueGo()
|
||||
|
||||
f := newFmtState(s, verb, shape)
|
||||
f.computeWidth(data)
|
||||
f.makePad()
|
||||
f.cast(ts)
|
||||
|
||||
// Tensor meta data
|
||||
if f.meta {
|
||||
switch ts.Dim() {
|
||||
case 1:
|
||||
f.Write(ufVec)
|
||||
case 2:
|
||||
f.Write(ufMat)
|
||||
default:
|
||||
f.Write(ufTensor)
|
||||
fmt.Fprintf(f, ": Dim=%d, ", ts.Dim())
|
||||
}
|
||||
fmt.Fprintf(f, "Shape=%v, Strides=%v\n", shape, strides)
|
||||
}
|
||||
|
||||
if f.verb == 'H' {
|
||||
return
|
||||
}
|
||||
|
||||
if f.flat {
|
||||
// TODO.
|
||||
// writeFlatTensor()
|
||||
log.Printf("WARNING: f.writeFlatTensor() NotImplemedted.\n")
|
||||
return
|
||||
}
|
||||
|
||||
// 0d (scalar)
|
||||
if len(shape) == 0 {
|
||||
fmt.Printf("%v", data)
|
||||
}
|
||||
|
||||
// 1d (slice)
|
||||
if len(shape) == 1 {
|
||||
values := sliceInterface(data, 0, shape[0])
|
||||
f.writeVector(values)
|
||||
return
|
||||
}
|
||||
|
||||
// 2d (matrix)
|
||||
if len(shape) == 2 {
|
||||
vlen := shape[0] * shape[1]
|
||||
values := sliceInterface(data, 0, vlen)
|
||||
f.writeMatrix(values, shape)
|
||||
return
|
||||
}
|
||||
|
||||
// >= 3d (tensor)
|
||||
f.Write([]byte("\n"))
|
||||
f.writeTensor(ts, data)
|
||||
}
|
||||
|
||||
// writeTensor writes input tensor in specified format.
|
||||
func (f *fmtState) writeTensor(ts *Tensor, values interface{}) {
|
||||
shape := toSliceInt(ts.MustSize())
|
||||
strides := shapeToStrides(shape)
|
||||
size := shapeToNumels(shape)
|
||||
mSize := shape[len(shape)-1] * shape[len(shape)-2]
|
||||
var (
|
||||
offset int
|
||||
printOne = false
|
||||
)
|
||||
|
||||
for i := 0; i < size; i += int(mSize) {
|
||||
dims := make([]int, len(strides)-2)
|
||||
for n := 0; n < len(strides[:len(strides)-2]); n++ {
|
||||
stride := strides[n]
|
||||
var dim int = offset
|
||||
for _, s := range strides[:n] {
|
||||
dim = dim % s
|
||||
}
|
||||
dim = dim / stride
|
||||
dims[n] = dim
|
||||
}
|
||||
|
||||
var (
|
||||
vlimit = f.printCols
|
||||
shouldPrint = true
|
||||
printVTrunc bool
|
||||
conds []bool
|
||||
)
|
||||
for i, val := range dims {
|
||||
maxDim := shape[i]
|
||||
if (val < vlimit/2) || (val >= maxDim-vlimit/2) {
|
||||
conds = append(conds, true)
|
||||
shouldPrint = shouldPrint && true
|
||||
} else {
|
||||
conds = append(conds, false)
|
||||
shouldPrint = shouldPrint && false
|
||||
|
||||
printVTrunc = true
|
||||
}
|
||||
} // inner for
|
||||
if shouldPrint {
|
||||
dimsLabel := "("
|
||||
for _, d := range dims {
|
||||
dimsLabel += fmt.Sprintf("%v, ", d)
|
||||
}
|
||||
dimsLabel += ".,.) =\n"
|
||||
f.Write([]byte(dimsLabel))
|
||||
|
||||
// Print matrix [H, W]
|
||||
data := sliceInterface(values, offset, offset+mSize)
|
||||
f.writeMatrix(data, shape[len(shape)-2:])
|
||||
|
||||
printOne = false
|
||||
}
|
||||
|
||||
if printVTrunc && !printOne {
|
||||
// NOTE. vertical truncation at > 2D level
|
||||
vtrunc := fmt.Sprintf("...,\n\n")
|
||||
f.Write([]byte(vtrunc))
|
||||
printOne = true
|
||||
}
|
||||
offset += mSize
|
||||
} // outer for
|
||||
}
|
||||
|
||||
func (f *fmtState) writeMatrix(data []interface{}, shape []int) {
|
||||
n := shapeToNumels(shape)
|
||||
dataLen := len(data)
|
||||
if dataLen != n {
|
||||
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
|
||||
}
|
||||
if len(shape) != 2 {
|
||||
log.Fatal("Shape must have length of 2.\n")
|
||||
}
|
||||
|
||||
stride := int(shape[1])
|
||||
currIdx := 0
|
||||
nextIdx := stride
|
||||
truncatedRows := shape[0] - f.printCols
|
||||
for row := 0; row < int(shape[0]); row++ {
|
||||
var slice []interface{}
|
||||
switch {
|
||||
case row < f.printCols/2: // First part
|
||||
slice = data[currIdx:nextIdx]
|
||||
f.writeVector(slice)
|
||||
currIdx = nextIdx
|
||||
nextIdx += stride
|
||||
case row == f.printCols/2: // Truncated sign
|
||||
if f.printCols != f.cols { // truncated mode
|
||||
f.writeVTrunc()
|
||||
} else { // full mode
|
||||
// Do nothing
|
||||
}
|
||||
currIdx = nextIdx
|
||||
nextIdx += stride
|
||||
case row > f.printCols/2 && row < f.printCols/2+truncatedRows: // Skip part
|
||||
currIdx = nextIdx
|
||||
nextIdx += stride
|
||||
case row >= f.printCols/2+truncatedRows: // Second part
|
||||
slice = data[currIdx:nextIdx]
|
||||
f.writeVector(slice)
|
||||
currIdx = nextIdx
|
||||
nextIdx += stride
|
||||
}
|
||||
}
|
||||
|
||||
// 1 line between matrix
|
||||
f.Write([]byte("\n"))
|
||||
|
||||
}
|
||||
|
||||
func (f *fmtState) writeVector(data []interface{}) {
|
||||
format := f.initFmt()
|
||||
vlen := len(data)
|
||||
for col := 0; col < vlen; col++ {
|
||||
if f.cols <= f.printCols || (col < f.printCols/2 || (col >= f.cols-f.printCols/2)) {
|
||||
el := data[col]
|
||||
// TODO: more format options here
|
||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
||||
f.Write(f.buf.Bytes())
|
||||
f.Write(f.pad[:f.w-w]) // prepad
|
||||
f.Write(f.pad[:2]) // pad
|
||||
f.buf.Reset()
|
||||
} else if col == f.printCols/2 {
|
||||
f.writeHTrunc()
|
||||
}
|
||||
}
|
||||
|
||||
f.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
// Print prints tensor meta data to stdout.
|
||||
func (ts *Tensor) Info() {
|
||||
fmt.Printf("%i", ts)
|
||||
}
|
||||
|
|
19
ts/print_test.go
Normal file
19
ts/print_test.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package ts_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestTensor_Format(t *testing.T) {
|
||||
shape := []int64{8, 8, 8}
|
||||
numels := int64(8 * 8 * 8)
|
||||
|
||||
x := ts.MustArange(ts.IntScalar(numels), gotch.Float, gotch.CPU).MustView(shape, true)
|
||||
|
||||
fmt.Printf("%0.1f", x) // print truncated data
|
||||
// fmt.Printf("%#0.1f", x) // print full data
|
||||
}
|
Loading…
Reference in New Issue
Block a user