fix(wrapper/tensor-generated-sample): cuda matmul works only with float type
This commit is contained in:
parent
41c0cfaab2
commit
fb2ef97a60
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
|
||||
// gotch "github.com/sugarme/gotch"
|
||||
gotch "github.com/sugarme/gotch"
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
|
@ -55,18 +55,18 @@ func main() {
|
|||
|
||||
fmt.Printf("DType: %v\n", ts.DType())
|
||||
|
||||
dx := [][]int32{
|
||||
dx := [][]float64{
|
||||
{1, 1},
|
||||
{1, 1},
|
||||
{1, 1},
|
||||
}
|
||||
|
||||
dy := [][]int32{
|
||||
dy := [][]float64{
|
||||
{1, 2, 3},
|
||||
{1, 1, 1},
|
||||
}
|
||||
|
||||
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 3})
|
||||
xs, err := wrapper.NewTensorFromData(dx, []int64{3, 2})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -77,23 +77,16 @@ func main() {
|
|||
|
||||
xs.Matmul(ys)
|
||||
|
||||
// device := gotch.NewCuda()
|
||||
//
|
||||
// // cy := ys.To(device)
|
||||
// // cx := xs.To(device)
|
||||
// zs := wrapper.NewTensor()
|
||||
// cz := zs.To(device)
|
||||
// fmt.Println(cz.Device().Name)
|
||||
device := gotch.NewCuda()
|
||||
|
||||
// cx.Matmul(cy)
|
||||
|
||||
// for i := 1; i < 1000000; i++ {
|
||||
// for i := 1; i < 2; i++ {
|
||||
// cx := xs.To(device)
|
||||
// cx.Print()
|
||||
// cy := ys.To(device)
|
||||
// cy.Print()
|
||||
//
|
||||
// }
|
||||
// NOTE: this will call CUDA out of memory error.
|
||||
// TODO: free CUDA memory at API somewhere.
|
||||
for i := 1; i < 1000000; i++ {
|
||||
cx := xs.To(device)
|
||||
// cx.Print()
|
||||
cy := ys.To(device)
|
||||
// cy.Print()
|
||||
cx.Matmul(cy)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -29,13 +29,13 @@ func (ts Tensor) To(device gt.Device) Tensor {
|
|||
|
||||
// TODO: how to get pointer to CUDA memory???
|
||||
// Something like `C.cudaMalloc()`???
|
||||
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// var cudaPtr unsafe.Pointer
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1)
|
||||
// fmt.Printf("Cuda Pointer: %v\n", &cudaPtr)
|
||||
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.cuMemAlloc(device.CInt(), 0)))
|
||||
|
||||
var ptr unsafe.Pointer
|
||||
// var ptr unsafe.Pointer
|
||||
// lib.AtgTo(ptr, ts.ctensor, int(device.CInt()))
|
||||
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
|
||||
// lib.AtgTo((*lib.Ctensor)(cudaPtr), ts.ctensor, int(device.CInt()))
|
||||
|
@ -44,8 +44,10 @@ func (ts Tensor) To(device gt.Device) Tensor {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *(*lib.Ctensor)(unsafe.Pointer(&ptr))}
|
||||
// return Tensor{ctensor: *(*lib.Ctensor)(unsafe.Pointer(&ptr))}
|
||||
// return Tensor{ctensor: (lib.Ctensor)(cudaPtr)}
|
||||
|
||||
return Tensor{ctensor: *ptr}
|
||||
}
|
||||
|
||||
func (ts Tensor) Device() gt.Device {
|
||||
|
|
Loading…
Reference in New Issue
Block a user