feat(nn/init): completed

This commit is contained in:
sugarme 2020-06-14 12:51:38 +10:00
parent 5db0200e25
commit 3b74f1fd16
5 changed files with 320 additions and 10 deletions

View File

@ -0,0 +1,41 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/tensor"
)
func main() {
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
data := [][]int64{
{1, 1, 1, 2, 2, 2, 3, 3},
{1, 1, 1, 2, 2, 2, 4, 4},
}
shape := []int64{2, 8}
ts, err := tensor.NewTensorFromData(data, shape)
if err != nil {
log.Fatal(err)
}
ts, err = ts.To(gotch.CPU)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Tensor value BEFORE: %v\n", ts)
ts.Print()
scalarVal := tensor.IntScalar(int64(5))
ts.Fill_(scalarVal)
fmt.Printf("Tensor value AFTER: %v\n", ts)
ts.Print()
}

View File

@ -116,3 +116,26 @@ func AtgOnes(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind, optionsDe
C.atg_ones(ptr, csizeDataPtr, csizeLen, coptionsKind, coptionsDevice)
}
// void atg_uniform_(tensor *, tensor self, double from, double to);
func AtgUniform_(ptr *Ctensor, self Ctensor, from float64, to float64) {
cfrom := *(*C.double)(unsafe.Pointer(&from))
cto := *(*C.double)(unsafe.Pointer(&to))
C.atg_uniform_(ptr, self, cfrom, cto)
}
// void atg_zeros_like(tensor *, tensor self);
func AtgZerosLike(ptr *Ctensor, self Ctensor) {
C.atg_zeros_like(ptr, self)
}
// void atg_fill_(tensor *, tensor self, scalar value);
func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) {
C.atg_fill_(ptr, self, value)
}
// void atg_randn_like(tensor *, tensor self);
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
C.atg_rand_like(ptr, self)
}

View File

@ -1,4 +1,197 @@
package nn
// TODO: implement specifically
type Init interface{}
import (
"log"
"math"
"math/rand"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
type Init interface {
// creates a new tensor with specified initiation
InitTensor(dims []int, device gotch.Device) (retVal ts.Tensor)
// re-initializes (in-place) an existing tensor with the specified initiation
Set(tensor ts.Tensor)
}
// constInit:
// ==========
type constInit struct {
value float64
}
func NewConstInit(v float64) constInit {
return constInit{v}
}
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var err error
kind := gotch.DType2CInt(gotch.Float)
switch {
case c.value == 0.0:
retVal = ts.Zeros(dims, kind, device.CInt())
case c.value == 1.0:
retVal = ts.Ones(dims, kind, device.CInt())
default:
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
data[i] = c.value
}
retVal, err = ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("constInit - InitTensor method call error: %v\n", err)
}
}
return retVal
}
func (c constInit) Set(tensor ts.Tensor) {
var err error
scalarVal := ts.FloatScalar(c.value)
if err != nil {
log.Fatalf("constInit - Set method call error: %v\n", err)
}
ts.Fill_(scalarVal)
}
// randnInit :
// ===========
type randnInit struct {
mean float64
stdev float64
}
func NewRandnInit(mean, stdev float64) randnInit {
return randnInit{mean, stdev}
}
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var err error
rd := rand.Rand{}
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
data[i] = rd.NormFloat64()*r.mean + r.stdev
}
retVal, err = ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
}
return retVal
}
func (r randnInit) Set(tensor ts.Tensor) {
var (
randnTs ts.Tensor
err error
)
dims, err := tensor.Size()
if err != nil {
log.Fatalf("randInit - Set method call error: %v\n", err)
}
rd := rand.Rand{}
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
data[i] = rd.NormFloat64()*r.mean + r.stdev
}
randnTs, err = ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
}
tensor.Copy_(randnTs)
}
// uniformInit :
// =============
type uniformInit struct {
lo float64
up float64
}
func NewUniformInit(lo, up float64) uniformInit {
return uniformInit{lo, up}
}
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var err error
kind := gotch.DType2CInt(gotch.Float)
tmpTs := ts.Zeros(dims, kind, device.CInt())
retVal, err = tmpTs.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
return retVal
}
func (u uniformInit) Set(tensor ts.Tensor) {
tensor.Uniform_(u.lo, u.up)
}
// kaiminguniformInit :
// ====================
type kaimingUniformInit struct{}
func NewKaimingUniformInit() kaimingUniformInit {
return kaimingUniformInit{}
}
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
fanIn := factorial(uint64(len(dims) - 1))
bound := math.Sqrt(1.0 / float64(fanIn))
var err error
kind := gotch.DType2CInt(gotch.Float)
tmpTs := ts.Zeros(dims, kind, device.CInt())
retVal, err = tmpTs.Uniform_(-bound, bound)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
return retVal
}
func factorial(n uint64) (result uint64) {
if n > 0 {
result = n * factorial(n-1)
return result
}
return 1
}
func (k kaimingUniformInit) Set(tensor ts.Tensor) {
dims, err := tensor.Size()
if err != nil {
log.Fatalf("uniformInit - Set method call error: %v\n", err)
}
fanIn := factorial(uint64(len(dims) - 1))
bound := math.Sqrt(1.0 / float64(fanIn))
tensor.Uniform_(-bound, bound)
}
// glorotInit :
// ====================
type glorotNInit struct{}
func NewGlorotNInit() glorotNInit {
return glorotNInit{}
}
func (gl glorotNInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
// TODO: implement
return
}
func (gl glorotNInit) Set(tensor ts.Tensor) {
// TODO: implement
}

View File

@ -321,3 +321,54 @@ func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Te
return retVal, nil
}
// NOTE: `_` denotes "in-place".
func (ts Tensor) Uniform_(from float64, to float64) {
var err error
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgUniform_(ptr, ts.ctensor, from, to)
if err = TorchErr(); err != nil {
log.Fatal(err)
}
}
func (ts Tensor) ZerosLike() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgZerosLike(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) Fill_(value Scalar) {
var err error
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgFill_(ptr, ts.ctensor, value.cscalar)
if err = TorchErr(); err != nil {
log.Fatal(err)
}
}
func (ts Tensor) RandnLike() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgRandnLike(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}

View File

@ -311,7 +311,7 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
}
// DoubleValue returns a float value on tensors holding a single element.
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
@ -630,20 +630,22 @@ func (ts Tensor) MustGet(index int) (retVal Tensor) {
}
// Copy_ copies in-place values from the argument tensor to the input tensor.
func Copy_(self, src Tensor) (err error) {
func Copy_(self, src Tensor) {
var err error
lib.AtCopy_(self.ctensor, src.ctensor)
if err = TorchErr(); err != nil {
return err
log.Fatal(err)
}
return nil
}
// MustCopy_ copies in-place values from the argument tensor to the input tensor.
// It will panic if error occurred.
func MustCopy_(self, src Tensor) {
if err := Copy_(self, src); err != nil {
// Copy_ copies in-place values from the argument tensor to existing tensor
func (ts Tensor) Copy_(src Tensor) {
var err error
lib.AtCopy_(ts.ctensor, src.ctensor)
if err = TorchErr(); err != nil {
log.Fatal(err)
}
}