fix(wrapper/tensor-generated-sample): cuda matmul works only with float type

This commit is contained in:
sugarme 2020-06-06 06:07:49 +10:00
parent 41c0cfaab2
commit fb2ef97a60
2 changed files with 19 additions and 24 deletions

View File

@ -4,7 +4,7 @@ import (
"fmt"
"log"
// gotch "github.com/sugarme/gotch"
gotch "github.com/sugarme/gotch"
wrapper "github.com/sugarme/gotch/wrapper"
)
@ -55,18 +55,18 @@ func main() {
fmt.Printf("DType: %v\n", ts.DType())
dx := [][]int32{
dx := [][]float64{
{1, 1},
{1, 1},
{1, 1},
}
dy := [][]int32{
dy := [][]float64{
{1, 2, 3},
{1, 1, 1},
}
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 3})
xs, err := wrapper.NewTensorFromData(dx, []int64{3, 2})
if err != nil {
log.Fatal(err)
}
@ -77,23 +77,16 @@ func main() {
xs.Matmul(ys)
// device := gotch.NewCuda()
//
// // cy := ys.To(device)
// // cx := xs.To(device)
// zs := wrapper.NewTensor()
// cz := zs.To(device)
// fmt.Println(cz.Device().Name)
device := gotch.NewCuda()
// cx.Matmul(cy)
// for i := 1; i < 1000000; i++ {
// for i := 1; i < 2; i++ {
// cx := xs.To(device)
// cx.Print()
// cy := ys.To(device)
// cy.Print()
//
// }
// NOTE: this will call CUDA out of memory error.
// TODO: free CUDA memory at API somewhere.
for i := 1; i < 1000000; i++ {
cx := xs.To(device)
// cx.Print()
cy := ys.To(device)
// cy.Print()
cx.Matmul(cy)
}
}

View File

@ -29,13 +29,13 @@ func (ts Tensor) To(device gt.Device) Tensor {
// TODO: how to get pointer to CUDA memory???
// Something like `C.cudaMalloc()`???
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// var cudaPtr unsafe.Pointer
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1)
// fmt.Printf("Cuda Pointer: %v\n", &cudaPtr)
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.cuMemAlloc(device.CInt(), 0)))
var ptr unsafe.Pointer
// var ptr unsafe.Pointer
// lib.AtgTo(ptr, ts.ctensor, int(device.CInt()))
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
// lib.AtgTo((*lib.Ctensor)(cudaPtr), ts.ctensor, int(device.CInt()))
@ -44,8 +44,10 @@ func (ts Tensor) To(device gt.Device) Tensor {
log.Fatal(err)
}
return Tensor{ctensor: *(*lib.Ctensor)(unsafe.Pointer(&ptr))}
// return Tensor{ctensor: *(*lib.Ctensor)(unsafe.Pointer(&ptr))}
// return Tensor{ctensor: (lib.Ctensor)(cudaPtr)}
return Tensor{ctensor: *ptr}
}
func (ts Tensor) Device() gt.Device {