2022-03-12 07:20:20 +00:00
|
|
|
package ts_test
|
2020-11-09 06:16:34 +00:00
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
|
2024-04-21 15:15:00 +01:00
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
2020-11-09 06:16:34 +00:00
|
|
|
)
|
|
|
|
|
2021-07-22 15:54:41 +01:00
|
|
|
func ExampleTensor_MustArange() {
|
|
|
|
tensor := ts.MustArange(ts.FloatScalar(12), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
2020-11-09 06:16:34 +00:00
|
|
|
|
|
|
|
fmt.Printf("%v", tensor)
|
|
|
|
|
|
|
|
// output
|
|
|
|
// 0 1 2 3
|
|
|
|
// 4 5 6 7
|
|
|
|
// 8 9 10 11
|
|
|
|
}
|
|
|
|
|
|
|
|
func ExampleTensor_Matmul() {
|
|
|
|
// Basic tensor operations
|
|
|
|
ts1 := ts.MustArange(ts.IntScalar(6), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
|
|
|
defer ts1.MustDrop()
|
|
|
|
ts2 := ts.MustOnes([]int64{3, 4}, gotch.Int64, gotch.CPU)
|
|
|
|
defer ts2.MustDrop()
|
|
|
|
|
|
|
|
mul := ts1.MustMatmul(ts2, false)
|
|
|
|
defer mul.MustDrop()
|
|
|
|
fmt.Println("ts1: ")
|
|
|
|
ts1.Print()
|
|
|
|
fmt.Println("ts2: ")
|
|
|
|
ts2.Print()
|
|
|
|
fmt.Println("mul tensor (ts1 x ts2): ")
|
|
|
|
mul.Print()
|
|
|
|
|
|
|
|
//ts1:
|
|
|
|
// 0 1 2
|
|
|
|
// 3 4 5
|
|
|
|
//[ CPULongType{2,3} ]
|
|
|
|
//ts2:
|
|
|
|
// 1 1 1 1
|
|
|
|
// 1 1 1 1
|
|
|
|
// 1 1 1 1
|
|
|
|
//[ CPULongType{3,4} ]
|
|
|
|
//mul tensor (ts1 x ts2):
|
|
|
|
// 3 3 3 3
|
|
|
|
// 12 12 12 12
|
|
|
|
//[ CPULongType{2,4} ]
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2021-07-22 15:54:41 +01:00
|
|
|
func ExampleTensor_AddScalar_() {
|
2020-11-09 06:16:34 +00:00
|
|
|
// In-place operation
|
|
|
|
ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU)
|
|
|
|
fmt.Println("Before:")
|
|
|
|
ts3.Print()
|
2021-07-22 15:54:41 +01:00
|
|
|
ts3.MustAddScalar_(ts.FloatScalar(2.0))
|
2020-11-09 06:16:34 +00:00
|
|
|
fmt.Printf("After (ts3 + 2.0): \n")
|
|
|
|
ts3.Print()
|
|
|
|
ts3.MustDrop()
|
|
|
|
|
|
|
|
//Before:
|
|
|
|
// 1 1 1
|
|
|
|
// 1 1 1
|
|
|
|
//[ CPUFloatType{2,3} ]
|
|
|
|
//After (ts3 + 2.0):
|
|
|
|
// 3 3 3
|
|
|
|
// 3 3 3
|
|
|
|
//[ CPUFloatType{2,3} ]
|
|
|
|
}
|