tensor/print: added basic tensor print using fmt Formatter interface

This commit is contained in:
sugarme 2020-11-09 13:35:28 +11:00
parent e66046b45e
commit 5d60adf0fd
2 changed files with 299 additions and 10 deletions

View File

@ -1,20 +1,23 @@
package main
import (
// "fmt"
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func main() {
// Create a tensor [2,3,4]
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
tensor.Print()
mul := ts.MustOnes([]int64{4, 5}, gotch.Int64, gotch.CPU)
res := tensor.MustMatmul(mul, false)
res.Print()
// intTensor()
floatTensor()
}
func intTensor() {
xs := ts.MustArange(ts.IntScalar(7*3*4*5*6), gotch.Int64, gotch.CPU).MustView([]int64{7, 3, 4, 5, 6}, true)
fmt.Printf("%v\n", xs)
}
func floatTensor() {
xs := ts.MustRand([]int64{7, 3, 4, 5, 6}, gotch.Double, gotch.CPU)
fmt.Printf("%v\n", xs)
}

286
tensor/print.go Normal file
View File

@ -0,0 +1,286 @@
package tensor
import (
"bytes"
"fmt"
"github.com/sugarme/gotch"
"log"
"reflect"
"unsafe"
)
func (ts *Tensor) ValueGo() interface{} {
dtype := ts.DType()
numel := ts.Numel()
var dst interface{}
switch dtype {
case gotch.Uint8:
dst = make([]uint8, numel)
case gotch.Int8:
dst = make([]int8, numel)
case gotch.Int16:
dst = make([]int16, numel)
case gotch.Int:
dst = make([]int32, numel)
case gotch.Int64:
dst = make([]int64, numel)
case gotch.Float:
dst = make([]float32, numel)
case gotch.Double:
dst = make([]float64, numel)
case gotch.Bool:
dst = make([]bool, numel)
default:
err := fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
log.Fatal(err)
}
err := ts.CopyData(dst, ts.Numel())
if err != nil {
log.Fatal(err)
}
// fmt.Println(dst)
return dst
}
func (ts *Tensor) ToSlice() reflect.Value {
// Create a 1-dimensional slice of the base large enough for the data and
// copy the data in.
shape := ts.MustSize()
dt := ts.DType()
n := int(numElements(shape))
var (
slice reflect.Value
typ reflect.Type
)
if dt.String() == "String" {
panic("Unsupported")
// strs, err := decodeOneDimString(raw, n)
// if err != nil {
// e := fmt.Errorf("unable to decode string with shape %v: %v", shape, err)
// panic(e)
// }
// slice = reflect.ValueOf(strs)
// typ = slice.Type()
} else {
gtyp, err := gotch.ToGoType(dt)
if err != nil {
log.Fatal(err)
}
typ = reflect.SliceOf(gtyp)
slice = reflect.MakeSlice(typ, n, n)
data := ts.ValueGo()
slice = reflect.ValueOf(data)
}
// Now we have the data in place in the base slice we can add the
// dimensions. We want to walk backwards through the shape. If the shape is
// length 1 or 0 then we're already done.
if len(shape) == 0 {
return slice.Index(0)
}
if len(shape) == 1 {
return slice
}
// We have a special case if the tensor has no data. Our backing slice is
// empty, but we still want to create slices following the shape. In this
// case only the final part of the shape will be 0 and we want to recalculate
// n at this point ignoring that 0.
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
// want 6 zero length slices to group as follows.
// {{} {}} {{} {}} {{} {}}
if n == 0 {
n = int(numElements(shape[:len(shape)-1]))
}
for i := len(shape) - 2; i >= 0; i-- {
underlyingSize := typ.Elem().Size()
typ = reflect.SliceOf(typ)
subsliceLen := int(shape[i+1])
if subsliceLen != 0 {
n = n / subsliceLen
}
// Just using reflection it is difficult to avoid unnecessary
// allocations while setting up the sub-slices as the Slice function on
// a slice Value allocates. So we end up doing pointer arithmetic!
// Pointer() on a slice gives us access to the data backing the slice.
// We insert slice headers directly into this data.
data := unsafe.Pointer(slice.Pointer())
nextSlice := reflect.MakeSlice(typ, n, n)
for j := 0; j < n; j++ {
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
setSliceInSlice(nextSlice, j, sliceHeader{
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
Len: subsliceLen,
Cap: subsliceLen,
})
}
fmt.Printf("nextSlice length: %v\n", nextSlice.Len())
fmt.Printf("%v\n\n", nextSlice)
slice = nextSlice
}
return slice
}
// setSliceInSlice sets slice[index] = content.
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
const sliceSize = unsafe.Sizeof(sliceHeader{})
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
}
func numElements(shape []int64) int64 {
n := int64(1)
for _, d := range shape {
n *= d
}
return n
}
// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
// this is not inspected by the garbage collector
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
// Format implements fmt.Formatter interface so that we can use
// fmt.Print... and verbs to print out Tensor value in different formats.
func (ts *Tensor) Format(s fmt.State, c rune) {
shape := ts.MustSize()
device := ts.MustDevice()
dtype := ts.DType()
if c == 'i' {
fmt.Fprintf(s, "\nTENSOR INFO:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n\tDefined:\t%v\n", shape, dtype, device, ts.MustDefined())
return
}
f := newFmtState(s, c, shape)
data := ts.ValueGo()
// 0d (scalar)
if len(shape) == 0 {
fmt.Printf("%v", data)
}
// 1d (slice)
if len(shape) == 1 {
f.writeSlice(data)
return
}
// 2d (matrix)
if len(shape) == 2 {
f.writeMatrix(data, shape)
return
}
// >= 3d (tensor)
mSize := int(shape[len(shape)-2] * shape[len(shape)-1])
mShape := shape[len(shape)-2:]
dims := shape[:len(shape)-2]
var rdims []int64
for d := len(dims) - 1; d >= 0; d-- {
rdims = append(rdims, dims[d])
}
f.writeTensor(0, rdims, 0, mSize, mShape, data, "")
}
// fmtState is a struct that implements fmt.State interface
type fmtState struct {
fmt.State
c rune
shape []int64
buf *bytes.Buffer
}
func newFmtState(s fmt.State, c rune, shape []int64) *fmtState {
return &fmtState{
State: s,
c: c,
shape: shape,
buf: bytes.NewBuffer(make([]byte, 0)),
}
}
// writeTensor iterates recursively through a reversed shape of tensor starting from axis 3 and
// and prints out matrices (of last two dims size).
func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShape []int64, data interface{}, mName string) {
offsetSize := product(dims[:d]) * mSize
for i := 0; i < int(dims[d]); i++ {
name := fmt.Sprintf("%v,%v", i+1, mName)
if d >= len(dims)-1 { // last dim, let's print out
fmt.Printf("(%v,.,.) =\n", name)
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
f.writeMatrix(slice, mShape)
} else { // recursively loop
f.writeTensor(d+1, dims, offset, mSize, mShape, data, name)
}
// update offset
offset += offsetSize
}
}
func product(dims []int64) int {
var p int = 1
if len(dims) == 0 {
return p
}
for _, d := range dims {
p = p * int(d)
}
return p
}
func (f *fmtState) writeMatrix(data interface{}, shape []int64) {
n := shapeToSize(shape)
dataLen := reflect.ValueOf(data).Len()
if dataLen != n {
log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n)
}
if len(shape) != 2 {
log.Fatal("Shape must have length of 2.\n")
}
stride := int(shape[1])
currIdx := 0
nextIdx := stride
for i := 0; i < int(shape[0]); i++ {
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
f.writeSlice(slice.Interface())
// fmt.Printf("%4v\n", slice)
currIdx = nextIdx
nextIdx += stride
}
// 2 lines
fmt.Printf("\n\n")
}
func (f *fmtState) writeSlice(data interface{}) {
dataLen := reflect.ValueOf(data).Len()
for i := 0; i < dataLen; i++ {
el := reflect.ValueOf(data).Index(i)
// TODO: more format options here
fmt.Printf("%4v ", el)
}
fmt.Println()
}
func shapeToSize(shape []int64) int {
n := 1
for _, v := range shape {
if v == 0 {
continue
}
n = n * int(v)
}
return n
}