feat(nn/sparse): added embedding layers

This commit is contained in:
sugarme 2020-06-25 16:30:00 +10:00
parent 67a01f1294
commit 3669367b71
5 changed files with 198 additions and 4 deletions

View File

@ -475,3 +475,25 @@ func AtgRandn(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind int32, op
C.atg_randn(ptr, csizeDataPtr, csizeLen, coptionKind, coptionDevice)
}
// void atg_embedding(tensor *, tensor weight, tensor indices, int64_t padding_idx, int scale_grad_by_freq, int sparse);
func AtgEmbedding(ptr *Ctensor, weight Ctensor, indices Ctensor, paddingIdx int64, scaleGradByFreq int, sparse int) {
cpaddingIdx := *(*C.int64_t)(unsafe.Pointer(&paddingIdx))
cscaleGradByFreq := *(*C.int)(unsafe.Pointer(&scaleGradByFreq))
csparse := *(*C.int)(unsafe.Pointer(&sparse))
C.atg_embedding(ptr, weight, indices, cpaddingIdx, cscaleGradByFreq, csparse)
}
// void atg_randint(tensor *, int64_t high, int64_t *size_data, int size_len, int options_kind, int options_device);
func AtgRandint(ptr *Ctensor, high int64, sizeData []int64, sizeLen int, optionsKind int32, optionsDevice int32) {
chigh := *(*C.int64_t)(unsafe.Pointer(&high))
csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
coptionKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_randint(ptr, chigh, csizeDataPtr, csizeLen, coptionKind, coptionDevice)
}

View File

@ -73,11 +73,13 @@ func NewRandnInit(mean, stdev float64) randnInit {
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
var err error
rd := rand.Rand{}
rand.Seed(86)
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
data[i] = rd.NormFloat64()*r.mean + r.stdev
data[i] = rand.NormFloat64()*r.mean + r.stdev
}
retVal, err = ts.NewTensorFromData(data, dims)
if err != nil {
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
@ -98,10 +100,10 @@ func (r randnInit) Set(tensor ts.Tensor) {
log.Fatalf("randInit - Set method call error: %v\n", err)
}
rd := rand.Rand{}
rand.Seed(86)
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
data[i] = rd.NormFloat64()*r.mean + r.stdev
data[i] = rand.NormFloat64()*r.mean + r.stdev
}
randnTs, err = ts.NewTensorFromData(data, dims)
if err != nil {

47
nn/sparse.go Normal file
View File

@ -0,0 +1,47 @@
package nn
// Sparse layers
import (
ts "github.com/sugarme/gotch/tensor"
)
// Configuration option for an embedding layer.
type EmbeddingConfig struct {
Sparse bool
ScaleGradByFreq bool
WsInit Init
PaddingIdx int64
}
func DefaultEmbeddingConfig() EmbeddingConfig {
return EmbeddingConfig{
Sparse: false,
ScaleGradByFreq: false,
WsInit: NewRandnInit(0.0, 1.0),
PaddingIdx: -1,
}
}
// An embedding layer.
//
// An embedding layer acts as a simple lookup table that stores embeddings.
// This is commonly used to store word embeddings.
type Embedding struct {
Ws ts.Tensor
config EmbeddingConfig
}
// NewEmbedding creates a new Embedding
func NewEmbedding(vs Path, numEmbeddings int64, embeddingDim int64, config EmbeddingConfig) Embedding {
return Embedding{
Ws: vs.NewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
config: config,
}
}
// Implement Module interface for Embedding:
// =========================================
func (e Embedding) Forward(xs ts.Tensor) (retVal ts.Tensor) {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
}

64
nn/sparse_test.go Normal file
View File

@ -0,0 +1,64 @@
package nn_test
import (
"reflect"
"testing"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
func embeddingTest(embeddingConfig nn.EmbeddingConfig, t *testing.T) {
var (
batchDim int64 = 5
seqLen int64 = 7
inputDim int64 = 10
outputDim int64 = 4
)
vs := nn.NewVarStore(gotch.CPU)
embeddings := nn.NewEmbedding(vs.Root(), inputDim, outputDim, embeddingConfig)
// Forward test
input := ts.MustRandint(10, []int64{batchDim, seqLen}, gotch.Int64, gotch.CPU)
output := embeddings.Forward(input)
want := []int64{batchDim, seqLen, outputDim}
got := output.MustSize()
if !reflect.DeepEqual(got, want) {
t.Errorf("Forward - Expected output shape: %v\n", want)
t.Errorf("Forward - Got output shape: %v\n", got)
}
// Padding test
paddingIdx := embeddingConfig.PaddingIdx
if embeddingConfig.PaddingIdx < 0 {
paddingIdx = inputDim + embeddingConfig.PaddingIdx
}
input = ts.MustOfSlice([]int64{paddingIdx})
output = embeddings.Forward(input)
want = []int64{1, outputDim}
got = output.MustSize()
if !reflect.DeepEqual(got, want) {
t.Errorf("Padding - Expected output shape: %v\n", want)
t.Errorf("Padding - Got output shape: %v\n", got)
}
}
func TestEmbedding(t *testing.T) {
cfg := nn.DefaultEmbeddingConfig()
embeddingTest(cfg, t)
cfg.PaddingIdx = -1
embeddingTest(cfg, t)
cfg.PaddingIdx = 0
embeddingTest(cfg, t)
}

View File

@ -1390,3 +1390,62 @@ func MustRandn(sizeData []int64, optionsKind gotch.DType, optionsDevice gotch.De
return retVal
}
func Embedding(weight, indices Tensor, paddingIdx int64, scaleGradByFreq, sparse bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
cscaleGradByFreq := 0
if scaleGradByFreq {
cscaleGradByFreq = 1
}
csparse := 0
if sparse {
csparse = 1
}
lib.AtgEmbedding(ptr, weight.ctensor, indices.ctensor, paddingIdx, cscaleGradByFreq, csparse)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustEmbedding(weight, indices Tensor, paddingIdx int64, scaleGradByFreq, sparse bool) (retVal Tensor) {
retVal, err := Embedding(weight, indices, paddingIdx, scaleGradByFreq, sparse)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Randint(high int64, sizeData []int64, optionsKind gotch.DType, optionsDevice gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgRandint(ptr, high, sizeData, len(sizeData), optionsKind.CInt(), optionsDevice.CInt())
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustRandint(high int64, sizeData []int64, optionsKind gotch.DType, optionsDevice gotch.Device) (retVal Tensor) {
retVal, err := Randint(high, sizeData, optionsKind, optionsDevice)
if err != nil {
log.Fatal(err)
}
return retVal
}