feat(wrapper/scalar): added scalar.go

This commit is contained in:
sugarme 2020-06-13 11:11:12 +10:00
parent 49e0335469
commit dc8c5d1a03
3 changed files with 150 additions and 0 deletions

41
example/scalar/main.go Normal file
View File

@ -0,0 +1,41 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch/wrapper"
)
func main() {
s := wrapper.FloatScalar(float64(1.23))
fmt.Printf("scalar value: %v\n", s)
intVal, err := s.ToInt()
if err != nil {
panic(err)
}
floatVal, err := s.ToFloat()
if err != nil {
panic(err)
}
strVal, err := s.ToString()
if err != nil {
panic(err)
}
fmt.Printf("scalar to int64 value: %v\n", intVal)
fmt.Printf("scalar to float64 value: %v\n", floatVal)
fmt.Printf("scalar to string value: %v\n", strVal)
s.Drop() // will set scalar to zero
fmt.Printf("scalar value: %v\n", s)
zeroVal, err := s.ToInt()
if err != nil {
log.Fatalf("Panic: %v\n", err)
}
fmt.Printf("Won't expect this val: %v\n", zeroVal)
}

44
libtch/scalar.go Normal file
View File

@ -0,0 +1,44 @@
package libtch
//#include "stdbool.h"
//#include "torch_api.h"
import "C"
import (
"unsafe"
)
// scalar ats_int(int64_t);
func AtsInt(v int64) Cscalar {
cv := *(*C.int64_t)(unsafe.Pointer(&v))
return C.ats_int(cv)
}
// scalar ats_float(double);
func AtsFloat(v float64) Cscalar {
cv := *(*C.double)(unsafe.Pointer(&v))
return C.ats_float(cv)
}
// int64_t ats_to_int(scalar);
func AtsToInt(cscalar Cscalar) int64 {
cint := C.ats_to_int(cscalar)
return *(*int64)(unsafe.Pointer(&cint))
}
// double ats_to_float(scalar);
func AtsToFloat(cscalar Cscalar) float64 {
cfloat := C.ats_to_float(cscalar)
return *(*float64)(unsafe.Pointer(&cfloat))
}
// char *ats_to_string(scalar);
func AtsToString(cscalar Cscalar) string {
charPtr := C.ats_to_string(cscalar)
return C.GoString(charPtr)
}
// void ats_free(scalar);
func AtsFree(cscalar Cscalar) {
C.ats_free(cscalar)
}

65
wrapper/scalar.go Normal file
View File

@ -0,0 +1,65 @@
package wrapper
import (
// "unsafe"
lib "github.com/sugarme/gotch/libtch"
)
type Scalar struct {
cscalar lib.Cscalar
}
// IntScalar creates a integer scalar
func IntScalar(v int64) Scalar {
cscalar := lib.AtsInt(v)
return Scalar{cscalar}
}
// FloatScalar creates a float scalar
func FloatScalar(v float64) Scalar {
cscalar := lib.AtsFloat(v)
return Scalar{cscalar}
}
// ToInt returns a integer value
func (sc Scalar) ToInt() (retVal int64, err error) {
retVal = lib.AtsToInt(sc.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
return retVal, nil
}
// ToFloat returns a float value
func (sc Scalar) ToFloat() (retVal float64, err error) {
retVal = lib.AtsToFloat(sc.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
return retVal, nil
}
// ToString returns a string representation of scalar value
func (sc Scalar) ToString() (retVal string, err error) {
retVal = lib.AtsToString(sc.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
return retVal, nil
}
// Drop sets scalar to zero and frees up C memory
//
// TODO: Really? after running s.Drop() and s.ToInt()
// it returns Zero.
func (sc Scalar) Drop() (err error) {
lib.AtsFree(sc.cscalar)
return TorchErr()
}