chorus(cleanup): cleanup comments at wrapper/tensor and tensor-generated-sample, example
This commit is contained in:
parent
fb2ef97a60
commit
ff9ae65229
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
|
@ -34,25 +35,12 @@ func main() {
|
|||
|
||||
ts.Print()
|
||||
|
||||
// fmt.Printf("Dim: %v\n", ts.Dim())
|
||||
|
||||
// ts.Size()
|
||||
// fmt.Println(ts.Size())
|
||||
|
||||
sz, err := ts.Size2()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Shape: %v\n", sz)
|
||||
|
||||
// typ, count, err := wrapper.DataCheck(data)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// fmt.Printf("typ: %v\n", typ)
|
||||
// fmt.Printf("Count: %v\n", count)
|
||||
|
||||
fmt.Printf("DType: %v\n", ts.DType())
|
||||
|
||||
dx := [][]float64{
|
||||
|
@ -75,18 +63,21 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
xs.Matmul(ys)
|
||||
// CPU
|
||||
startCPUTime := time.Now()
|
||||
for i := 1; i < 100000; i++ {
|
||||
xs.Matmul(ys)
|
||||
}
|
||||
fmt.Printf("CPU time: %v\n", time.Since(startCPUTime))
|
||||
|
||||
// Cuda
|
||||
device := gotch.NewCuda()
|
||||
|
||||
// NOTE: this will call CUDA out of memory error.
|
||||
// TODO: free CUDA memory at API somewhere.
|
||||
for i := 1; i < 1000000; i++ {
|
||||
startGPUTime := time.Now()
|
||||
for i := 1; i < 100000; i++ {
|
||||
cx := xs.To(device)
|
||||
// cx.Print()
|
||||
cy := ys.To(device)
|
||||
// cy.Print()
|
||||
cx.Matmul(cy)
|
||||
}
|
||||
|
||||
fmt.Printf("GPU time: %v\n", time.Since(startGPUTime))
|
||||
}
|
||||
|
|
|
@ -1,19 +1,7 @@
|
|||
// NOTE: this is a sample for OCaml generated code for `tensor-generated.go`
|
||||
package wrapper
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR} -g -O3
|
||||
// #cgo CFLAGS: -I${SRCDIR} -O3 -Wall -Wno-unused-variable -Wno-deprecated-declarations -Wno-c++11-narrowing -g -Wno-sign-compare -Wno-unused-function
|
||||
// #cgo CFLAGS: -I/usr/local/include -I/opt/libtorch/include -I/opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo LDFLAGS: -lstdc++ -ltorch -lc10 -ltorch_cpu
|
||||
// #cgo LDFLAGS: -L/opt/libtorch/lib -L/lib64
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/lib
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc
|
||||
// #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=1
|
||||
// #cgo linux,amd64,!nogpu CFLAGS: -I/usr/local/cuda/include
|
||||
// #cgo linux,amd64,!nogpu LDFLAGS: -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcudnn -lcaffe2_nvrtc -lnvrtc-builtins -lnvrtc -lnvToolsExt -L/opt/libtorch/lib -lc10_cuda -ltorch_cuda
|
||||
// # include <cuda.h>
|
||||
// #include "stdlib.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
|
@ -28,25 +16,16 @@ import (
|
|||
func (ts Tensor) To(device gt.Device) Tensor {
|
||||
|
||||
// TODO: how to get pointer to CUDA memory???
|
||||
// Something like `C.cudaMalloc()`???
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// var cudaPtr unsafe.Pointer
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1)
|
||||
// fmt.Printf("Cuda Pointer: %v\n", &cudaPtr)
|
||||
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.cuMemAlloc(device.CInt(), 0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
// var ptr unsafe.Pointer
|
||||
// lib.AtgTo(ptr, ts.ctensor, int(device.CInt()))
|
||||
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
|
||||
// lib.AtgTo((*lib.Ctensor)(cudaPtr), ts.ctensor, int(device.CInt()))
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// return Tensor{ctensor: *(*lib.Ctensor)(unsafe.Pointer(&ptr))}
|
||||
// return Tensor{ctensor: (lib.Ctensor)(cudaPtr)}
|
||||
|
||||
return Tensor{ctensor: *ptr}
|
||||
}
|
||||
|
||||
|
|
|
@ -250,13 +250,14 @@ func (ts Tensor) Eq1(other Tensor) {
|
|||
lib.AtPrint(*ctensorPtr)
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) {
|
||||
func (ts Tensor) Matmul(other Tensor) Tensor {
|
||||
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ctensorPtr))
|
||||
lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
lib.AtPrint(*ctensorPtr)
|
||||
return Tensor{ctensor: *ctensorPtr}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user