feat(nn/init): completed
This commit is contained in:
parent
5db0200e25
commit
3b74f1fd16
41
example/tensor-in-place/main.go
Normal file
41
example/tensor-in-place/main.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// TODO: Check Go type of data and tensor DType
|
||||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
shape := []int64{2, 8}
|
||||
|
||||
ts, err := tensor.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ts, err = ts.To(gotch.CPU)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Tensor value BEFORE: %v\n", ts)
|
||||
ts.Print()
|
||||
|
||||
scalarVal := tensor.IntScalar(int64(5))
|
||||
|
||||
ts.Fill_(scalarVal)
|
||||
|
||||
fmt.Printf("Tensor value AFTER: %v\n", ts)
|
||||
ts.Print()
|
||||
}
|
|
@ -116,3 +116,26 @@ func AtgOnes(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind, optionsDe
|
|||
|
||||
C.atg_ones(ptr, csizeDataPtr, csizeLen, coptionsKind, coptionsDevice)
|
||||
}
|
||||
|
||||
// void atg_uniform_(tensor *, tensor self, double from, double to);
|
||||
func AtgUniform_(ptr *Ctensor, self Ctensor, from float64, to float64) {
|
||||
cfrom := *(*C.double)(unsafe.Pointer(&from))
|
||||
cto := *(*C.double)(unsafe.Pointer(&to))
|
||||
|
||||
C.atg_uniform_(ptr, self, cfrom, cto)
|
||||
}
|
||||
|
||||
// void atg_zeros_like(tensor *, tensor self);
|
||||
func AtgZerosLike(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_zeros_like(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_fill_(tensor *, tensor self, scalar value);
|
||||
func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) {
|
||||
C.atg_fill_(ptr, self, value)
|
||||
}
|
||||
|
||||
// void atg_randn_like(tensor *, tensor self);
|
||||
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_rand_like(ptr, self)
|
||||
}
|
||||
|
|
197
nn/init.go
197
nn/init.go
|
@ -1,4 +1,197 @@
|
|||
package nn
|
||||
|
||||
// TODO: implement specifically
|
||||
type Init interface{}
|
||||
import (
|
||||
"log"
|
||||
"math"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type Init interface {
|
||||
// creates a new tensor with specified initiation
|
||||
InitTensor(dims []int, device gotch.Device) (retVal ts.Tensor)
|
||||
|
||||
// re-initializes (in-place) an existing tensor with the specified initiation
|
||||
Set(tensor ts.Tensor)
|
||||
}
|
||||
|
||||
// constInit:
|
||||
// ==========
|
||||
|
||||
type constInit struct {
|
||||
value float64
|
||||
}
|
||||
|
||||
func NewConstInit(v float64) constInit {
|
||||
return constInit{v}
|
||||
}
|
||||
|
||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var err error
|
||||
kind := gotch.DType2CInt(gotch.Float)
|
||||
switch {
|
||||
case c.value == 0.0:
|
||||
retVal = ts.Zeros(dims, kind, device.CInt())
|
||||
case c.value == 1.0:
|
||||
retVal = ts.Ones(dims, kind, device.CInt())
|
||||
default:
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = c.value
|
||||
}
|
||||
retVal, err = ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("constInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (c constInit) Set(tensor ts.Tensor) {
|
||||
var err error
|
||||
scalarVal := ts.FloatScalar(c.value)
|
||||
if err != nil {
|
||||
log.Fatalf("constInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
ts.Fill_(scalarVal)
|
||||
}
|
||||
|
||||
// randnInit :
|
||||
// ===========
|
||||
type randnInit struct {
|
||||
mean float64
|
||||
stdev float64
|
||||
}
|
||||
|
||||
func NewRandnInit(mean, stdev float64) randnInit {
|
||||
return randnInit{mean, stdev}
|
||||
}
|
||||
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var err error
|
||||
rd := rand.Rand{}
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = rd.NormFloat64()*r.mean + r.stdev
|
||||
}
|
||||
retVal, err = ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
|
||||
}
|
||||
|
||||
func (r randnInit) Set(tensor ts.Tensor) {
|
||||
var (
|
||||
randnTs ts.Tensor
|
||||
err error
|
||||
)
|
||||
|
||||
dims, err := tensor.Size()
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
rd := rand.Rand{}
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = rd.NormFloat64()*r.mean + r.stdev
|
||||
}
|
||||
randnTs, err = ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
|
||||
tensor.Copy_(randnTs)
|
||||
}
|
||||
|
||||
// uniformInit :
|
||||
// =============
|
||||
|
||||
type uniformInit struct {
|
||||
lo float64
|
||||
up float64
|
||||
}
|
||||
|
||||
func NewUniformInit(lo, up float64) uniformInit {
|
||||
return uniformInit{lo, up}
|
||||
}
|
||||
|
||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var err error
|
||||
kind := gotch.DType2CInt(gotch.Float)
|
||||
tmpTs := ts.Zeros(dims, kind, device.CInt())
|
||||
retVal, err = tmpTs.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (u uniformInit) Set(tensor ts.Tensor) {
|
||||
tensor.Uniform_(u.lo, u.up)
|
||||
}
|
||||
|
||||
// kaiminguniformInit :
|
||||
// ====================
|
||||
|
||||
type kaimingUniformInit struct{}
|
||||
|
||||
func NewKaimingUniformInit() kaimingUniformInit {
|
||||
return kaimingUniformInit{}
|
||||
}
|
||||
|
||||
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
fanIn := factorial(uint64(len(dims) - 1))
|
||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||
var err error
|
||||
kind := gotch.DType2CInt(gotch.Float)
|
||||
tmpTs := ts.Zeros(dims, kind, device.CInt())
|
||||
retVal, err = tmpTs.Uniform_(-bound, bound)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func factorial(n uint64) (result uint64) {
|
||||
if n > 0 {
|
||||
result = n * factorial(n-1)
|
||||
return result
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (k kaimingUniformInit) Set(tensor ts.Tensor) {
|
||||
dims, err := tensor.Size()
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - Set method call error: %v\n", err)
|
||||
}
|
||||
fanIn := factorial(uint64(len(dims) - 1))
|
||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||
tensor.Uniform_(-bound, bound)
|
||||
}
|
||||
|
||||
// glorotInit :
|
||||
// ====================
|
||||
type glorotNInit struct{}
|
||||
|
||||
func NewGlorotNInit() glorotNInit {
|
||||
return glorotNInit{}
|
||||
}
|
||||
|
||||
func (gl glorotNInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
// TODO: implement
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (gl glorotNInit) Set(tensor ts.Tensor) {
|
||||
// TODO: implement
|
||||
}
|
||||
|
|
|
@ -321,3 +321,54 @@ func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Te
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// NOTE: `_` denotes "in-place".
|
||||
func (ts Tensor) Uniform_(from float64, to float64) {
|
||||
var err error
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgUniform_(ptr, ts.ctensor, from, to)
|
||||
if err = TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) ZerosLike() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgZerosLike(ptr, ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Fill_(value Scalar) {
|
||||
var err error
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgFill_(ptr, ts.ctensor, value.cscalar)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) RandnLike() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgRandnLike(ptr, ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
|
@ -311,7 +311,7 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
|||
|
||||
}
|
||||
|
||||
// DoubleValue returns a float value on tensors holding a single element.
|
||||
// Float64Value returns a float value on tensors holding a single element.
|
||||
// An error is returned otherwise.
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
|
||||
|
@ -630,20 +630,22 @@ func (ts Tensor) MustGet(index int) (retVal Tensor) {
|
|||
}
|
||||
|
||||
// Copy_ copies in-place values from the argument tensor to the input tensor.
|
||||
func Copy_(self, src Tensor) (err error) {
|
||||
func Copy_(self, src Tensor) {
|
||||
var err error
|
||||
lib.AtCopy_(self.ctensor, src.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustCopy_ copies in-place values from the argument tensor to the input tensor.
|
||||
// It will panic if error occurred.
|
||||
func MustCopy_(self, src Tensor) {
|
||||
if err := Copy_(self, src); err != nil {
|
||||
// Copy_ copies in-place values from the argument tensor to existing tensor
|
||||
func (ts Tensor) Copy_(src Tensor) {
|
||||
var err error
|
||||
lib.AtCopy_(ts.ctensor, src.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user