tensor/print: more print options; updated for minor version
This commit is contained in:
parent
5d60adf0fd
commit
b18d1cde89
|
@ -31,7 +31,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
### Added
|
||||
- Added `tensor.SaveMultiNew`
|
||||
|
||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||
|
||||
## [0.2.0]
|
||||
|
||||
|
@ -55,3 +54,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Changed to use `*Path` argument of `NewLayerNorm` method at `nn/layer-norm.go`
|
||||
- Lots of clean-up return variables i.e. retVal, err
|
||||
|
||||
## [0.3.2]
|
||||
|
||||
### Added
|
||||
- [#6]: Go native tensor print using `fmt.Formatter` interface. Now, a tensor can be printed out like: `fmt.Printf("%.3f", tensor)` (for float type)
|
||||
|
||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||
[#6]: https://github.com/sugarme/gotch/issues/6
|
||||
|
|
12
README.md
12
README.md
|
@ -16,21 +16,21 @@
|
|||
|
||||
- **CPU**
|
||||
|
||||
Default values: `LIBTORCH_VER=1.7.0` and `GOTCH_VER=v0.3.1`
|
||||
Default values: `LIBTORCH_VER=1.7.0` and `GOTCH_VER=v0.3.2`
|
||||
|
||||
```bash
|
||||
go get -u github.com/sugarme/gotch@v0.3.1
|
||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.1/setup-cpu.sh
|
||||
go get -u github.com/sugarme/gotch@v0.3.2
|
||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-cpu.sh
|
||||
|
||||
```
|
||||
|
||||
- **GPU**
|
||||
|
||||
Default values: `LIBTORCH_VER=1.7.0`, `CUDA_VER=10.1` and `GOTCH_VER=v0.3.1`
|
||||
Default values: `LIBTORCH_VER=1.7.0`, `CUDA_VER=10.1` and `GOTCH_VER=v0.3.2`
|
||||
|
||||
```bash
|
||||
go get -u github.com/sugarme/gotch@v0.3.1
|
||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.1/setup-gpu.sh
|
||||
go get -u github.com/sugarme/gotch@v0.3.2
|
||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-gpu.sh
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -8,16 +8,17 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
// intTensor()
|
||||
intTensor()
|
||||
floatTensor()
|
||||
}
|
||||
|
||||
func intTensor() {
|
||||
xs := ts.MustArange(ts.IntScalar(7*3*4*5*6), gotch.Int64, gotch.CPU).MustView([]int64{7, 3, 4, 5, 6}, true)
|
||||
fmt.Printf("%v\n", xs)
|
||||
fmt.Printf("%0.4d\n", xs)
|
||||
}
|
||||
|
||||
func floatTensor() {
|
||||
xs := ts.MustRand([]int64{7, 3, 4, 5, 6}, gotch.Double, gotch.CPU)
|
||||
fmt.Printf("%v\n", xs)
|
||||
// xs := ts.MustRand([]int64{7, 3, 4, 5, 6}, gotch.Double, gotch.CPU)
|
||||
xs := ts.MustRand([]int64{3, 5, 6}, gotch.Float, gotch.CPU)
|
||||
fmt.Printf("%.3f\n", xs)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Env
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.1}"
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
|
||||
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.1}"
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
CU_VERSION="${CUDA_VERSION//./}"
|
||||
|
|
2
setup.sh
2
setup.sh
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
export GOTCH_VERSION="${GOTCH_VER:-v0.3.1}"
|
||||
export GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||
export LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
export CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
export CU_VERSION="${CUDA_VERSION//./}"
|
||||
|
|
|
@ -3,12 +3,17 @@ package tensor
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/sugarme/gotch"
|
||||
"log"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
var fmtByte = []byte("%")
|
||||
var precByte = []byte(".")
|
||||
|
||||
func (ts *Tensor) ValueGo() interface{} {
|
||||
dtype := ts.DType()
|
||||
numel := ts.Numel()
|
||||
|
@ -52,14 +57,7 @@ func (ts *Tensor) ToSlice() reflect.Value {
|
|||
typ reflect.Type
|
||||
)
|
||||
if dt.String() == "String" {
|
||||
panic("Unsupported")
|
||||
// strs, err := decodeOneDimString(raw, n)
|
||||
// if err != nil {
|
||||
// e := fmt.Errorf("unable to decode string with shape %v: %v", shape, err)
|
||||
// panic(e)
|
||||
// }
|
||||
// slice = reflect.ValueOf(strs)
|
||||
// typ = slice.Type()
|
||||
panic("Unsupported 'String' type")
|
||||
} else {
|
||||
gtyp, err := gotch.ToGoType(dt)
|
||||
if err != nil {
|
||||
|
@ -212,7 +210,10 @@ func (f *fmtState) writeTensor(d int, dims []int64, offset int, mSize int, mShap
|
|||
for i := 0; i < int(dims[d]); i++ {
|
||||
name := fmt.Sprintf("%v,%v", i+1, mName)
|
||||
if d >= len(dims)-1 { // last dim, let's print out
|
||||
fmt.Printf("(%v,.,.) =\n", name)
|
||||
// write matrix name
|
||||
nameStr := fmt.Sprintf("(%v.,.) =\n", name)
|
||||
f.Write([]byte(nameStr))
|
||||
// write matrix values
|
||||
slice := reflect.ValueOf(data).Slice(offset, offset+mSize).Interface()
|
||||
f.writeMatrix(slice, mShape)
|
||||
} else { // recursively loop
|
||||
|
@ -253,25 +254,47 @@ func (f *fmtState) writeMatrix(data interface{}, shape []int64) {
|
|||
for i := 0; i < int(shape[0]); i++ {
|
||||
slice := reflect.ValueOf(data).Slice(currIdx, nextIdx)
|
||||
f.writeSlice(slice.Interface())
|
||||
// fmt.Printf("%4v\n", slice)
|
||||
currIdx = nextIdx
|
||||
nextIdx += stride
|
||||
}
|
||||
|
||||
// 2 lines
|
||||
fmt.Printf("\n\n")
|
||||
// 1 line between matrix
|
||||
f.Write([]byte("\n"))
|
||||
|
||||
}
|
||||
|
||||
func (f *fmtState) writeSlice(data interface{}) {
|
||||
format := f.cleanFmt()
|
||||
dataLen := reflect.ValueOf(data).Len()
|
||||
for i := 0; i < dataLen; i++ {
|
||||
el := reflect.ValueOf(data).Index(i)
|
||||
el := reflect.ValueOf(data).Index(i).Interface()
|
||||
|
||||
// TODO: more format options here
|
||||
fmt.Printf("%4v ", el)
|
||||
fmt.Fprintf(f.buf, format, el)
|
||||
f.Write(f.buf.Bytes())
|
||||
f.Write([]byte(" "))
|
||||
f.buf.Reset()
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
f.Write([]byte("\n"))
|
||||
}
|
||||
|
||||
func (f *fmtState) cleanFmt() string {
|
||||
buf := bytes.NewBuffer(fmtByte)
|
||||
|
||||
// width
|
||||
if w, ok := f.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(w))
|
||||
}
|
||||
|
||||
// precision
|
||||
if p, ok := f.Precision(); ok {
|
||||
buf.Write(precByte)
|
||||
buf.WriteString(strconv.Itoa(p))
|
||||
}
|
||||
|
||||
buf.WriteRune(f.c)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func shapeToSize(shape []int64) int {
|
||||
|
|
Loading…
Reference in New Issue
Block a user