tensor/print: added padding and precision
This commit is contained in:
parent
5997cc1a2c
commit
c1bbee4880
46
README.md
46
README.md
|
@ -49,29 +49,41 @@ import (
|
|||
|
||||
func basicOps() {
|
||||
|
||||
// Initiate a tensor
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(12), gotch.Float, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
xs := ts.MustRand([]int64{3, 5, 6}, gotch.Float, gotch.CPU)
|
||||
fmt.Printf("%8.3f\n", xs)
|
||||
fmt.Printf("%i", xs)
|
||||
|
||||
tensor.Print()
|
||||
// 0 1 2 3
|
||||
// 4 5 6 7
|
||||
// 8 9 10 11
|
||||
// [ CPUFloatType{3,4} ]
|
||||
// output
|
||||
/*
|
||||
(1,.,.) =
|
||||
0.391 0.055 0.638 0.514 0.757 0.446
|
||||
0.817 0.075 0.437 0.452 0.077 0.492
|
||||
0.504 0.945 0.863 0.243 0.254 0.640
|
||||
0.850 0.132 0.763 0.572 0.216 0.116
|
||||
0.410 0.660 0.156 0.336 0.885 0.391
|
||||
|
||||
fmt.Printf("tensor values: %v\n", tensor.Float64Values())
|
||||
//tensor values: [0 1 2 3 4 5 6 7 8 9 10 11]
|
||||
(2,.,.) =
|
||||
0.952 0.731 0.380 0.390 0.374 0.001
|
||||
0.455 0.142 0.088 0.039 0.862 0.939
|
||||
0.621 0.198 0.728 0.914 0.168 0.057
|
||||
0.655 0.231 0.680 0.069 0.803 0.243
|
||||
0.853 0.729 0.983 0.534 0.749 0.624
|
||||
|
||||
fmt.Printf("tensor dtype: %v\n", tensor.DType())
|
||||
//tensor dtype: float32
|
||||
(3,.,.) =
|
||||
0.734 0.447 0.914 0.956 0.269 0.000
|
||||
0.427 0.034 0.477 0.535 0.440 0.972
|
||||
0.407 0.945 0.099 0.184 0.778 0.058
|
||||
0.482 0.996 0.085 0.605 0.282 0.671
|
||||
0.887 0.029 0.005 0.216 0.354 0.262
|
||||
|
||||
fmt.Printf("tensor shape: %v\n", tensor.MustSize())
|
||||
//tensor shape: [3 4]
|
||||
|
||||
fmt.Printf("tensor element number: %v\n", tensor.Numel())
|
||||
//tensor element number: 12
|
||||
|
||||
// Delete a tensor (NOTE. tensor is created in C memory and will need to free up manually.)
|
||||
tensor.MustDrop()
|
||||
TENSOR INFO:
|
||||
Shape: [3 5 6]
|
||||
DType: float32
|
||||
Device: {CPU 1}
|
||||
Defined: true
|
||||
*/
|
||||
|
||||
// Basic tensor operations
|
||||
ts1 := ts.MustArange(ts.IntScalar(6), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
|
|
|
@ -8,17 +8,18 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
intTensor()
|
||||
// intTensor()
|
||||
floatTensor()
|
||||
}
|
||||
|
||||
func intTensor() {
|
||||
xs := ts.MustArange(ts.IntScalar(7*3*4*5*6), gotch.Int64, gotch.CPU).MustView([]int64{7, 3, 4, 5, 6}, true)
|
||||
fmt.Printf("%0.4d\n", xs)
|
||||
fmt.Printf("%4d\n", xs)
|
||||
}
|
||||
|
||||
func floatTensor() {
|
||||
// xs := ts.MustRand([]int64{7, 3, 4, 5, 6}, gotch.Double, gotch.CPU)
|
||||
xs := ts.MustRand([]int64{3, 5, 6}, gotch.Float, gotch.CPU)
|
||||
fmt.Printf("%.3f\n", xs)
|
||||
fmt.Printf("%8.3f\n", xs)
|
||||
fmt.Printf("%i", xs)
|
||||
}
|
||||
|
|
71
tensor/basic-example_test.go
Normal file
71
tensor/basic-example_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package tensor_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func ExampleTensor_MustArange1() {
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(12), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
|
||||
fmt.Printf("%v", tensor)
|
||||
|
||||
// output
|
||||
// 0 1 2 3
|
||||
// 4 5 6 7
|
||||
// 8 9 10 11
|
||||
}
|
||||
|
||||
func ExampleTensor_Matmul() {
|
||||
// Basic tensor operations
|
||||
ts1 := ts.MustArange(ts.IntScalar(6), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
defer ts1.MustDrop()
|
||||
ts2 := ts.MustOnes([]int64{3, 4}, gotch.Int64, gotch.CPU)
|
||||
defer ts2.MustDrop()
|
||||
|
||||
mul := ts1.MustMatmul(ts2, false)
|
||||
defer mul.MustDrop()
|
||||
fmt.Println("ts1: ")
|
||||
ts1.Print()
|
||||
fmt.Println("ts2: ")
|
||||
ts2.Print()
|
||||
fmt.Println("mul tensor (ts1 x ts2): ")
|
||||
mul.Print()
|
||||
|
||||
//ts1:
|
||||
// 0 1 2
|
||||
// 3 4 5
|
||||
//[ CPULongType{2,3} ]
|
||||
//ts2:
|
||||
// 1 1 1 1
|
||||
// 1 1 1 1
|
||||
// 1 1 1 1
|
||||
//[ CPULongType{3,4} ]
|
||||
//mul tensor (ts1 x ts2):
|
||||
// 3 3 3 3
|
||||
// 12 12 12 12
|
||||
//[ CPULongType{2,4} ]
|
||||
|
||||
}
|
||||
|
||||
func ExampleTensor_Add1_() {
|
||||
// In-place operation
|
||||
ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU)
|
||||
fmt.Println("Before:")
|
||||
ts3.Print()
|
||||
ts3.MustAdd1_(ts.FloatScalar(2.0))
|
||||
fmt.Printf("After (ts3 + 2.0): \n")
|
||||
ts3.Print()
|
||||
ts3.MustDrop()
|
||||
|
||||
//Before:
|
||||
// 1 1 1
|
||||
// 1 1 1
|
||||
//[ CPUFloatType{2,3} ]
|
||||
//After (ts3 + 2.0):
|
||||
// 3 3 3
|
||||
// 3 3 3
|
||||
//[ CPUFloatType{2,3} ]
|
||||
}
|
|
@ -152,10 +152,12 @@ func (ts *Tensor) Format(s fmt.State, c rune) {
|
|||
return
|
||||
}
|
||||
|
||||
f := newFmtState(s, c, shape)
|
||||
|
||||
data := ts.ValueGo()
|
||||
|
||||
f := newFmtState(s, c, shape)
|
||||
f.setWidth(data)
|
||||
f.makePad()
|
||||
|
||||
// 0d (scalar)
|
||||
if len(shape) == 0 {
|
||||
fmt.Printf("%v", data)
|
||||
|
@ -188,15 +190,23 @@ func (ts *Tensor) Format(s fmt.State, c rune) {
|
|||
// fmtState is a struct that implements fmt.State interface
|
||||
type fmtState struct {
|
||||
fmt.State
|
||||
c rune
|
||||
c rune // format verb
|
||||
pad []byte // padding
|
||||
w int // width
|
||||
p int // precision
|
||||
shape []int64
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newFmtState(s fmt.State, c rune, shape []int64) *fmtState {
|
||||
w, _ := s.Width()
|
||||
p, _ := s.Precision()
|
||||
|
||||
return &fmtState{
|
||||
State: s,
|
||||
c: c,
|
||||
w: w,
|
||||
p: p,
|
||||
shape: shape,
|
||||
buf: bytes.NewBuffer(make([]byte, 0)),
|
||||
}
|
||||
|
@ -270,9 +280,10 @@ func (f *fmtState) writeSlice(data interface{}) {
|
|||
el := reflect.ValueOf(data).Index(i).Interface()
|
||||
|
||||
// TODO: more format options here
|
||||
fmt.Fprintf(f.buf, format, el)
|
||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
||||
f.Write(f.buf.Bytes())
|
||||
f.Write([]byte(" "))
|
||||
f.Write(f.pad[:f.w-w]) // prepad
|
||||
f.Write(f.pad[:2]) // pad
|
||||
f.buf.Reset()
|
||||
}
|
||||
|
||||
|
@ -297,6 +308,27 @@ func (f *fmtState) cleanFmt() string {
|
|||
return buf.String()
|
||||
}
|
||||
|
||||
func (f *fmtState) makePad() {
|
||||
f.pad = make([]byte, maxInt(f.w, 4))
|
||||
for i := range f.pad {
|
||||
f.pad[i] = ' '
|
||||
}
|
||||
}
|
||||
|
||||
// setWidth determines maximal width from input data and set to `w` field
|
||||
func (f *fmtState) setWidth(data interface{}) {
|
||||
format := f.cleanFmt()
|
||||
f.w = 0
|
||||
for i := 0; i < reflect.ValueOf(data).Len(); i++ {
|
||||
el := reflect.ValueOf(data).Index(i).Interface()
|
||||
w, _ := fmt.Fprintf(f.buf, format, el)
|
||||
if w > f.w {
|
||||
f.w = w
|
||||
}
|
||||
f.buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func shapeToSize(shape []int64) int {
|
||||
n := 1
|
||||
for _, v := range shape {
|
||||
|
@ -307,3 +339,10 @@ func shapeToSize(shape []int64) int {
|
|||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a >= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user