From 9817f7393a86e3fe217ace475f93258264f9bc57 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 24 Jun 2020 18:14:34 +1000 Subject: [PATCH] feat(nn/rnn): added rnn.go feat(nn/conv-transpose): added conv-transpose.go --- example/mnist/linear.go | 2 +- example/mnist/nn.go | 4 +- libtch/c-generated-sample.go | 77 ++++++++++ nn/conv-transpose.go | 118 +++++++------- nn/rnn.go | 248 ++++++++++++++++++++++++++++++ tensor/tensor-generated-sample.go | 188 ++++++++++++++++++++++ 6 files changed, 577 insertions(+), 60 deletions(-) create mode 100644 nn/rnn.go diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 86be1a0..f30e92d 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -13,7 +13,7 @@ const ( Label int64 = 10 MnistDir string = "../../data/mnist" - epochs = 500 + epochs = 200 ) func runLinear() { diff --git a/example/mnist/nn.go b/example/mnist/nn.go index f4e0640..950d2f2 100644 --- a/example/mnist/nn.go +++ b/example/mnist/nn.go @@ -16,9 +16,9 @@ const ( LabelNN int64 = 10 MnistDirNN string = "../../data/mnist" - epochsNN = 500 + epochsNN = 200 - LrNN = 1e-2 + LrNN = 1e-3 ) var l nn.Linear diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index ed94253..b902a32 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -387,3 +387,80 @@ func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) { C.atg_dropout_(ptr, self, cp, ctrain) } + +// void atg_conv_transpose1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len); +func AtgConvTranspose1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) { + cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0])) + cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen)) + cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0])) + cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen)) + coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0])) + coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen)) + cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0])) + cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen)) + cgroups := *(*C.int64_t)(unsafe.Pointer(&groups)) + + C.atg_conv_transpose1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen) +} + +// void atg_conv_transpose2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len); +func AtgConvTranspose2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) { + cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0])) + cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen)) + cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0])) + cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen)) + coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0])) + coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen)) + cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0])) + cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen)) + cgroups := *(*C.int64_t)(unsafe.Pointer(&groups)) + + C.atg_conv_transpose2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen) +} + +// void atg_conv_transpose3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len); +func AtgConvTranspose3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) { + cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0])) + cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen)) + cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0])) + cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen)) + coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0])) + coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen)) + cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0])) + cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen)) + cgroups := *(*C.int64_t)(unsafe.Pointer(&groups)) + + C.atg_conv_transpose3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen) +} + +// void atg_lstm(tensor *, tensor input, tensor *hx_data, int hx_len, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first); +func AtgLstm(ctensorsPtr []*Ctensor, input Ctensor, hxData []Ctensor, hxLen int, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) { + + chxDataPtr := (*Ctensor)(unsafe.Pointer(&hxData[0])) + chxLen := *(*C.int)(unsafe.Pointer(&hxLen)) + cparamsDataPtr := (*Ctensor)(unsafe.Pointer(¶msData[0])) + cparamsLen := *(*C.int)(unsafe.Pointer(¶msLen)) + chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases)) + cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers)) + cdropout := *(*C.double)(unsafe.Pointer(&dropout)) + ctrain := *(*C.int)(unsafe.Pointer(&train)) + cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional)) + cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst)) + + C.atg_lstm(ctensorsPtr[0], input, chxDataPtr, chxLen, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst) +} + +// void atg_gru(tensor *, tensor input, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first); +func AtgGru(ctensorsPtr []*Ctensor, input Ctensor, hx Ctensor, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) { + + cparamsDataPtr := (*Ctensor)(unsafe.Pointer(¶msData[0])) + cparamsLen := *(*C.int)(unsafe.Pointer(¶msLen)) + chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases)) + cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers)) + cdropout := *(*C.double)(unsafe.Pointer(&dropout)) + ctrain := *(*C.int)(unsafe.Pointer(&train)) + cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional)) + cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst)) + + C.atg_gru(ctensorsPtr[0], input, hx, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst) +} diff --git a/nn/conv-transpose.go b/nn/conv-transpose.go index 422a741..82cb31b 100644 --- a/nn/conv-transpose.go +++ b/nn/conv-transpose.go @@ -3,62 +3,55 @@ package nn // A two dimension transposed convolution layer. import ( + "log" + ts "github.com/sugarme/gotch/tensor" ) type ConvTranspose1DConfig struct { - Stride []int64 - Padding []int64 - Dilation []int64 - Groups int64 - Bias bool - WsInit Init - BsInit Init + Stride []int64 + Padding []int64 + OutputPadding []int64 + Dilation []int64 + Groups int64 + Bias bool + WsInit Init + BsInit Init } type ConvTranspose2DConfig struct { - Stride []int64 - Padding []int64 - Dilation []int64 - Groups int64 - Bias bool - WsInit Init - BsInit Init + Stride []int64 + Padding []int64 + OutputPadding []int64 + Dilation []int64 + Groups int64 + Bias bool + WsInit Init + BsInit Init } type ConvTranspose3DConfig struct { - Stride []int64 - Padding []int64 - Dilation []int64 - Groups int64 - Bias bool - WsInit Init - BsInit Init + Stride []int64 + Padding []int64 + OutputPadding []int64 + Dilation []int64 + Groups int64 + Bias bool + WsInit Init + BsInit Init } // DefaultConvConfig create a default 1D ConvConfig func DefaultConvTranspose1DConfig() ConvTranspose1DConfig { return ConvTranspose1DConfig{ - Stride: []int64{1}, - Padding: []int64{0}, - Dilation: []int64{1}, - Groups: 1, - Bias: true, - WsInit: NewKaimingUniformInit(), - BsInit: NewConstInit(float64(0.0)), - } -} - -// DefaultConvConfig2D creates a default 2D ConvConfig -func DefaultConvTranspose2DConfig() ConvTranspose2DConfig { - return ConvTranspose2DConfig{ - Stride: []int64{1, 1}, - Padding: []int64{0, 0}, - Dilation: []int64{1, 1}, - Groups: 1, - Bias: true, - WsInit: NewKaimingUniformInit(), - BsInit: NewConstInit(float64(0.0)), + Stride: []int64{1}, + Padding: []int64{0}, + OutputPadding: []int64{0}, + Dilation: []int64{1}, + Groups: 1, + Bias: true, + WsInit: NewKaimingUniformInit(), + BsInit: NewConstInit(float64(0.0)), } } @@ -68,14 +61,18 @@ type ConvTranspose1D struct { Config ConvTranspose1DConfig } -func NewConvTranspose1D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose1DConfig) ConvTranspose1D { +func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose1DConfig) ConvTranspose1D { + if len(ksizes) != 1 { + log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes)) + } + var conv ConvTranspose1D conv.Config = cfg if cfg.Bias { conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) } weightSize := []int64{outDim, int64(inDim / cfg.Groups)} - weightSize = append(weightSize, k) + weightSize = append(weightSize, ksizes...) conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) return conv @@ -87,14 +84,18 @@ type ConvTranspose2D struct { Config ConvTranspose2DConfig } -func NewConvTranspose2D(vs *Path, inDim, outDim int64, k int64, cfg ConvTranspose2DConfig) ConvTranspose2D { +func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose2DConfig) ConvTranspose2D { + + if len(ksizes) != 2 { + log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes)) + } var conv ConvTranspose2D conv.Config = cfg if cfg.Bias { conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) } weightSize := []int64{outDim, int64(inDim / cfg.Groups)} - weightSize = append(weightSize, k, k) + weightSize = append(weightSize, ksizes...) conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) return conv @@ -106,14 +107,17 @@ type ConvTranspose3D struct { Config ConvTranspose3DConfig } -func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DConfig) ConvTranspose3D { +func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose3DConfig) ConvTranspose3D { + if len(ksizes) != 3 { + log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes)) + } var conv ConvTranspose3D conv.Config = cfg if cfg.Bias { conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) } weightSize := []int64{outDim, int64(inDim / cfg.Groups)} - weightSize = append(weightSize, k, k, k) + weightSize = append(weightSize, ksizes...) conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) return conv @@ -122,13 +126,13 @@ func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DCon // Implement Module for Conv1D, Conv2D, Conv3D: // ============================================ -/* func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor { - * return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) - * } - * - * func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor { - * return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) - * } - * func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor { - * return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) - * } */ +func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor { + return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups) +} + +func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor { + return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups) +} +func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor { + return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups) +} diff --git a/nn/rnn.go b/nn/rnn.go new file mode 100644 index 0000000..8978486 --- /dev/null +++ b/nn/rnn.go @@ -0,0 +1,248 @@ +package nn + +import ( + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +type State interface{} + +type RNN interface { + + // A zero state from which the recurrent network is usually initialized. + ZeroState(batchDim int64) State + + // Applies a single step of the recurrent network. + // + // The input should have dimensions [batch_size, features]. + Step(input ts.Tensor, inState State) State + + // Applies multiple steps of the recurrent network. + // + // The input should have dimensions [batch_size, seq_len, features]. + // The initial state is the result of applying zero_state. + Seq(input ts.Tensor) (ts.Tensor, State) + + // Applies multiple steps of the recurrent network. + // + // The input should have dimensions [batch_size, seq_len, features]. + SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) +} + +func defaultSeq(self interface{}, input ts.Tensor) (ts.Tensor, State) { + batchDim := input.MustSize()[0] + inState := self.(RNN).ZeroState(batchDim) + + return self.(RNN).SeqInit(input, inState) +} + +// The state for a LSTM network, this contains two tensors. +type LSTMState struct { + Tensor1 ts.Tensor + Tensor2 ts.Tensor +} + +// The hidden state vector, which is also the output of the LSTM. +func (ls LSTMState) H() (retVal ts.Tensor) { + return ls.Tensor1.MustShallowClone() +} + +// The cell state vector. +func (ls LSTMState) C() (retVal ts.Tensor) { + return ls.Tensor2.MustShallowClone() +} + +// The GRU and LSTM layers share the same config. +// Configuration for the GRU and LSTM layers. +type RNNConfig struct { + HasBiases bool + NumLayers int64 + Dropout float64 + Train bool + Bidirectional bool + BatchFirst bool +} + +// Default creates default RNN configuration +func DefaultRNNConfig() RNNConfig { + return RNNConfig{ + HasBiases: true, + NumLayers: 1, + Dropout: float64(0.0), + Train: true, + Bidirectional: false, + BatchFirst: true, + } +} + +// A Long Short-Term Memory (LSTM) layer. +// +// https://en.wikipedia.org/wiki/Long_short-term_memory +type LSTM struct { + flatWeights []ts.Tensor + hiddenDim int64 + config RNNConfig + device gotch.Device +} + +// NewLSTM creates a LSTM layer. +func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) { + + var numDirections int64 = 1 + if cfg.Bidirectional { + numDirections = 2 + } + + gateDim := 4 * hiddenDim + flatWeights := make([]ts.Tensor, 0) + + for i := 0; i < int(cfg.NumLayers); i++ { + for n := 0; n < int(numDirections); n++ { + if i != 0 { + inDim = hiddenDim * numDirections + } + + wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim}) + wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim}) + bIh := vs.Zeros("b_ih", []int64{gateDim}) + bHh := vs.Zeros("b_hh", []int64{gateDim}) + + flatWeights = append(flatWeights, wIh, wHh, bIh, bHh) + } + } + + return LSTM{ + flatWeights: flatWeights, + hiddenDim: hiddenDim, + config: cfg, + device: vs.Device(), + } + +} + +// Implement RNN interface for LSTM: +// ================================= + +func (l LSTM) ZeroState(batchDim int64) (retVal State) { + var numDirections int64 = 1 + if l.config.Bidirectional { + numDirections = 2 + } + + layerDim := l.config.NumLayers * numDirections + shape := []int64{layerDim, batchDim, l.hiddenDim} + zeros := ts.MustZeros(shape, gotch.Float.CInt(), l.device.CInt()) + + return LSTMState{ + Tensor1: zeros.MustShallowClone(), + Tensor2: zeros.MustShallowClone(), + } +} + +func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) { + ip := input.MustUnsqueeze(1, false) + + _, state := l.SeqInit(ip, inState.(LSTMState)) + + return state +} + +func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) { + return defaultSeq(l, input) +} + +func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) { + + output, h, c := input.MustLSTM([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst) + + return output, LSTMState{ + Tensor1: h, + Tensor2: c, + } +} + +// GRUState is a GRU state. It contains a single tensor. +type GRUState struct { + Tensor ts.Tensor +} + +func (gs GRUState) Value() ts.Tensor { + return gs.Tensor +} + +// A Gated Recurrent Unit (GRU) layer. +// +// https://en.wikipedia.org/wiki/Gated_recurrent_unit +type GRU struct { + flatWeights []ts.Tensor + hiddenDim int64 + config RNNConfig + device gotch.Device +} + +// NewGRU create a new GRU layer +func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) { + var numDirections int64 = 1 + if cfg.Bidirectional { + numDirections = 2 + } + + gateDim := 3 * hiddenDim + flatWeights := make([]ts.Tensor, 0) + + for i := 0; i < int(cfg.NumLayers); i++ { + for n := 0; n < int(numDirections); n++ { + if i != 0 { + inDim = hiddenDim * numDirections + } + + wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim}) + wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim}) + bIh := vs.Zeros("b_ih", []int64{gateDim}) + bHh := vs.Zeros("b_hh", []int64{gateDim}) + + flatWeights = append(flatWeights, wIh, wHh, bIh, bHh) + } + } + + return GRU{ + flatWeights: flatWeights, + hiddenDim: hiddenDim, + config: cfg, + device: vs.Device(), + } +} + +// Implement RNN interface for GRU: +// ================================ + +func (g GRU) ZeroState(batchDim int64) (retVal State) { + var numDirections int64 = 1 + if g.config.Bidirectional { + numDirections = 2 + } + + layerDim := g.config.NumLayers * numDirections + shape := []int64{layerDim, batchDim, g.hiddenDim} + + return ts.MustZeros(shape, gotch.Float.CInt(), g.device.CInt()) +} + +func (g GRU) Step(input ts.Tensor, inState State) (retVal State) { + ip := input.MustUnsqueeze(1, false) + + _, state := g.SeqInit(ip, inState.(LSTMState)) + + return state +} + +func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) { + return defaultSeq(g, input) +} + +func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) { + + output, h := input.MustGRU(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst) + + return output, GRUState{Tensor: h} +} diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index d5d0c1d..120e6b4 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -316,6 +316,15 @@ func (ts Tensor) Unsqueeze(dim int64, del bool) (retVal Tensor, err error) { return retVal, nil } +func (ts Tensor) MustUnsqueeze(dim int64, del bool) (retVal Tensor) { + retVal, err := ts.Unsqueeze(dim, del) + if err != nil { + log.Fatal(err) + } + + return retVal +} + // Select creates a new tensor from current tensor given dim and index. func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) @@ -1176,3 +1185,182 @@ func (ts Tensor) Dropout_(p float64, train bool) { log.Fatal(err) } } + +func ConvTranspose1D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + + lib.AtgConvTranspose1d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups) + + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func MustConvTranspose1D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) { + retVal, err := ConvTranspose1D(input, weight, bias, stride, padding, outputPadding, dilation, groups) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func ConvTranspose2D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + + lib.AtgConvTranspose2d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups) + + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func MustConvTranspose2D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) { + retVal, err := ConvTranspose2D(input, weight, bias, stride, padding, outputPadding, dilation, groups) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func ConvTranspose3D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + + lib.AtgConvTranspose3d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups) + + err = TorchErr() + if err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func MustConvTranspose3D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) { + retVal, err := ConvTranspose3D(input, weight, bias, stride, padding, outputPadding, dilation, groups) + + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) LSTM(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor, err error) { + + // NOTE: atg_lstm will return an array of 3 Ctensors + ts1Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + ts2Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + ts3Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + var ctensorsPtr []*lib.Ctensor + ctensorsPtr = append(ctensorsPtr, ts1Ptr, ts2Ptr, ts3Ptr) + + var chxData []lib.Ctensor + for _, t := range hxData { + chxData = append(chxData, t.ctensor) + } + + var cparamsData []lib.Ctensor + for _, t := range paramsData { + cparamsData = append(cparamsData, t.ctensor) + } + + chasBiases := 0 + if hasBiases { + chasBiases = 1 + } + ctrain := 0 + if train { + ctrain = 1 + } + cbidirectional := 0 + if bidirectional { + cbidirectional = 1 + } + cbatchFirst := 0 + if batchFirst { + cbatchFirst = 1 + } + + lib.AtgLstm(ctensorsPtr, ts.ctensor, chxData, len(hxData), cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst) + err = TorchErr() + if err != nil { + return output, h, c, err + } + + return Tensor{ctensor: *ts1Ptr}, Tensor{ctensor: *ts2Ptr}, Tensor{ctensor: *ts3Ptr}, nil + +} + +func (ts Tensor) MustLSTM(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor) { + output, h, c, err := ts.LSTM(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst) + + if err != nil { + log.Fatal(err) + } + + return output, h, c +} + +func (ts Tensor) GRU(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor, err error) { + + // NOTE: atg_gru will returns an array of 2 Ctensor + ts1Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + ts2Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + var ctensorsPtr []*lib.Ctensor + ctensorsPtr = append(ctensorsPtr, ts1Ptr, ts2Ptr) + + var cparamsData []lib.Ctensor + for _, t := range paramsData { + cparamsData = append(cparamsData, t.ctensor) + } + + chasBiases := 0 + if hasBiases { + chasBiases = 1 + } + ctrain := 0 + if train { + ctrain = 1 + } + cbidirectional := 0 + if bidirectional { + cbidirectional = 1 + } + cbatchFirst := 0 + if batchFirst { + cbatchFirst = 1 + } + + lib.AtgGru(ctensorsPtr, ts.ctensor, hx.ctensor, cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst) + err = TorchErr() + if err != nil { + return output, h, err + } + + return Tensor{ctensor: *ts1Ptr}, Tensor{ctensor: *ts2Ptr}, nil +} + +func (ts Tensor) MustGRU(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor) { + output, h, err := ts.GRU(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst) + if err != nil { + log.Fatal(err) + } + + return output, h +}