tensor/print: more print options; updated for minor version
This commit is contained in:
parent
5d60adf0fd
commit
b18d1cde89
|
@ -31,7 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
### Added
|
### Added
|
||||||
- Added `tensor.SaveMultiNew`
|
- Added `tensor.SaveMultiNew`
|
||||||
|
|
||||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
|
||||||
|
|
||||||
## [0.2.0]
|
## [0.2.0]
|
||||||
|
|
||||||
|
@ -55,3 +54,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
|
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
|
||||||
- Lots of clean-up return variables i.e. retVal, err
|
- Lots of clean-up return variables i.e. retVal, err
|
||||||
|
|
||||||
|
## [0.3.2]
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- [#6]: Go native tensor print using `fmt.Formatter` interface. Now, a tensor can be printed out like: `fmt.Printf("%.3f", tensor)` (for float type)
|
||||||
|
|
||||||
|
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||||
|
[#6]: https://github.com/sugarme/gotch/issues/6
|
||||||
|
|
12
README.md
12
README.md
|
@ -16,21 +16,21 @@
|
||||||
|
|
||||||
- **CPU**
|
- **CPU**
|
||||||
|
|
||||||
Default values: `LIBTORCH_VER=1.7.0` and `GOTCH_VER=v0.3.1`
|
Default values: `LIBTORCH_VER=1.7.0` and `GOTCH_VER=v0.3.2`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go get -u github.com/sugarme/gotch@v0.3.1
|
go get -u github.com/sugarme/gotch@v0.3.2
|
||||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.1/setup-cpu.sh
|
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-cpu.sh
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- **GPU**
|
- **GPU**
|
||||||
|
|
||||||
Default values: `LIBTORCH_VER=1.7.0`, `CUDA_VER=10.1` and `GOTCH_VER=v0.3.1`
|
Default values: `LIBTORCH_VER=1.7.0`, `CUDA_VER=10.1` and `GOTCH_VER=v0.3.2`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
go get -u github.com/sugarme/gotch@v0.3.1
|
go get -u github.com/sugarme/gotch@v0.3.2
|
||||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.1/setup-gpu.sh
|
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-gpu.sh
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -8,16 +8,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// intTensor()
|
intTensor()
|
||||||
floatTensor()
|
floatTensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
func intTensor() {
|
func intTensor() {
|
||||||
xs := ts.MustArange(ts.IntScalar(7*3*4*5*6), gotch.Int64, gotch.CPU).MustView([]int64{7, 3, 4, 5, 6}, true)
|
xs := ts.MustArange(ts.IntScalar(7*3*4*5*6), gotch.Int64, gotch.CPU).MustView([]int64{7, 3, 4, 5, 6}, true)
|
||||||
fmt.Printf("%v\n", xs)
|
fmt.Printf("%0.4d\n", xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func floatTensor() {
|
func floatTensor() {
|
||||||
xs := ts.MustRand([]int64{7, 3, 4, 5, 6}, gotch.Double, gotch.CPU)
|
// xs := ts.MustRand([]int64{7, 3, 4, 5, 6}, gotch.Double, gotch.CPU)
|
||||||
fmt.Printf("%v\n", xs)
|
xs := ts.MustRand([]int64{3, 5, 6}, gotch.Float, gotch.CPU)
|
||||||
|
fmt.Printf("%.3f\n", xs)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Env
|
# Env
|
||||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.1}"
|
GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||||
|
|
||||||
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.1}"
|
GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||||
CUDA_VERSION="${CUDA_VER:-10.1}"
|
CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||||
CU_VERSION="${CUDA_VERSION//./}"
|
CU_VERSION="${CUDA_VERSION//./}"
|
||||||
|
|
2
setup.sh
2
setup.sh
|
@ -1,6 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
export GOTCH_VERSION="${GOTCH_VER:-v0.3.1}"
|
export GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||||
export LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
export LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||||
export CUDA_VERSION="${CUDA_VER:-10.1}"
|
export CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||||
export CU_VERSION="${CUDA_VERSION//./}"
|
export CU_VERSION="${CUDA_VERSION//./}"
|
||||||
|
|
|
@ -3,12 +3,17 @@ package tensor
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"log"
|
"log"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var fmtByte = []byte("%")
|
||||||
|
var precByte = []byte(".")
|
||||||
|
|
||||||
func (ts *Tensor) ValueGo() interface{} {
|
func (ts *Tensor) ValueGo() interface{} {
|
||||||
dtype := ts.DType()
|
dtype := ts.DType()
|
||||||
numel := ts.Numel()
|
numel := ts.Numel()
|
||||||
|
@ -52,14 +57,7 @@ func (ts *Tensor) ToSlice() reflect.Value {
|
||||||
typ reflect.Type
|
typ reflect.Type
|
||||||
)
|
)
|
||||||
if dt.String() == "String" {
|
if dt.String() == "String" {
|
||||||
panic("Unsupported")
|
panic("Unsupported 'String' type")
|
||||||
// strs, err := decodeOneDimString(raw, n)
|
|
||||||
// if err != nil {
|
|
||||||
// e := fmt.Errorf("unable to decode string with shape %v: %v", shape, err)
|
|
||||||
// panic(e)
|
|
||||||
// }
|
|
||||||
// slice = reflect.ValueOf(strs)
|
|
||||||
// typ = slice.Type()
|
|
||||||
} else {
|
} else {
|
||||||
gtyp, err := gotch.ToGoType(dt)
|
gtyp, err := gotch.ToGoType(dt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -212,7 +210,10 @@ func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShap
|
||||||
for i := 0; i < int(dims[d]); i++ {
|
for i := 0; i < int(dims[d]); i++ {
|
||||||
name := fmt.Sprintf("%v,%v", i+1, mName)
|
name := fmt.Sprintf("%v,%v", i+1, mName)
|
||||||
if d >= len(dims)-1 { // last dim, let's print out
|
if d >= len(dims)-1 { // last dim, let's print out
|
||||||
fmt.Printf("(%v,.,.) =\n", name)
|
// write matrix name
|
||||||
|
nameStr := fmt.Sprintf("(%v.,.) =\n", name)
|
||||||
|
f.Write([]byte(nameStr))
|
||||||
|
// write matrix values
|
||||||
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
|
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
|
||||||
f.writeMatrix(slice, mShape)
|
f.writeMatrix(slice, mShape)
|
||||||
} else { // recursively loop
|
} else { // recursively loop
|
||||||
|
@ -253,25 +254,47 @@ func (f *fmtState) writeMatrix(data interface{}, shape []int64) {
|
||||||
for i := 0; i < int(shape[0]); i++ {
|
for i := 0; i < int(shape[0]); i++ {
|
||||||
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
|
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
|
||||||
f.writeSlice(slice.Interface())
|
f.writeSlice(slice.Interface())
|
||||||
// fmt.Printf("%4v\n", slice)
|
|
||||||
currIdx = nextIdx
|
currIdx = nextIdx
|
||||||
nextIdx += stride
|
nextIdx += stride
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2 lines
|
// 1 line between matrix
|
||||||
fmt.Printf("\n\n")
|
f.Write([]byte("\n"))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fmtState) writeSlice(data interface{}) {
|
func (f *fmtState) writeSlice(data interface{}) {
|
||||||
|
format := f.cleanFmt()
|
||||||
dataLen := reflect.ValueOf(data).Len()
|
dataLen := reflect.ValueOf(data).Len()
|
||||||
for i := 0; i < dataLen; i++ {
|
for i := 0; i < dataLen; i++ {
|
||||||
el := reflect.ValueOf(data).Index(i)
|
el := reflect.ValueOf(data).Index(i).Interface()
|
||||||
|
|
||||||
// TODO: more format options here
|
// TODO: more format options here
|
||||||
fmt.Printf("%4v ", el)
|
fmt.Fprintf(f.buf, format, el)
|
||||||
|
f.Write(f.buf.Bytes())
|
||||||
|
f.Write([]byte(" "))
|
||||||
|
f.buf.Reset()
|
||||||
}
|
}
|
||||||
fmt.Println()
|
|
||||||
|
f.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fmtState) cleanFmt() string {
|
||||||
|
buf := bytes.NewBuffer(fmtByte)
|
||||||
|
|
||||||
|
// width
|
||||||
|
if w, ok := f.Width(); ok {
|
||||||
|
buf.WriteString(strconv.Itoa(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// precision
|
||||||
|
if p, ok := f.Precision(); ok {
|
||||||
|
buf.Write(precByte)
|
||||||
|
buf.WriteString(strconv.Itoa(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(f.c)
|
||||||
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func shapeToSize(shape []int64) int {
|
func shapeToSize(shape []int64) int {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user