diff --git a/CHANGELOG.md b/CHANGELOG.md index 019d9e5..a216a06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - Fixed incorrect indexing at `dutil/Dataset.Next()` - Added `nn.MSELoss()` +- reworked `ts.Format()` ## [Nofix] - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. diff --git a/ts/print.go b/ts/print.go index ffe11da..d26637a 100644 --- a/ts/print.go +++ b/ts/print.go @@ -6,14 +6,10 @@ import ( "log" "reflect" "strconv" - "unsafe" "github.com/sugarme/gotch" ) -var fmtByte = []byte("%") -var precByte = []byte(".") - func (ts *Tensor) ValueGo() interface{} { dtype := ts.DType() numel := ts.Numel() @@ -51,288 +47,6 @@ func (ts *Tensor) ValueGo() interface{} { return dst } -func (ts *Tensor) ToSlice() reflect.Value { - // Create a 1-dimensional slice of the base large enough for the data and - // copy the data in. - shape := ts.MustSize() - dt := ts.DType() - n := int(numElements(shape)) - var ( - slice reflect.Value - typ reflect.Type - ) - if dt.String() == "String" { - panic("Unsupported 'String' type") - } else { - gtyp, err := gotch.ToGoType(dt) - if err != nil { - log.Fatal(err) - } - typ = reflect.SliceOf(gtyp) - slice = reflect.MakeSlice(typ, n, n) - data := ts.ValueGo() - slice = reflect.ValueOf(data) - } - // Now we have the data in place in the base slice we can add the - // dimensions. We want to walk backwards through the shape. If the shape is - // length 1 or 0 then we're already done. - if len(shape) == 0 { - return slice.Index(0) - } - if len(shape) == 1 { - return slice - } - // We have a special case if the tensor has no data. Our backing slice is - // empty, but we still want to create slices following the shape. In this - // case only the final part of the shape will be 0 and we want to recalculate - // n at this point ignoring that 0. - // For example if our shape is 3 * 2 * 0 then n will be zero, but we still - // want 6 zero length slices to group as follows. - // {{} {}} {{} {}} {{} {}} - if n == 0 { - n = int(numElements(shape[:len(shape)-1])) - } - for i := len(shape) - 2; i >= 0; i-- { - underlyingSize := typ.Elem().Size() - typ = reflect.SliceOf(typ) - subsliceLen := int(shape[i+1]) - if subsliceLen != 0 { - n = n / subsliceLen - } - // Just using reflection it is difficult to avoid unnecessary - // allocations while setting up the sub-slices as the Slice function on - // a slice Value allocates. So we end up doing pointer arithmetic! - // Pointer() on a slice gives us access to the data backing the slice. - // We insert slice headers directly into this data. - data := unsafe.Pointer(slice.Pointer()) - nextSlice := reflect.MakeSlice(typ, n, n) - for j := 0; j < n; j++ { - // This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen] - setSliceInSlice(nextSlice, j, sliceHeader{ - Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)), - Len: subsliceLen, - Cap: subsliceLen, - }) - } - - fmt.Printf("nextSlice length: %v\n", nextSlice.Len()) - fmt.Printf("%v\n\n", nextSlice) - - slice = nextSlice - } - return slice -} - -// setSliceInSlice sets slice[index] = content. -func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) { - const sliceSize = unsafe.Sizeof(sliceHeader{}) - // We must cast slice.Pointer to uninptr & back again to avoid GC issues. - // See https://github.com/google/go-cmp/issues/167#issuecomment-546093202 - *(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content -} -func numElements(shape []int64) int64 { - n := int64(1) - for _, d := range shape { - n *= d - } - return n -} - -// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and -// this is not inspected by the garbage collector -type sliceHeader struct { - Data unsafe.Pointer - Len int - Cap int -} - -// Format implements fmt.Formatter interface so that we can use -// fmt.Print... and verbs to print out Tensor value in different formats. -func (ts *Tensor) Format(s fmt.State, c rune) { - shape := ts.MustSize() - device := ts.MustDevice() - dtype := ts.DType() - if c == 'i' { - fmt.Fprintf(s, "\nTENSOR INFO:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n\tDefined:\t%v\n", shape, dtype, device, ts.MustDefined()) - return - } - - data := ts.ValueGo() - - f := newFmtState(s, c, shape) - f.setWidth(data) - f.makePad() - - // 0d (scalar) - if len(shape) == 0 { - fmt.Printf("%v", data) - } - - // 1d (slice) - if len(shape) == 1 { - f.writeSlice(data) - return - } - - // 2d (matrix) - if len(shape) == 2 { - f.writeMatrix(data, shape) - return - } - - // >= 3d (tensor) - mSize := int(shape[len(shape)-2] * shape[len(shape)-1]) - mShape := shape[len(shape)-2:] - dims := shape[:len(shape)-2] - var rdims []int64 - for d := len(dims) - 1; d >= 0; d-- { - rdims = append(rdims, dims[d]) - } - - f.writeTensor(0, rdims, 0, mSize, mShape, data, "") -} - -// fmtState is a struct that implements fmt.State interface -type fmtState struct { - fmt.State - c rune // format verb - pad []byte // padding - w int // width - p int // precision - shape []int64 - buf *bytes.Buffer -} - -func newFmtState(s fmt.State, c rune, shape []int64) *fmtState { - w, _ := s.Width() - p, _ := s.Precision() - - return &fmtState{ - State: s, - c: c, - w: w, - p: p, - shape: shape, - buf: bytes.NewBuffer(make([]byte, 0)), - } -} - -// writeTensor iterates recursively through a reversed shape of tensor starting from axis 3 and -// and prints out matrices (of last two dims size). -func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShape []int64, data interface{}, mName string) { - - offsetSize := product(dims[:d]) * mSize - for i := 0; i < int(dims[d]); i++ { - name := fmt.Sprintf("%v,%v", i+1, mName) - if d >= len(dims)-1 { // last dim, let's print out - // write matrix name - nameStr := fmt.Sprintf("(%v.,.) =\n", name) - f.Write([]byte(nameStr)) - // write matrix values - slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface() - f.writeMatrix(slice, mShape) - } else { // recursively loop - f.writeTensor(d+1, dims, offset, mSize, mShape, data, name) - } - - // update offset - offset += offsetSize - } -} - -func product(dims []int64) int { - var p int = 1 - if len(dims) == 0 { - return p - } - - for _, d := range dims { - p = p * int(d) - } - - return p -} - -func (f *fmtState) writeMatrix(data interface{}, shape []int64) { - n := shapeToSize(shape) - dataLen := reflect.ValueOf(data).Len() - if dataLen != n { - log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n) - } - if len(shape) != 2 { - log.Fatal("Shape must have length of 2.\n") - } - - stride := int(shape[1]) - currIdx := 0 - nextIdx := stride - for i := 0; i < int(shape[0]); i++ { - slice := reflect.ValueOf(data).Slice(currIdx, nextIdx) - f.writeSlice(slice.Interface()) - currIdx = nextIdx - nextIdx += stride - } - - // 1 line between matrix - f.Write([]byte("\n")) - -} - -func (f *fmtState) writeSlice(data interface{}) { - format := f.cleanFmt() - dataLen := reflect.ValueOf(data).Len() - for i := 0; i < dataLen; i++ { - el := reflect.ValueOf(data).Index(i).Interface() - - // TODO: more format options here - w, _ := fmt.Fprintf(f.buf, format, el) - f.Write(f.buf.Bytes()) - f.Write(f.pad[:f.w-w]) // prepad - f.Write(f.pad[:2]) // pad - f.buf.Reset() - } - - f.Write([]byte("\n")) -} - -func (f *fmtState) cleanFmt() string { - buf := bytes.NewBuffer(fmtByte) - - // width - if w, ok := f.Width(); ok { - buf.WriteString(strconv.Itoa(w)) - } - - // precision - if p, ok := f.Precision(); ok { - buf.Write(precByte) - buf.WriteString(strconv.Itoa(p)) - } - - buf.WriteRune(f.c) - return buf.String() -} - -func (f *fmtState) makePad() { - f.pad = make([]byte, maxInt(f.w, 4)) - for i := range f.pad { - f.pad[i] = ' ' - } -} - -// setWidth determines maximal width from input data and set to `w` field -func (f *fmtState) setWidth(data interface{}) { - format := f.cleanFmt() - f.w = 0 - for i := 0; i < reflect.ValueOf(data).Len(); i++ { - el := reflect.ValueOf(data).Index(i).Interface() - w, _ := fmt.Fprintf(f.buf, format, el) - if w > f.w { - f.w = w - } - f.buf.Reset() - } -} func shapeToSize(shape []int64) int { n := 1 @@ -345,9 +59,467 @@ func shapeToSize(shape []int64) int { return n } +func shapeToNumels(shape []int) int { + n := 1 + for _, d := range shape { + n *= int(d) + } + return n +} + +func shapeToStrides(shape []int) []int { + numel := shapeToNumels(shape) + var strides []int + for _, v := range shape { + numel /= int(v) + strides = append(strides, numel) + } + + return strides +} + +func toSliceInt(in []int64) []int { + out := make([]int, len(in)) + for i := 0; i < len(in); i++ { + out[i] = int(in[i]) + } + return out +} + func maxInt(a, b int) int { if a >= b { return a } return b } + +func minInt(a, b int) int { + if a < b { + return a + } else { + return b + } +} + +func toSlice(input interface{}) []interface{} { + vlen := reflect.ValueOf(input).Len() + out := make([]interface{}, vlen) + for i := 0; i < vlen; i++ { + out[i] = reflect.ValueOf(input).Index(i).Interface() + } + + return out +} + +func sliceInterface(data interface{}, start, end int) []interface{} { + return toSlice(data)[start:end] +} + +// Implement Format interface for Tensor: +// ====================================== +var ( + fmtByte = []byte("%") + precByte = []byte(".") + fmtFlags = [...]rune{'+', '-', '#', ' ', '0'} + + ufVec = []byte("Vector") + ufMat = []byte("Matrix") + ufTensor = []byte("Tensor") +) + +// fmtState is a custom formatter for Tensor that implements fmt.State interface +type fmtState struct { + fmt.State + verb rune // format verb + flat bool // whether to print tensor in flatten format + meta bool // whether to print out meta data + ext bool // whether to print full tensor data (no truncation) + pad []byte // padding space + w int // width - total calculated space to print out tensor values + p int // precision for float dtype + base int // integer counting base for integer dtype cases + htrunc []byte // horizontal truncation symbol space + vtrunc []byte // vertical truncation symbol space + rows int // total rows + cols int // total columns + printRows int // rows to print + printCols int // columns to print + shape []int // shape of tensor to print out + buf *bytes.Buffer // memory to hold formated tensor data to print +} + +func newFmtState(s fmt.State, verb rune, shape []int) *fmtState { + w, _ := s.Width() + p, _ := s.Precision() + + return &fmtState{ + State: s, + verb: verb, + flat: s.Flag('-'), + meta: s.Flag('+'), + ext: s.Flag('#'), + w: w, + p: p, + htrunc: []byte("..., "), + vtrunc: []byte("...,\n"), + shape: shape, + buf: bytes.NewBuffer(make([]byte, 0)), + } +} + +// originalFmt returns original format. +func (f *fmtState) originalFmt() string { + // write format symbol and verbs + buf := bytes.NewBuffer(fmtByte) // '%' + for _, flag := range fmtFlags { + if f.Flag(int(flag)) { + buf.WriteRune(flag) + } + } + + // write width format + if w, ok := f.Width(); ok { + buf.WriteString(strconv.Itoa(w)) + } + + // write precision verb + if p, ok := f.Precision(); ok { + buf.Write(precByte) + buf.WriteString(strconv.Itoa(p)) + } + + buf.WriteRune(f.verb) + + return buf.String() +} + +// cleanFmt returns a start of the format. +func (f *fmtState) initFmt() string { + buf := bytes.NewBuffer(fmtByte) + + // write width format + if w, ok := f.Width(); ok { + buf.WriteString(strconv.Itoa(w)) + } + + // write precision verb + if p, ok := f.Precision(); ok { + buf.Write(precByte) + buf.WriteString(strconv.Itoa(p)) + } + + buf.WriteRune(f.verb) + + return buf.String() +} + +// cast casts tensor data to the formatter. +func (f *fmtState) cast(ts *Tensor) { + // rows and columns + if ts.Dim() == 1 { + f.rows = 1 + f.cols = int(ts.Numel()) + } else { + shape := ts.MustSize() + f.rows = int(shape[len(shape)-2]) + f.cols = int(shape[len(shape)-1]) + } + + // printRows and printCols + switch { + case f.flat && f.ext: + f.printCols = int(ts.Numel()) + case f.flat: + f.printCols = 10 + case f.ext: + f.printCols = f.cols + f.printRows = f.rows + default: + f.printCols = minInt(f.cols, 6) + f.printRows = minInt(f.rows, 6) + } +} + +// fmtVerb formats verbs. +func (f *fmtState) fmtVerb(ts *Tensor) { + if f.verb == 'H' { // print out only header. + f.meta = true + return + } + + // var typ T + typ := ts.DType() + + switch typ.String() { + case "float32", "float64": + switch f.verb { + case 'f', 'e', 'E', 'G', 'b': + // accepted. Do nothing + default: + f.verb = 'g' + } + + case "uint8", "int8", "int16", "int32", "int64": + switch f.verb { + case 'b': + f.base = 2 + case 'd': + f.base = 10 + case 'o': + f.base = 8 + case 'x', 'X': + f.base = 16 + default: + f.base = 10 + f.verb = 'd' + } + case "bool": + f.verb = 't' + default: + f.verb = 'v' + } +} + +// computeWidth computes a width that can fit for every element. +func (f *fmtState) computeWidth(values interface{}) { + format := f.initFmt() + vlen := reflect.ValueOf(values).Len() + f.w = 0 + for i := 0; i < vlen; i++ { + val := reflect.ValueOf(values).Index(i) + w, _ := fmt.Fprintf(f.buf, format, val) + + if w > f.w { + f.w = w + } + f.buf.Reset() + } +} + +// makePad prepares white spaces for print-out format. +func (f *fmtState) makePad() { + f.pad = make([]byte, maxInt(f.w, 2)) + for i := range f.pad { + f.pad[i] = ' ' // one white space + } +} + +func (f *fmtState) writeHTrunc() { + f.Write(f.htrunc) +} + +func (f *fmtState) writeVTrunc() { + f.Write(f.vtrunc) +} + +// Format implements fmt.Formatter interface so that we can use +// fmt.Print... and verbs to print out Tensor value in different formats. +func (ts *Tensor) Format(s fmt.State, verb rune) { + shape := toSliceInt(ts.MustSize()) + strides := shapeToStrides(shape) + device := ts.MustDevice() + dtype := ts.DType().String() + if verb == 'i' { + fmt.Fprintf( + s, + "\nTENSOR META:\n\tShape:\t\t%v\n\tDType:\t\t%v\n\tDevice:\t\t%v\n", + shape, + dtype, + device, + ) + return + } + + data := ts.ValueGo() + + f := newFmtState(s, verb, shape) + f.computeWidth(data) + f.makePad() + f.cast(ts) + + // Tensor meta data + if f.meta { + switch ts.Dim() { + case 1: + f.Write(ufVec) + case 2: + f.Write(ufMat) + default: + f.Write(ufTensor) + fmt.Fprintf(f, ": Dim=%d, ", ts.Dim()) + } + fmt.Fprintf(f, "Shape=%v, Strides=%v\n", shape, strides) + } + + if f.verb == 'H' { + return + } + + if f.flat { + // TODO. + // writeFlatTensor() + log.Printf("WARNING: f.writeFlatTensor() NotImplemedted.\n") + return + } + + // 0d (scalar) + if len(shape) == 0 { + fmt.Printf("%v", data) + } + + // 1d (slice) + if len(shape) == 1 { + values := sliceInterface(data, 0, shape[0]) + f.writeVector(values) + return + } + + // 2d (matrix) + if len(shape) == 2 { + vlen := shape[0] * shape[1] + values := sliceInterface(data, 0, vlen) + f.writeMatrix(values, shape) + return + } + + // >= 3d (tensor) + f.Write([]byte("\n")) + f.writeTensor(ts, data) +} + +// writeTensor writes input tensor in specified format. +func (f *fmtState) writeTensor(ts *Tensor, values interface{}) { + shape := toSliceInt(ts.MustSize()) + strides := shapeToStrides(shape) + size := shapeToNumels(shape) + mSize := shape[len(shape)-1] * shape[len(shape)-2] + var ( + offset int + printOne = false + ) + + for i := 0; i < size; i += int(mSize) { + dims := make([]int, len(strides)-2) + for n := 0; n < len(strides[:len(strides)-2]); n++ { + stride := strides[n] + var dim int = offset + for _, s := range strides[:n] { + dim = dim % s + } + dim = dim / stride + dims[n] = dim + } + + var ( + vlimit = f.printCols + shouldPrint = true + printVTrunc bool + conds []bool + ) + for i, val := range dims { + maxDim := shape[i] + if (val < vlimit/2) || (val >= maxDim-vlimit/2) { + conds = append(conds, true) + shouldPrint = shouldPrint && true + } else { + conds = append(conds, false) + shouldPrint = shouldPrint && false + + printVTrunc = true + } + } // inner for + if shouldPrint { + dimsLabel := "(" + for _, d := range dims { + dimsLabel += fmt.Sprintf("%v, ", d) + } + dimsLabel += ".,.) =\n" + f.Write([]byte(dimsLabel)) + + // Print matrix [H, W] + data := sliceInterface(values, offset, offset+mSize) + f.writeMatrix(data, shape[len(shape)-2:]) + + printOne = false + } + + if printVTrunc && !printOne { + // NOTE. vertical truncation at > 2D level + vtrunc := fmt.Sprintf("...,\n\n") + f.Write([]byte(vtrunc)) + printOne = true + } + offset += mSize + } // outer for +} + +func (f *fmtState) writeMatrix(data []interface{}, shape []int) { + n := shapeToNumels(shape) + dataLen := len(data) + if dataLen != n { + log.Fatalf("mismatched: slice data has %v elements - shape: %v\n", dataLen, n) + } + if len(shape) != 2 { + log.Fatal("Shape must have length of 2.\n") + } + + stride := int(shape[1]) + currIdx := 0 + nextIdx := stride + truncatedRows := shape[0] - f.printCols + for row := 0; row < int(shape[0]); row++ { + var slice []interface{} + switch { + case row < f.printCols/2: // First part + slice = data[currIdx:nextIdx] + f.writeVector(slice) + currIdx = nextIdx + nextIdx += stride + case row == f.printCols/2: // Truncated sign + if f.printCols != f.cols { // truncated mode + f.writeVTrunc() + } else { // full mode + // Do nothing + } + currIdx = nextIdx + nextIdx += stride + case row > f.printCols/2 && row < f.printCols/2+truncatedRows: // Skip part + currIdx = nextIdx + nextIdx += stride + case row >= f.printCols/2+truncatedRows: // Second part + slice = data[currIdx:nextIdx] + f.writeVector(slice) + currIdx = nextIdx + nextIdx += stride + } + } + + // 1 line between matrix + f.Write([]byte("\n")) + +} + +func (f *fmtState) writeVector(data []interface{}) { + format := f.initFmt() + vlen := len(data) + for col := 0; col < vlen; col++ { + if f.cols <= f.printCols || (col < f.printCols/2 || (col >= f.cols-f.printCols/2)) { + el := data[col] + // TODO: more format options here + w, _ := fmt.Fprintf(f.buf, format, el) + f.Write(f.buf.Bytes()) + f.Write(f.pad[:f.w-w]) // prepad + f.Write(f.pad[:2]) // pad + f.buf.Reset() + } else if col == f.printCols/2 { + f.writeHTrunc() + } + } + + f.Write([]byte("\n")) +} + +// Print prints tensor meta data to stdout. +func (ts *Tensor) Info() { + fmt.Printf("%i", ts) +} diff --git a/ts/print_test.go b/ts/print_test.go new file mode 100644 index 0000000..07fbf05 --- /dev/null +++ b/ts/print_test.go @@ -0,0 +1,19 @@ +package ts_test + +import ( + "fmt" + "testing" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/ts" +) + +func TestTensor_Format(t *testing.T) { + shape := []int64{8, 8, 8} + numels := int64(8 * 8 * 8) + + x := ts.MustArange(ts.IntScalar(numels), gotch.Float, gotch.CPU).MustView(shape, true) + + fmt.Printf("%0.1f", x) // print truncated data + // fmt.Printf("%#0.1f", x) // print full data +}