feat(nn/sparse): added embedding layers
This commit is contained in:
parent
67a01f1294
commit
3669367b71
|
@ -475,3 +475,25 @@ func AtgRandn(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind int32, op
|
|||
|
||||
C.atg_randn(ptr, csizeDataPtr, csizeLen, coptionKind, coptionDevice)
|
||||
}
|
||||
|
||||
// void atg_embedding(tensor *, tensor weight, tensor indices, int64_t padding_idx, int scale_grad_by_freq, int sparse);
|
||||
func AtgEmbedding(ptr *Ctensor, weight Ctensor, indices Ctensor, paddingIdx int64, scaleGradByFreq int, sparse int) {
|
||||
|
||||
cpaddingIdx := *(*C.int64_t)(unsafe.Pointer(&paddingIdx))
|
||||
cscaleGradByFreq := *(*C.int)(unsafe.Pointer(&scaleGradByFreq))
|
||||
csparse := *(*C.int)(unsafe.Pointer(&sparse))
|
||||
|
||||
C.atg_embedding(ptr, weight, indices, cpaddingIdx, cscaleGradByFreq, csparse)
|
||||
}
|
||||
|
||||
// void atg_randint(tensor *, int64_t high, int64_t *size_data, int size_len, int options_kind, int options_device);
|
||||
func AtgRandint(ptr *Ctensor, high int64, sizeData []int64, sizeLen int, optionsKind int32, optionsDevice int32) {
|
||||
|
||||
chigh := *(*C.int64_t)(unsafe.Pointer(&high))
|
||||
csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
||||
coptionKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||
coptionDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
|
||||
C.atg_randint(ptr, chigh, csizeDataPtr, csizeLen, coptionKind, coptionDevice)
|
||||
}
|
||||
|
|
10
nn/init.go
10
nn/init.go
|
@ -73,11 +73,13 @@ func NewRandnInit(mean, stdev float64) randnInit {
|
|||
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||
var err error
|
||||
rd := rand.Rand{}
|
||||
rand.Seed(86)
|
||||
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = rd.NormFloat64()*r.mean + r.stdev
|
||||
data[i] = rand.NormFloat64()*r.mean + r.stdev
|
||||
}
|
||||
|
||||
retVal, err = ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
log.Fatalf("randInit - InitTensor method call error: %v\n", err)
|
||||
|
@ -98,10 +100,10 @@ func (r randnInit) Set(tensor ts.Tensor) {
|
|||
log.Fatalf("randInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
rd := rand.Rand{}
|
||||
rand.Seed(86)
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
data[i] = rd.NormFloat64()*r.mean + r.stdev
|
||||
data[i] = rand.NormFloat64()*r.mean + r.stdev
|
||||
}
|
||||
randnTs, err = ts.NewTensorFromData(data, dims)
|
||||
if err != nil {
|
||||
|
|
47
nn/sparse.go
Normal file
47
nn/sparse.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package nn
|
||||
|
||||
// Sparse layers
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Configuration option for an embedding layer.
|
||||
type EmbeddingConfig struct {
|
||||
Sparse bool
|
||||
ScaleGradByFreq bool
|
||||
WsInit Init
|
||||
PaddingIdx int64
|
||||
}
|
||||
|
||||
func DefaultEmbeddingConfig() EmbeddingConfig {
|
||||
return EmbeddingConfig{
|
||||
Sparse: false,
|
||||
ScaleGradByFreq: false,
|
||||
WsInit: NewRandnInit(0.0, 1.0),
|
||||
PaddingIdx: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// An embedding layer.
|
||||
//
|
||||
// An embedding layer acts as a simple lookup table that stores embeddings.
|
||||
// This is commonly used to store word embeddings.
|
||||
type Embedding struct {
|
||||
Ws ts.Tensor
|
||||
config EmbeddingConfig
|
||||
}
|
||||
|
||||
// NewEmbedding creates a new Embedding
|
||||
func NewEmbedding(vs Path, numEmbeddings int64, embeddingDim int64, config EmbeddingConfig) Embedding {
|
||||
return Embedding{
|
||||
Ws: vs.NewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Module interface for Embedding:
|
||||
// =========================================
|
||||
func (e Embedding) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
|
||||
}
|
64
nn/sparse_test.go
Normal file
64
nn/sparse_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package nn_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func embeddingTest(embeddingConfig nn.EmbeddingConfig, t *testing.T) {
|
||||
|
||||
var (
|
||||
batchDim int64 = 5
|
||||
seqLen int64 = 7
|
||||
inputDim int64 = 10
|
||||
outputDim int64 = 4
|
||||
)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
embeddings := nn.NewEmbedding(vs.Root(), inputDim, outputDim, embeddingConfig)
|
||||
|
||||
// Forward test
|
||||
input := ts.MustRandint(10, []int64{batchDim, seqLen}, gotch.Int64, gotch.CPU)
|
||||
output := embeddings.Forward(input)
|
||||
|
||||
want := []int64{batchDim, seqLen, outputDim}
|
||||
got := output.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Forward - Expected output shape: %v\n", want)
|
||||
t.Errorf("Forward - Got output shape: %v\n", got)
|
||||
}
|
||||
|
||||
// Padding test
|
||||
paddingIdx := embeddingConfig.PaddingIdx
|
||||
if embeddingConfig.PaddingIdx < 0 {
|
||||
paddingIdx = inputDim + embeddingConfig.PaddingIdx
|
||||
}
|
||||
|
||||
input = ts.MustOfSlice([]int64{paddingIdx})
|
||||
output = embeddings.Forward(input)
|
||||
want = []int64{1, outputDim}
|
||||
got = output.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Padding - Expected output shape: %v\n", want)
|
||||
t.Errorf("Padding - Got output shape: %v\n", got)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestEmbedding(t *testing.T) {
|
||||
|
||||
cfg := nn.DefaultEmbeddingConfig()
|
||||
embeddingTest(cfg, t)
|
||||
|
||||
cfg.PaddingIdx = -1
|
||||
embeddingTest(cfg, t)
|
||||
|
||||
cfg.PaddingIdx = 0
|
||||
embeddingTest(cfg, t)
|
||||
}
|
|
@ -1390,3 +1390,62 @@ func MustRandn(sizeData []int64, optionsKind gotch.DType, optionsDevice gotch.De
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Embedding(weight, indices Tensor, paddingIdx int64, scaleGradByFreq, sparse bool) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
cscaleGradByFreq := 0
|
||||
if scaleGradByFreq {
|
||||
cscaleGradByFreq = 1
|
||||
}
|
||||
|
||||
csparse := 0
|
||||
if sparse {
|
||||
csparse = 1
|
||||
}
|
||||
|
||||
lib.AtgEmbedding(ptr, weight.ctensor, indices.ctensor, paddingIdx, cscaleGradByFreq, csparse)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustEmbedding(weight, indices Tensor, paddingIdx int64, scaleGradByFreq, sparse bool) (retVal Tensor) {
|
||||
|
||||
retVal, err := Embedding(weight, indices, paddingIdx, scaleGradByFreq, sparse)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Randint(high int64, sizeData []int64, optionsKind gotch.DType, optionsDevice gotch.Device) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgRandint(ptr, high, sizeData, len(sizeData), optionsKind.CInt(), optionsDevice.CInt())
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustRandint(high int64, sizeData []int64, optionsKind gotch.DType, optionsDevice gotch.Device) (retVal Tensor) {
|
||||
|
||||
retVal, err := Randint(high, sizeData, optionsKind, optionsDevice)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user