feat(nn/rnn): added rnn.go feat(nn/conv-transpose): added conv-transpose.go

This commit is contained in:
sugarme 2020-06-24 18:14:34 +10:00
parent 71fb5ae79b
commit 9817f7393a
6 changed files with 577 additions and 60 deletions

View File

@ -13,7 +13,7 @@ const (
Label int64 = 10
MnistDir string = "../../data/mnist"
epochs = 500
epochs = 200
)
func runLinear() {

View File

@ -16,9 +16,9 @@ const (
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
epochsNN = 500
epochsNN = 200
LrNN = 1e-2
LrNN = 1e-3
)
var l nn.Linear

View File

@ -387,3 +387,80 @@ func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) {
C.atg_dropout_(ptr, self, cp, ctrain)
}
// void atg_conv_transpose1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
func AtgConvTranspose1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
C.atg_conv_transpose1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
}
// void atg_conv_transpose2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
func AtgConvTranspose2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
C.atg_conv_transpose2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
}
// void atg_conv_transpose3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
func AtgConvTranspose3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
C.atg_conv_transpose3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
}
// void atg_lstm(tensor *, tensor input, tensor *hx_data, int hx_len, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first);
func AtgLstm(ctensorsPtr []*Ctensor, input Ctensor, hxData []Ctensor, hxLen int, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) {
chxDataPtr := (*Ctensor)(unsafe.Pointer(&hxData[0]))
chxLen := *(*C.int)(unsafe.Pointer(&hxLen))
cparamsDataPtr := (*Ctensor)(unsafe.Pointer(&paramsData[0]))
cparamsLen := *(*C.int)(unsafe.Pointer(&paramsLen))
chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases))
cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers))
cdropout := *(*C.double)(unsafe.Pointer(&dropout))
ctrain := *(*C.int)(unsafe.Pointer(&train))
cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional))
cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst))
C.atg_lstm(ctensorsPtr[0], input, chxDataPtr, chxLen, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst)
}
// void atg_gru(tensor *, tensor input, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first);
func AtgGru(ctensorsPtr []*Ctensor, input Ctensor, hx Ctensor, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) {
cparamsDataPtr := (*Ctensor)(unsafe.Pointer(&paramsData[0]))
cparamsLen := *(*C.int)(unsafe.Pointer(&paramsLen))
chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases))
cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers))
cdropout := *(*C.double)(unsafe.Pointer(&dropout))
ctrain := *(*C.int)(unsafe.Pointer(&train))
cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional))
cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst))
C.atg_gru(ctensorsPtr[0], input, hx, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst)
}

View File

@ -3,62 +3,55 @@ package nn
// A two dimension transposed convolution layer.
import (
"log"
ts "github.com/sugarme/gotch/tensor"
)
type ConvTranspose1DConfig struct {
Stride []int64
Padding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type ConvTranspose2DConfig struct {
Stride []int64
Padding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type ConvTranspose3DConfig struct {
Stride []int64
Padding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
// DefaultConvConfig create a default 1D ConvConfig
func DefaultConvTranspose1DConfig() ConvTranspose1DConfig {
return ConvTranspose1DConfig{
Stride: []int64{1},
Padding: []int64{0},
Dilation: []int64{1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
// DefaultConvConfig2D creates a default 2D ConvConfig
func DefaultConvTranspose2DConfig() ConvTranspose2DConfig {
return ConvTranspose2DConfig{
Stride: []int64{1, 1},
Padding: []int64{0, 0},
Dilation: []int64{1, 1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
Stride: []int64{1},
Padding: []int64{0},
OutputPadding: []int64{0},
Dilation: []int64{1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
@ -68,14 +61,18 @@ type ConvTranspose1D struct {
Config ConvTranspose1DConfig
}
func NewConvTranspose1D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose1DConfig) ConvTranspose1D {
func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose1DConfig) ConvTranspose1D {
if len(ksizes) != 1 {
log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes))
}
var conv ConvTranspose1D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, k)
weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
@ -87,14 +84,18 @@ type ConvTranspose2D struct {
Config ConvTranspose2DConfig
}
func NewConvTranspose2D(vs *Path, inDim, outDim int64, k int64, cfg ConvTranspose2DConfig) ConvTranspose2D {
func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose2DConfig) ConvTranspose2D {
if len(ksizes) != 2 {
log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes))
}
var conv ConvTranspose2D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, k, k)
weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
@ -106,14 +107,17 @@ type ConvTranspose3D struct {
Config ConvTranspose3DConfig
}
func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DConfig) ConvTranspose3D {
func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose3DConfig) ConvTranspose3D {
if len(ksizes) != 3 {
log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes))
}
var conv ConvTranspose3D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, k, k, k)
weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
@ -122,13 +126,13 @@ func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DCon
// Implement Module for Conv1D, Conv2D, Conv3D:
// ============================================
/* func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
* return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
* }
*
* func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
* return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
* }
* func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
* return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
* } */
func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
}
func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
}
func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
}

248
nn/rnn.go Normal file
View File

@ -0,0 +1,248 @@
package nn
import (
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
type State interface{}
type RNN interface {
// A zero state from which the recurrent network is usually initialized.
ZeroState(batchDim int64) State
// Applies a single step of the recurrent network.
//
// The input should have dimensions [batch_size, features].
Step(input ts.Tensor, inState State) State
// Applies multiple steps of the recurrent network.
//
// The input should have dimensions [batch_size, seq_len, features].
// The initial state is the result of applying zero_state.
Seq(input ts.Tensor) (ts.Tensor, State)
// Applies multiple steps of the recurrent network.
//
// The input should have dimensions [batch_size, seq_len, features].
SeqInit(input ts.Tensor, inState State) (ts.Tensor, State)
}
func defaultSeq(self interface{}, input ts.Tensor) (ts.Tensor, State) {
batchDim := input.MustSize()[0]
inState := self.(RNN).ZeroState(batchDim)
return self.(RNN).SeqInit(input, inState)
}
// The state for a LSTM network, this contains two tensors.
type LSTMState struct {
Tensor1 ts.Tensor
Tensor2 ts.Tensor
}
// The hidden state vector, which is also the output of the LSTM.
func (ls LSTMState) H() (retVal ts.Tensor) {
return ls.Tensor1.MustShallowClone()
}
// The cell state vector.
func (ls LSTMState) C() (retVal ts.Tensor) {
return ls.Tensor2.MustShallowClone()
}
// The GRU and LSTM layers share the same config.
// Configuration for the GRU and LSTM layers.
type RNNConfig struct {
HasBiases bool
NumLayers int64
Dropout float64
Train bool
Bidirectional bool
BatchFirst bool
}
// Default creates default RNN configuration
func DefaultRNNConfig() RNNConfig {
return RNNConfig{
HasBiases: true,
NumLayers: 1,
Dropout: float64(0.0),
Train: true,
Bidirectional: false,
BatchFirst: true,
}
}
// A Long Short-Term Memory (LSTM) layer.
//
// https://en.wikipedia.org/wiki/Long_short-term_memory
type LSTM struct {
flatWeights []ts.Tensor
hiddenDim int64
config RNNConfig
device gotch.Device
}
// NewLSTM creates a LSTM layer.
func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
var numDirections int64 = 1
if cfg.Bidirectional {
numDirections = 2
}
gateDim := 4 * hiddenDim
flatWeights := make([]ts.Tensor, 0)
for i := 0; i < int(cfg.NumLayers); i++ {
for n := 0; n < int(numDirections); n++ {
if i != 0 {
inDim = hiddenDim * numDirections
}
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
bIh := vs.Zeros("b_ih", []int64{gateDim})
bHh := vs.Zeros("b_hh", []int64{gateDim})
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
}
}
return LSTM{
flatWeights: flatWeights,
hiddenDim: hiddenDim,
config: cfg,
device: vs.Device(),
}
}
// Implement RNN interface for LSTM:
// =================================
func (l LSTM) ZeroState(batchDim int64) (retVal State) {
var numDirections int64 = 1
if l.config.Bidirectional {
numDirections = 2
}
layerDim := l.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, l.hiddenDim}
zeros := ts.MustZeros(shape, gotch.Float.CInt(), l.device.CInt())
return LSTMState{
Tensor1: zeros.MustShallowClone(),
Tensor2: zeros.MustShallowClone(),
}
}
func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
ip := input.MustUnsqueeze(1, false)
_, state := l.SeqInit(ip, inState.(LSTMState))
return state
}
func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) {
return defaultSeq(l, input)
}
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
output, h, c := input.MustLSTM([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
return output, LSTMState{
Tensor1: h,
Tensor2: c,
}
}
// GRUState is a GRU state. It contains a single tensor.
type GRUState struct {
Tensor ts.Tensor
}
func (gs GRUState) Value() ts.Tensor {
return gs.Tensor
}
// A Gated Recurrent Unit (GRU) layer.
//
// https://en.wikipedia.org/wiki/Gated_recurrent_unit
type GRU struct {
flatWeights []ts.Tensor
hiddenDim int64
config RNNConfig
device gotch.Device
}
// NewGRU create a new GRU layer
func NewGRU(vs *Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
var numDirections int64 = 1
if cfg.Bidirectional {
numDirections = 2
}
gateDim := 3 * hiddenDim
flatWeights := make([]ts.Tensor, 0)
for i := 0; i < int(cfg.NumLayers); i++ {
for n := 0; n < int(numDirections); n++ {
if i != 0 {
inDim = hiddenDim * numDirections
}
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
bIh := vs.Zeros("b_ih", []int64{gateDim})
bHh := vs.Zeros("b_hh", []int64{gateDim})
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
}
}
return GRU{
flatWeights: flatWeights,
hiddenDim: hiddenDim,
config: cfg,
device: vs.Device(),
}
}
// Implement RNN interface for GRU:
// ================================
func (g GRU) ZeroState(batchDim int64) (retVal State) {
var numDirections int64 = 1
if g.config.Bidirectional {
numDirections = 2
}
layerDim := g.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, g.hiddenDim}
return ts.MustZeros(shape, gotch.Float.CInt(), g.device.CInt())
}
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
ip := input.MustUnsqueeze(1, false)
_, state := g.SeqInit(ip, inState.(LSTMState))
return state
}
func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) {
return defaultSeq(g, input)
}
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
output, h := input.MustGRU(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
return output, GRUState{Tensor: h}
}

View File

@ -316,6 +316,15 @@ func (ts Tensor) Unsqueeze(dim int64, del bool) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustUnsqueeze(dim int64, del bool) (retVal Tensor) {
retVal, err := ts.Unsqueeze(dim, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Select creates a new tensor from current tensor given dim and index.
func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
@ -1176,3 +1185,182 @@ func (ts Tensor) Dropout_(p float64, train bool) {
log.Fatal(err)
}
}
func ConvTranspose1D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgConvTranspose1d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustConvTranspose1D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) {
retVal, err := ConvTranspose1D(input, weight, bias, stride, padding, outputPadding, dilation, groups)
if err != nil {
log.Fatal(err)
}
return retVal
}
func ConvTranspose2D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgConvTranspose2d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustConvTranspose2D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) {
retVal, err := ConvTranspose2D(input, weight, bias, stride, padding, outputPadding, dilation, groups)
if err != nil {
log.Fatal(err)
}
return retVal
}
func ConvTranspose3D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgConvTranspose3d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), outputPadding, len(outputPadding), dilation, len(dilation), groups)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustConvTranspose3D(input, weight, bias Tensor, stride, padding, outputPadding, dilation []int64, groups int64) (retVal Tensor) {
retVal, err := ConvTranspose3D(input, weight, bias, stride, padding, outputPadding, dilation, groups)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) LSTM(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor, err error) {
// NOTE: atg_lstm will return an array of 3 Ctensors
ts1Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ts2Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ts3Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
var ctensorsPtr []*lib.Ctensor
ctensorsPtr = append(ctensorsPtr, ts1Ptr, ts2Ptr, ts3Ptr)
var chxData []lib.Ctensor
for _, t := range hxData {
chxData = append(chxData, t.ctensor)
}
var cparamsData []lib.Ctensor
for _, t := range paramsData {
cparamsData = append(cparamsData, t.ctensor)
}
chasBiases := 0
if hasBiases {
chasBiases = 1
}
ctrain := 0
if train {
ctrain = 1
}
cbidirectional := 0
if bidirectional {
cbidirectional = 1
}
cbatchFirst := 0
if batchFirst {
cbatchFirst = 1
}
lib.AtgLstm(ctensorsPtr, ts.ctensor, chxData, len(hxData), cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
err = TorchErr()
if err != nil {
return output, h, c, err
}
return Tensor{ctensor: *ts1Ptr}, Tensor{ctensor: *ts2Ptr}, Tensor{ctensor: *ts3Ptr}, nil
}
func (ts Tensor) MustLSTM(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor) {
output, h, c, err := ts.LSTM(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil {
log.Fatal(err)
}
return output, h, c
}
func (ts Tensor) GRU(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor, err error) {
// NOTE: atg_gru will returns an array of 2 Ctensor
ts1Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ts2Ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
var ctensorsPtr []*lib.Ctensor
ctensorsPtr = append(ctensorsPtr, ts1Ptr, ts2Ptr)
var cparamsData []lib.Ctensor
for _, t := range paramsData {
cparamsData = append(cparamsData, t.ctensor)
}
chasBiases := 0
if hasBiases {
chasBiases = 1
}
ctrain := 0
if train {
ctrain = 1
}
cbidirectional := 0
if bidirectional {
cbidirectional = 1
}
cbatchFirst := 0
if batchFirst {
cbatchFirst = 1
}
lib.AtgGru(ctensorsPtr, ts.ctensor, hx.ctensor, cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
err = TorchErr()
if err != nil {
return output, h, err
}
return Tensor{ctensor: *ts1Ptr}, Tensor{ctensor: *ts2Ptr}, nil
}
func (ts Tensor) MustGRU(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor) {
output, h, err := ts.GRU(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil {
log.Fatal(err)
}
return output, h
}