feat(wrapper/scalar): added scalar.go
This commit is contained in:
parent
49e0335469
commit
dc8c5d1a03
41
example/scalar/main.go
Normal file
41
example/scalar/main.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
s := wrapper.FloatScalar(float64(1.23))
|
||||
fmt.Printf("scalar value: %v\n", s)
|
||||
|
||||
intVal, err := s.ToInt()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
floatVal, err := s.ToFloat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
strVal, err := s.ToString()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("scalar to int64 value: %v\n", intVal)
|
||||
fmt.Printf("scalar to float64 value: %v\n", floatVal)
|
||||
fmt.Printf("scalar to string value: %v\n", strVal)
|
||||
|
||||
s.Drop() // will set scalar to zero
|
||||
fmt.Printf("scalar value: %v\n", s)
|
||||
|
||||
zeroVal, err := s.ToInt()
|
||||
if err != nil {
|
||||
log.Fatalf("Panic: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Won't expect this val: %v\n", zeroVal)
|
||||
}
|
44
libtch/scalar.go
Normal file
44
libtch/scalar.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package libtch
|
||||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// scalar ats_int(int64_t);
|
||||
func AtsInt(v int64) Cscalar {
|
||||
cv := *(*C.int64_t)(unsafe.Pointer(&v))
|
||||
return C.ats_int(cv)
|
||||
}
|
||||
|
||||
// scalar ats_float(double);
|
||||
func AtsFloat(v float64) Cscalar {
|
||||
cv := *(*C.double)(unsafe.Pointer(&v))
|
||||
return C.ats_float(cv)
|
||||
}
|
||||
|
||||
// int64_t ats_to_int(scalar);
|
||||
func AtsToInt(cscalar Cscalar) int64 {
|
||||
cint := C.ats_to_int(cscalar)
|
||||
return *(*int64)(unsafe.Pointer(&cint))
|
||||
}
|
||||
|
||||
// double ats_to_float(scalar);
|
||||
func AtsToFloat(cscalar Cscalar) float64 {
|
||||
cfloat := C.ats_to_float(cscalar)
|
||||
return *(*float64)(unsafe.Pointer(&cfloat))
|
||||
}
|
||||
|
||||
// char *ats_to_string(scalar);
|
||||
func AtsToString(cscalar Cscalar) string {
|
||||
charPtr := C.ats_to_string(cscalar)
|
||||
return C.GoString(charPtr)
|
||||
}
|
||||
|
||||
// void ats_free(scalar);
|
||||
func AtsFree(cscalar Cscalar) {
|
||||
C.ats_free(cscalar)
|
||||
}
|
65
wrapper/scalar.go
Normal file
65
wrapper/scalar.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package wrapper
|
||||
|
||||
import (
|
||||
// "unsafe"
|
||||
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
type Scalar struct {
|
||||
cscalar lib.Cscalar
|
||||
}
|
||||
|
||||
// IntScalar creates a integer scalar
|
||||
func IntScalar(v int64) Scalar {
|
||||
cscalar := lib.AtsInt(v)
|
||||
return Scalar{cscalar}
|
||||
}
|
||||
|
||||
// FloatScalar creates a float scalar
|
||||
func FloatScalar(v float64) Scalar {
|
||||
cscalar := lib.AtsFloat(v)
|
||||
return Scalar{cscalar}
|
||||
}
|
||||
|
||||
// ToInt returns a integer value
|
||||
func (sc Scalar) ToInt() (retVal int64, err error) {
|
||||
retVal = lib.AtsToInt(sc.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// ToFloat returns a float value
|
||||
func (sc Scalar) ToFloat() (retVal float64, err error) {
|
||||
retVal = lib.AtsToFloat(sc.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// ToString returns a string representation of scalar value
|
||||
func (sc Scalar) ToString() (retVal string, err error) {
|
||||
retVal = lib.AtsToString(sc.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Drop sets scalar to zero and frees up C memory
|
||||
//
|
||||
// TODO: Really? after running s.Drop() and s.ToInt()
|
||||
// it returns Zero.
|
||||
func (sc Scalar) Drop() (err error) {
|
||||
lib.AtsFree(sc.cscalar)
|
||||
return TorchErr()
|
||||
}
|
Loading…
Reference in New Issue
Block a user