feat(nn/rnn): added rnn.go feat(nn/conv-transpose): added conv-transpose.go
This commit is contained in:
parent
71fb5ae79b
commit
9817f7393a
|
@ -13,7 +13,7 @@ const (
|
|||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
epochs = 500
|
||||
epochs = 200
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
|
|
|
@ -16,9 +16,9 @@ const (
|
|||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
|
||||
epochsNN = 500
|
||||
epochsNN = 200
|
||||
|
||||
LrNN = 1e-2
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
||||
var l nn.Linear
|
||||
|
|
|
@ -387,3 +387,80 @@ func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) {
|
|||
|
||||
C.atg_dropout_(ptr, self, cp, ctrain)
|
||||
}
|
||||
|
||||
// void atg_conv_transpose1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
|
||||
func AtgConvTranspose1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
|
||||
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
||||
|
||||
C.atg_conv_transpose1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
|
||||
}
|
||||
|
||||
// void atg_conv_transpose2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
|
||||
func AtgConvTranspose2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
|
||||
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
||||
|
||||
C.atg_conv_transpose2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
|
||||
}
|
||||
|
||||
// void atg_conv_transpose3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
|
||||
func AtgConvTranspose3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
|
||||
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
||||
|
||||
C.atg_conv_transpose3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
|
||||
}
|
||||
|
||||
// void atg_lstm(tensor *, tensor input, tensor *hx_data, int hx_len, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first);
|
||||
func AtgLstm(ctensorsPtr []*Ctensor, input Ctensor, hxData []Ctensor, hxLen int, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) {
|
||||
|
||||
chxDataPtr := (*Ctensor)(unsafe.Pointer(&hxData[0]))
|
||||
chxLen := *(*C.int)(unsafe.Pointer(&hxLen))
|
||||
cparamsDataPtr := (*Ctensor)(unsafe.Pointer(¶msData[0]))
|
||||
cparamsLen := *(*C.int)(unsafe.Pointer(¶msLen))
|
||||
chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases))
|
||||
cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers))
|
||||
cdropout := *(*C.double)(unsafe.Pointer(&dropout))
|
||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
||||
cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional))
|
||||
cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst))
|
||||
|
||||
C.atg_lstm(ctensorsPtr[0], input, chxDataPtr, chxLen, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst)
|
||||
}
|
||||
|
||||
// void atg_gru(tensor *, tensor input, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first);
|
||||
func AtgGru(ctensorsPtr []*Ctensor, input Ctensor, hx Ctensor, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) {
|
||||
|
||||
cparamsDataPtr := (*Ctensor)(unsafe.Pointer(¶msData[0]))
|
||||
cparamsLen := *(*C.int)(unsafe.Pointer(¶msLen))
|
||||
chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases))
|
||||
cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers))
|
||||
cdropout := *(*C.double)(unsafe.Pointer(&dropout))
|
||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
||||
cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional))
|
||||
cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst))
|
||||
|
||||
C.atg_gru(ctensorsPtr[0], input, hx, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst)
|
||||
}
|
||||
|
|
|
@ -3,62 +3,55 @@ package nn
|
|||
// A two dimension transposed convolution layer.
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type ConvTranspose1DConfig struct {
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
OutputPadding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
type ConvTranspose2DConfig struct {
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
OutputPadding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
type ConvTranspose3DConfig struct {
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
OutputPadding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
// DefaultConvConfig create a default 1D ConvConfig
|
||||
func DefaultConvTranspose1DConfig() ConvTranspose1DConfig {
|
||||
return ConvTranspose1DConfig{
|
||||
Stride: []int64{1},
|
||||
Padding: []int64{0},
|
||||
Dilation: []int64{1},
|
||||
Groups: 1,
|
||||
Bias: true,
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: NewConstInit(float64(0.0)),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultConvConfig2D creates a default 2D ConvConfig
|
||||
func DefaultConvTranspose2DConfig() ConvTranspose2DConfig {
|
||||
return ConvTranspose2DConfig{
|
||||
Stride: []int64{1, 1},
|
||||
Padding: []int64{0, 0},
|
||||
Dilation: []int64{1, 1},
|
||||
Groups: 1,
|
||||
Bias: true,
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: NewConstInit(float64(0.0)),
|
||||
Stride: []int64{1},
|
||||
Padding: []int64{0},
|
||||
OutputPadding: []int64{0},
|
||||
Dilation: []int64{1},
|
||||
Groups: 1,
|
||||
Bias: true,
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: NewConstInit(float64(0.0)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,14 +61,18 @@ type ConvTranspose1D struct {
|
|||
Config ConvTranspose1DConfig
|
||||
}
|
||||
|
||||
func NewConvTranspose1D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose1DConfig) ConvTranspose1D {
|
||||
func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose1DConfig) ConvTranspose1D {
|
||||
if len(ksizes) != 1 {
|
||||
log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes))
|
||||
}
|
||||
|
||||
var conv ConvTranspose1D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, k)
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
|
@ -87,14 +84,18 @@ type ConvTranspose2D struct {
|
|||
Config ConvTranspose2DConfig
|
||||
}
|
||||
|
||||
func NewConvTranspose2D(vs *Path, inDim, outDim int64, k int64, cfg ConvTranspose2DConfig) ConvTranspose2D {
|
||||
func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose2DConfig) ConvTranspose2D {
|
||||
|
||||
if len(ksizes) != 2 {
|
||||
log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes))
|
||||
}
|
||||
var conv ConvTranspose2D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, k, k)
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
|
@ -106,14 +107,17 @@ type ConvTranspose3D struct {
|
|||
Config ConvTranspose3DConfig
|
||||
}
|
||||
|
||||
func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DConfig) ConvTranspose3D {
|
||||
func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose3DConfig) ConvTranspose3D {
|
||||
if len(ksizes) != 3 {
|
||||
log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes))
|
||||
}
|
||||
var conv ConvTranspose3D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, k, k, k)
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
|
@ -122,13 +126,13 @@ func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DCon
|
|||
// Implement Module for Conv1D, Conv2D, Conv3D:
|
||||
// ============================================
|
||||
|
||||
/* func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
* return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||
* }
|
||||
*
|
||||
* func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
* return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||
* }
|
||||
* func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
* return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||
* } */
|
||||
func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
|
||||
}
|
||||
|
||||
func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
|
||||
}
|
||||
func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
|
||||
}
|
||||
|
|
248
nn/rnn.go
Normal file
248
nn/rnn.go
Normal file
|
@ -0,0 +1,248 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type State interface{}
|
||||
|
||||
type RNN interface {
|
||||
|
||||
// A zero state from which the recurrent network is usually initialized.
|
||||
ZeroState(batchDim int64) State
|
||||
|
||||
// Applies a single step of the recurrent network.
|
||||
//
|
||||
// The input should have dimensions [batch_size, features].
|
||||
Step(input ts.Tensor, inState State) State
|
||||
|
||||
// Applies multiple steps of the recurrent network.
|
||||
//
|
||||
// The input should have dimensions [batch_size, seq_len, features].
|
||||
// The initial state is the result of applying zero_state.
|
||||
Seq(input ts.Tensor) (ts.Tensor, State)
|
||||
|
||||
// Applies multiple steps of the recurrent network.
|
||||
//
|
||||
// The input should have dimensions [batch_size, seq_len, features].
|
||||
SeqInit(input ts.Tensor, inState State) (ts.Tensor, State)
|
||||
}
|
||||
|
||||
func defaultSeq(self interface{}, input ts.Tensor) (ts.Tensor, State) {
|
||||
batchDim := input.MustSize()[0]
|
||||
inState := self.(RNN).ZeroState(batchDim)
|
||||
|
||||
return self.(RNN).SeqInit(input, inState)
|
||||
}
|
||||
|
||||
// The state for a LSTM network, this contains two tensors.
|
||||
type LSTMState struct {
|
||||
Tensor1 ts.Tensor
|
||||
Tensor2 ts.Tensor
|
||||
}
|
||||
|
||||
// The hidden state vector, which is also the output of the LSTM.
|
||||
func (ls LSTMState) H() (retVal ts.Tensor) {
|
||||
return ls.Tensor1.MustShallowClone()
|
||||
}
|
||||
|
||||
// The cell state vector.
|
||||
func (ls LSTMState) C() (retVal ts.Tensor) {
|
||||
return ls.Tensor2.MustShallowClone()
|
||||
}
|
||||
|
||||
// The GRU and LSTM layers share the same config.
|
||||
// Configuration for the GRU and LSTM layers.
|
||||
type RNNConfig struct {
|
||||
HasBiases bool
|
||||
NumLayers int64
|
||||
Dropout float64
|
||||
Train bool
|
||||
Bidirectional bool
|
||||
BatchFirst bool
|
||||
}
|
||||
|
||||
// Default creates default RNN configuration
|
||||
func DefaultRNNConfig() RNNConfig {
|
||||
return RNNConfig{
|
||||
HasBiases: true,
|
||||
NumLayers: 1,
|
||||
Dropout: float64(0.0),
|
||||
Train: true,
|
||||
Bidirectional: false,
|
||||
BatchFirst: true,
|
||||
}
|
||||
}
|
||||
|
||||
// A Long Short-Term Memory (LSTM) layer.
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Long_short-term_memory
|
||||
type LSTM struct {
|
||||
flatWeights []ts.Tensor
|
||||
hiddenDim int64
|
||||
config RNNConfig
|
||||
device gotch.Device
|
||||
}
|
||||
|
||||
// NewLSTM creates a LSTM layer.
|
||||
func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
|
||||
|
||||
var numDirections int64 = 1
|
||||
if cfg.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
|
||||
gateDim := 4 * hiddenDim
|
||||
flatWeights := make([]ts.Tensor, 0)
|
||||
|
||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||
for n := 0; n < int(numDirections); n++ {
|
||||
if i != 0 {
|
||||
inDim = hiddenDim * numDirections
|
||||
}
|
||||
|
||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
|
||||
}
|
||||
}
|
||||
|
||||
return LSTM{
|
||||
flatWeights: flatWeights,
|
||||
hiddenDim: hiddenDim,
|
||||
config: cfg,
|
||||
device: vs.Device(),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Implement RNN interface for LSTM:
|
||||
// =================================
|
||||
|
||||
func (l LSTM) ZeroState(batchDim int64) (retVal State) {
|
||||
var numDirections int64 = 1
|
||||
if l.config.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
|
||||
layerDim := l.config.NumLayers * numDirections
|
||||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||
zeros := ts.MustZeros(shape, gotch.Float.CInt(), l.device.CInt())
|
||||
|
||||
return LSTMState{
|
||||
Tensor1: zeros.MustShallowClone(),
|
||||
Tensor2: zeros.MustShallowClone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
ip := input.MustUnsqueeze(1, false)
|
||||
|
||||
_, state := l.SeqInit(ip, inState.(LSTMState))
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) {
|
||||
return defaultSeq(l, input)
|
||||
}
|
||||
|
||||
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||
|
||||
output, h, c := input.MustLSTM([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
|
||||
return output, LSTMState{
|
||||
Tensor1: h,
|
||||
Tensor2: c,
|
||||
}
|
||||
}
|
||||
|
||||
// GRUState is a GRU state. It contains a single tensor.
|
||||
type GRUState struct {
|
||||
Tensor ts.Tensor
|
||||
}
|
||||
|
||||
func (gs GRUState) Value() ts.Tensor {
|
||||
return gs.Tensor
|
||||
}
|
||||
|
||||
// A Gated Recurrent Unit (GRU) layer.
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Gated_recurrent_unit
|
||||
type GRU struct {
|
||||
flatWeights []ts.Tensor
|
||||
hiddenDim int64
|
||||
config RNNConfig
|
||||
device gotch.Device
|
||||
}
|
||||
|
||||
// NewGRU create a new GRU layer
|
||||
func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
||||
var numDirections int64 = 1
|
||||
if cfg.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
|
||||
gateDim := 3 * hiddenDim
|
||||
flatWeights := make([]ts.Tensor, 0)
|
||||
|
||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||
for n := 0; n < int(numDirections); n++ {
|
||||
if i != 0 {
|
||||
inDim = hiddenDim * numDirections
|
||||
}
|
||||
|
||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
|
||||
}
|
||||
}
|
||||
|
||||
return GRU{
|
||||
flatWeights: flatWeights,
|
||||
hiddenDim: hiddenDim,
|
||||
config: cfg,
|
||||
device: vs.Device(),
|
||||
}
|
||||
}
|
||||
|
||||
// Implement RNN interface for GRU:
|
||||
// ================================
|
||||
|
||||
func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
||||
var numDirections int64 = 1
|
||||
if g.config.Bidirectional {
|
||||
numDirections = 2
|
||||
}
|
||||
|
||||
layerDim := g.config.NumLayers * numDirections
|
||||
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
||||
|
||||
return ts.MustZeros(shape, gotch.Float.CInt(), g.device.CInt())
|
||||
}
|
||||
|
||||
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
ip := input.MustUnsqueeze(1, false)
|
||||
|
||||
_, state := g.SeqInit(ip, inState.(LSTMState))
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) {
|
||||
return defaultSeq(g, input)
|
||||
}
|
||||
|
||||
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||
|
||||
output, h := input.MustGRU(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||
|
||||
return output, GRUState{Tensor: h}
|
||||
}
|
|
@ -316,6 +316,15 @@ func (ts Tensor) Unsqueeze(dim int64, del bool) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustUnsqueeze(dim int64, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Unsqueeze(dim, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Select creates a new tensor from current tensor given dim and index.
|
||||
func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
@ -1176,3 +1185,182 @@ func (ts Tensor) Dropout_(p float64, train bool) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func ConvTranspose1D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgConvTranspose1d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustConvTranspose1D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) {
|
||||
retVal, err := ConvTranspose1D(input, weight, bias, stride, padding, outputPadding, dilation, groups)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func ConvTranspose2D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgConvTranspose2d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustConvTranspose2D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) {
|
||||
retVal, err := ConvTranspose2D(input, weight, bias, stride, padding, outputPadding, dilation, groups)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func ConvTranspose3D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgConvTranspose3d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustConvTranspose3D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) {
|
||||
retVal, err := ConvTranspose3D(input, weight, bias, stride, padding, outputPadding, dilation, groups)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) LSTM(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor, err error) {
|
||||
|
||||
// NOTE: atg_lstm will return an array of 3 Ctensors
|
||||
ts1Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
ts2Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
ts3Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
var ctensorsPtr []*lib.Ctensor
|
||||
ctensorsPtr = append(ctensorsPtr, ts1Ptr, ts2Ptr, ts3Ptr)
|
||||
|
||||
var chxData []lib.Ctensor
|
||||
for _, t := range hxData {
|
||||
chxData = append(chxData, t.ctensor)
|
||||
}
|
||||
|
||||
var cparamsData []lib.Ctensor
|
||||
for _, t := range paramsData {
|
||||
cparamsData = append(cparamsData, t.ctensor)
|
||||
}
|
||||
|
||||
chasBiases := 0
|
||||
if hasBiases {
|
||||
chasBiases = 1
|
||||
}
|
||||
ctrain := 0
|
||||
if train {
|
||||
ctrain = 1
|
||||
}
|
||||
cbidirectional := 0
|
||||
if bidirectional {
|
||||
cbidirectional = 1
|
||||
}
|
||||
cbatchFirst := 0
|
||||
if batchFirst {
|
||||
cbatchFirst = 1
|
||||
}
|
||||
|
||||
lib.AtgLstm(ctensorsPtr, ts.ctensor, chxData, len(hxData), cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return output, h, c, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ts1Ptr}, Tensor{ctensor: *ts2Ptr}, Tensor{ctensor: *ts3Ptr}, nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustLSTM(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor) {
|
||||
output, h, c, err := ts.LSTM(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return output, h, c
|
||||
}
|
||||
|
||||
func (ts Tensor) GRU(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor, err error) {
|
||||
|
||||
// NOTE: atg_gru will returns an array of 2 Ctensor
|
||||
ts1Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
ts2Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
var ctensorsPtr []*lib.Ctensor
|
||||
ctensorsPtr = append(ctensorsPtr, ts1Ptr, ts2Ptr)
|
||||
|
||||
var cparamsData []lib.Ctensor
|
||||
for _, t := range paramsData {
|
||||
cparamsData = append(cparamsData, t.ctensor)
|
||||
}
|
||||
|
||||
chasBiases := 0
|
||||
if hasBiases {
|
||||
chasBiases = 1
|
||||
}
|
||||
ctrain := 0
|
||||
if train {
|
||||
ctrain = 1
|
||||
}
|
||||
cbidirectional := 0
|
||||
if bidirectional {
|
||||
cbidirectional = 1
|
||||
}
|
||||
cbatchFirst := 0
|
||||
if batchFirst {
|
||||
cbatchFirst = 1
|
||||
}
|
||||
|
||||
lib.AtgGru(ctensorsPtr, ts.ctensor, hx.ctensor, cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return output, h, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ts1Ptr}, Tensor{ctensor: *ts2Ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustGRU(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor) {
|
||||
output, h, err := ts.GRU(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return output, h
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user