reworked ts.Format()

This commit is contained in:
sugarme 2022-04-28 17:31:23 +10:00
parent 9121872ceb
commit ce421dd4c5
3 changed files with 478 additions and 286 deletions

View File

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- Fixed incorrect indexing at `dutil/Dataset.Next()`
- Added `nn.MSELoss()`
- reworked `ts.Format()`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -6,14 +6,10 @@ import (
"log"
"reflect"
"strconv"
"unsafe"
"github.com/sugarme/gotch"
)
var fmtByte = []byte("%")
var precByte = []byte(".")
func (ts *Tensor) ValueGo() interface{} {
dtype := ts.DType()
numel := ts.Numel()
@ -51,288 +47,6 @@ func (ts *Tensor) ValueGo() interface{} {
return dst
}
func (ts *Tensor) ToSlice() reflect.Value {
// Create a 1-dimensional slice of the base large enough for the data and
// copy the data in.
shape := ts.MustSize()
dt := ts.DType()
n := int(numElements(shape))
var (
slice reflect.Value
typ reflect.Type
)
if dt.String() == "String" {
panic("Unsupported 'String' type")
} else {
gtyp, err := gotch.ToGoType(dt)
if err != nil {
log.Fatal(err)
}
typ = reflect.SliceOf(gtyp)
slice = reflect.MakeSlice(typ, n, n)
data := ts.ValueGo()
slice = reflect.ValueOf(data)
}
// Now we have the data in place in the base slice we can add the
// dimensions. We want to walk backwards through the shape. If the shape is
// length 1 or 0 then we're already done.
if len(shape) == 0 {
return slice.Index(0)
}
if len(shape) == 1 {
return slice
}
// We have a special case if the tensor has no data. Our backing slice is
// empty, but we still want to create slices following the shape. In this
// case only the final part of the shape will be 0 and we want to recalculate
// n at this point ignoring that 0.
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
// want 6 zero length slices to group as follows.
// {{} {}} {{} {}} {{} {}}
if n == 0 {
n = int(numElements(shape[:len(shape)-1]))
}
for i := len(shape) - 2; i >= 0; i-- {
underlyingSize := typ.Elem().Size()
typ = reflect.SliceOf(typ)
subsliceLen := int(shape[i+1])
if subsliceLen != 0 {
n = n / subsliceLen
}
// Just using reflection it is difficult to avoid unnecessary
// allocations while setting up the sub-slices as the Slice function on
// a slice Value allocates. So we end up doing pointer arithmetic!
// Pointer() on a slice gives us access to the data backing the slice.
// We insert slice headers directly into this data.
data := unsafe.Pointer(slice.Pointer())
nextSlice := reflect.MakeSlice(typ, n, n)
for j := 0; j < n; j++ {
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
setSliceInSlice(nextSlice, j, sliceHeader{
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
Len: subsliceLen,
Cap: subsliceLen,
})
}
fmt.Printf("nextSlice length: %v\n", nextSlice.Len())
fmt.Printf("%v\n\n", nextSlice)
slice = nextSlice
}
return slice
}
// setSliceInSlice sets slice[index] = content.
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
const sliceSize = unsafe.Sizeof(sliceHeader{})
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
}
func numElements(shape []int64) int64 {
n := int64(1)
for _, d := range shape {
n *= d
}
return n
}
// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
// this is not inspected by the garbage collector
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
// Format implements fmt.Formatter interface so that we can use
// fmt.Print... and verbs to print out Tensor value in different formats.
func (ts *Tensor) Format(s fmt.State, c rune) {
shape := ts.MustSize()
device := ts.MustDevice()
dtype := ts.DType()
if c == 'i' {
fmt.Fprintf(s, "\nTENSOR INFO:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n\tDefined:\t%v\n", shape, dtype, device, ts.MustDefined())
return
}
data := ts.ValueGo()
f := newFmtState(s, c, shape)
f.setWidth(data)
f.makePad()
// 0d (scalar)
if len(shape) == 0 {
fmt.Printf("%v", data)
}
// 1d (slice)
if len(shape) == 1 {
f.writeSlice(data)
return
}
// 2d (matrix)
if len(shape) == 2 {
f.writeMatrix(data, shape)
return
}
// >= 3d (tensor)
mSize := int(shape[len(shape)-2] * shape[len(shape)-1])
mShape := shape[len(shape)-2:]
dims := shape[:len(shape)-2]
var rdims []int64
for d := len(dims) - 1; d >= 0; d-- {
rdims = append(rdims, dims[d])
}
f.writeTensor(0, rdims, 0, mSize, mShape, data, "")
}
// fmtState is a struct that implements fmt.State interface
type fmtState struct {
fmt.State
c rune // format verb
pad []byte // padding
w int // width
p int // precision
shape []int64
buf *bytes.Buffer
}
func newFmtState(s fmt.State, c rune, shape []int64) *fmtState {
w, _ := s.Width()
p, _ := s.Precision()
return &fmtState{
State: s,
c: c,
w: w,
p: p,
shape: shape,
buf: bytes.NewBuffer(make([]byte, 0)),
}
}
// writeTensor iterates recursively through a reversed shape of tensor starting from axis 3 and
// and prints out matrices (of last two dims size).
func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShape []int64, data interface{}, mName string) {
offsetSize := product(dims[:d]) * mSize
for i := 0; i < int(dims[d]); i++ {
name := fmt.Sprintf("%v,%v", i+1, mName)
if d >= len(dims)-1 { // last dim, let's print out
// write matrix name
nameStr := fmt.Sprintf("(%v.,.) =\n", name)
f.Write([]byte(nameStr))
// write matrix values
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
f.writeMatrix(slice, mShape)
} else { // recursively loop
f.writeTensor(d+1, dims, offset, mSize, mShape, data, name)
}
// update offset
offset += offsetSize
}
}
func product(dims []int64) int {
var p int = 1
if len(dims) == 0 {
return p
}
for _, d := range dims {
p = p * int(d)
}
return p
}
func (f *fmtState) writeMatrix(data interface{}, shape []int64) {
n := shapeToSize(shape)
dataLen := reflect.ValueOf(data).Len()
if dataLen != n {
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
}
if len(shape) != 2 {
log.Fatal("Shape must have length of 2.\n")
}
stride := int(shape[1])
currIdx := 0
nextIdx := stride
for i := 0; i < int(shape[0]); i++ {
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
f.writeSlice(slice.Interface())
currIdx = nextIdx
nextIdx += stride
}
// 1 line between matrix
f.Write([]byte("\n"))
}
func (f *fmtState) writeSlice(data interface{}) {
format := f.cleanFmt()
dataLen := reflect.ValueOf(data).Len()
for i := 0; i < dataLen; i++ {
el := reflect.ValueOf(data).Index(i).Interface()
// TODO: more format options here
w, _ := fmt.Fprintf(f.buf, format, el)
f.Write(f.buf.Bytes())
f.Write(f.pad[:f.w-w]) // prepad
f.Write(f.pad[:2]) // pad
f.buf.Reset()
}
f.Write([]byte("\n"))
}
func (f *fmtState) cleanFmt() string {
buf := bytes.NewBuffer(fmtByte)
// width
if w, ok := f.Width(); ok {
buf.WriteString(strconv.Itoa(w))
}
// precision
if p, ok := f.Precision(); ok {
buf.Write(precByte)
buf.WriteString(strconv.Itoa(p))
}
buf.WriteRune(f.c)
return buf.String()
}
func (f *fmtState) makePad() {
f.pad = make([]byte, maxInt(f.w, 4))
for i := range f.pad {
f.pad[i] = ' '
}
}
// setWidth determines maximal width from input data and set to `w` field
func (f *fmtState) setWidth(data interface{}) {
format := f.cleanFmt()
f.w = 0
for i := 0; i < reflect.ValueOf(data).Len(); i++ {
el := reflect.ValueOf(data).Index(i).Interface()
w, _ := fmt.Fprintf(f.buf, format, el)
if w > f.w {
f.w = w
}
f.buf.Reset()
}
}
func shapeToSize(shape []int64) int {
n := 1
@ -345,9 +59,467 @@ func shapeToSize(shape []int64) int {
return n
}
func shapeToNumels(shape []int) int {
n := 1
for _, d := range shape {
n *= int(d)
}
return n
}
func shapeToStrides(shape []int) []int {
numel := shapeToNumels(shape)
var strides []int
for _, v := range shape {
numel /= int(v)
strides = append(strides, numel)
}
return strides
}
func toSliceInt(in []int64) []int {
out := make([]int, len(in))
for i := 0; i < len(in); i++ {
out[i] = int(in[i])
}
return out
}
func maxInt(a, b int) int {
if a >= b {
return a
}
return b
}
func minInt(a, b int) int {
if a < b {
return a
} else {
return b
}
}
func toSlice(input interface{}) []interface{} {
vlen := reflect.ValueOf(input).Len()
out := make([]interface{}, vlen)
for i := 0; i < vlen; i++ {
out[i] = reflect.ValueOf(input).Index(i).Interface()
}
return out
}
func sliceInterface(data interface{}, start, end int) []interface{} {
return toSlice(data)[start:end]
}
// Implement Format interface for Tensor:
// ======================================
var (
fmtByte = []byte("%")
precByte = []byte(".")
fmtFlags = [...]rune{'+', '-', '#', ' ', '0'}
ufVec = []byte("Vector")
ufMat = []byte("Matrix")
ufTensor = []byte("Tensor")
)
// fmtState is a custom formatter for Tensor that implements fmt.State interface
type fmtState struct {
fmt.State
verb rune // format verb
flat bool // whether to print tensor in flatten format
meta bool // whether to print out meta data
ext bool // whether to print full tensor data (no truncation)
pad []byte // padding space
w int // width - total calculated space to print out tensor values
p int // precision for float dtype
base int // integer counting base for integer dtype cases
htrunc []byte // horizontal truncation symbol space
vtrunc []byte // vertical truncation symbol space
rows int // total rows
cols int // total columns
printRows int // rows to print
printCols int // columns to print
shape []int // shape of tensor to print out
buf *bytes.Buffer // memory to hold formated tensor data to print
}
func newFmtState(s fmt.State, verb rune, shape []int) *fmtState {
w, _ := s.Width()
p, _ := s.Precision()
return &fmtState{
State: s,
verb: verb,
flat: s.Flag('-'),
meta: s.Flag('+'),
ext: s.Flag('#'),
w: w,
p: p,
htrunc: []byte("..., "),
vtrunc: []byte("...,\n"),
shape: shape,
buf: bytes.NewBuffer(make([]byte, 0)),
}
}
// originalFmt returns original format.
func (f *fmtState) originalFmt() string {
// write format symbol and verbs
buf := bytes.NewBuffer(fmtByte) // '%'
for _, flag := range fmtFlags {
if f.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
// write width format
if w, ok := f.Width(); ok {
buf.WriteString(strconv.Itoa(w))
}
// write precision verb
if p, ok := f.Precision(); ok {
buf.Write(precByte)
buf.WriteString(strconv.Itoa(p))
}
buf.WriteRune(f.verb)
return buf.String()
}
// cleanFmt returns a start of the format.
func (f *fmtState) initFmt() string {
buf := bytes.NewBuffer(fmtByte)
// write width format
if w, ok := f.Width(); ok {
buf.WriteString(strconv.Itoa(w))
}
// write precision verb
if p, ok := f.Precision(); ok {
buf.Write(precByte)
buf.WriteString(strconv.Itoa(p))
}
buf.WriteRune(f.verb)
return buf.String()
}
// cast casts tensor data to the formatter.
func (f *fmtState) cast(ts *Tensor) {
// rows and columns
if ts.Dim() == 1 {
f.rows = 1
f.cols = int(ts.Numel())
} else {
shape := ts.MustSize()
f.rows = int(shape[len(shape)-2])
f.cols = int(shape[len(shape)-1])
}
// printRows and printCols
switch {
case f.flat && f.ext:
f.printCols = int(ts.Numel())
case f.flat:
f.printCols = 10
case f.ext:
f.printCols = f.cols
f.printRows = f.rows
default:
f.printCols = minInt(f.cols, 6)
f.printRows = minInt(f.rows, 6)
}
}
// fmtVerb formats verbs.
func (f *fmtState) fmtVerb(ts *Tensor) {
if f.verb == 'H' { // print out only header.
f.meta = true
return
}
// var typ T
typ := ts.DType()
switch typ.String() {
case "float32", "float64":
switch f.verb {
case 'f', 'e', 'E', 'G', 'b':
// accepted. Do nothing
default:
f.verb = 'g'
}
case "uint8", "int8", "int16", "int32", "int64":
switch f.verb {
case 'b':
f.base = 2
case 'd':
f.base = 10
case 'o':
f.base = 8
case 'x', 'X':
f.base = 16
default:
f.base = 10
f.verb = 'd'
}
case "bool":
f.verb = 't'
default:
f.verb = 'v'
}
}
// computeWidth computes a width that can fit for every element.
func (f *fmtState) computeWidth(values interface{}) {
format := f.initFmt()
vlen := reflect.ValueOf(values).Len()
f.w = 0
for i := 0; i < vlen; i++ {
val := reflect.ValueOf(values).Index(i)
w, _ := fmt.Fprintf(f.buf, format, val)
if w > f.w {
f.w = w
}
f.buf.Reset()
}
}
// makePad prepares white spaces for print-out format.
func (f *fmtState) makePad() {
f.pad = make([]byte, maxInt(f.w, 2))
for i := range f.pad {
f.pad[i] = ' ' // one white space
}
}
func (f *fmtState) writeHTrunc() {
f.Write(f.htrunc)
}
func (f *fmtState) writeVTrunc() {
f.Write(f.vtrunc)
}
// Format implements fmt.Formatter interface so that we can use
// fmt.Print... and verbs to print out Tensor value in different formats.
func (ts *Tensor) Format(s fmt.State, verb rune) {
shape := toSliceInt(ts.MustSize())
strides := shapeToStrides(shape)
device := ts.MustDevice()
dtype := ts.DType().String()
if verb == 'i' {
fmt.Fprintf(
s,
"\nTENSOR META:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n",
shape,
dtype,
device,
)
return
}
data := ts.ValueGo()
f := newFmtState(s, verb, shape)
f.computeWidth(data)
f.makePad()
f.cast(ts)
// Tensor meta data
if f.meta {
switch ts.Dim() {
case 1:
f.Write(ufVec)
case 2:
f.Write(ufMat)
default:
f.Write(ufTensor)
fmt.Fprintf(f, ": Dim=%d, ", ts.Dim())
}
fmt.Fprintf(f, "Shape=%v, Strides=%v\n", shape, strides)
}
if f.verb == 'H' {
return
}
if f.flat {
// TODO.
// writeFlatTensor()
log.Printf("WARNING: f.writeFlatTensor() NotImplemedted.\n")
return
}
// 0d (scalar)
if len(shape) == 0 {
fmt.Printf("%v", data)
}
// 1d (slice)
if len(shape) == 1 {
values := sliceInterface(data, 0, shape[0])
f.writeVector(values)
return
}
// 2d (matrix)
if len(shape) == 2 {
vlen := shape[0] * shape[1]
values := sliceInterface(data, 0, vlen)
f.writeMatrix(values, shape)
return
}
// >= 3d (tensor)
f.Write([]byte("\n"))
f.writeTensor(ts, data)
}
// writeTensor writes input tensor in specified format.
func (f *fmtState) writeTensor(ts *Tensor, values interface{}) {
shape := toSliceInt(ts.MustSize())
strides := shapeToStrides(shape)
size := shapeToNumels(shape)
mSize := shape[len(shape)-1] * shape[len(shape)-2]
var (
offset int
printOne = false
)
for i := 0; i < size; i += int(mSize) {
dims := make([]int, len(strides)-2)
for n := 0; n < len(strides[:len(strides)-2]); n++ {
stride := strides[n]
var dim int = offset
for _, s := range strides[:n] {
dim = dim % s
}
dim = dim / stride
dims[n] = dim
}
var (
vlimit = f.printCols
shouldPrint = true
printVTrunc bool
conds []bool
)
for i, val := range dims {
maxDim := shape[i]
if (val < vlimit/2) || (val >= maxDim-vlimit/2) {
conds = append(conds, true)
shouldPrint = shouldPrint && true
} else {
conds = append(conds, false)
shouldPrint = shouldPrint && false
printVTrunc = true
}
} // inner for
if shouldPrint {
dimsLabel := "("
for _, d := range dims {
dimsLabel += fmt.Sprintf("%v, ", d)
}
dimsLabel += ".,.) =\n"
f.Write([]byte(dimsLabel))
// Print matrix [H, W]
data := sliceInterface(values, offset, offset+mSize)
f.writeMatrix(data, shape[len(shape)-2:])
printOne = false
}
if printVTrunc && !printOne {
// NOTE. vertical truncation at > 2D level
vtrunc := fmt.Sprintf("...,\n\n")
f.Write([]byte(vtrunc))
printOne = true
}
offset += mSize
} // outer for
}
func (f *fmtState) writeMatrix(data []interface{}, shape []int) {
n := shapeToNumels(shape)
dataLen := len(data)
if dataLen != n {
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
}
if len(shape) != 2 {
log.Fatal("Shape must have length of 2.\n")
}
stride := int(shape[1])
currIdx := 0
nextIdx := stride
truncatedRows := shape[0] - f.printCols
for row := 0; row < int(shape[0]); row++ {
var slice []interface{}
switch {
case row < f.printCols/2: // First part
slice = data[currIdx:nextIdx]
f.writeVector(slice)
currIdx = nextIdx
nextIdx += stride
case row == f.printCols/2: // Truncated sign
if f.printCols != f.cols { // truncated mode
f.writeVTrunc()
} else { // full mode
// Do nothing
}
currIdx = nextIdx
nextIdx += stride
case row > f.printCols/2 && row < f.printCols/2+truncatedRows: // Skip part
currIdx = nextIdx
nextIdx += stride
case row >= f.printCols/2+truncatedRows: // Second part
slice = data[currIdx:nextIdx]
f.writeVector(slice)
currIdx = nextIdx
nextIdx += stride
}
}
// 1 line between matrix
f.Write([]byte("\n"))
}
func (f *fmtState) writeVector(data []interface{}) {
format := f.initFmt()
vlen := len(data)
for col := 0; col < vlen; col++ {
if f.cols <= f.printCols || (col < f.printCols/2 || (col >= f.cols-f.printCols/2)) {
el := data[col]
// TODO: more format options here
w, _ := fmt.Fprintf(f.buf, format, el)
f.Write(f.buf.Bytes())
f.Write(f.pad[:f.w-w]) // prepad
f.Write(f.pad[:2]) // pad
f.buf.Reset()
} else if col == f.printCols/2 {
f.writeHTrunc()
}
}
f.Write([]byte("\n"))
}
// Print prints tensor meta data to stdout.
func (ts *Tensor) Info() {
fmt.Printf("%i", ts)
}

19
ts/print_test.go Normal file
View File

@ -0,0 +1,19 @@
package ts_test
import (
"fmt"
"testing"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
func TestTensor_Format(t *testing.T) {
shape := []int64{8, 8, 8}
numels := int64(8 * 8 * 8)
x := ts.MustArange(ts.IntScalar(numels), gotch.Float, gotch.CPU).MustView(shape, true)
fmt.Printf("%0.1f", x) // print truncated data
// fmt.Printf("%#0.1f", x) // print full data
}