feat(nn/conv): added conv

This commit is contained in:
sugarme 2020-06-22 18:49:41 +10:00
parent 5861f3c525
commit d480f969bb
4 changed files with 255 additions and 5 deletions

View File

@ -316,3 +316,42 @@ func AtgSub1(ptr *Ctensor, self Ctensor, other Cscalar) {
func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_sub_(ptr, self, other)
}
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
C.atg_conv1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
}
// void atg_conv2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
func AtgConv2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
C.atg_conv2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
}
// void atg_conv3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
func AtgConv3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
}

137
nn/conv.go Normal file
View File

@ -0,0 +1,137 @@
package nn
// N-dimensional convolution layers.
import (
ts "github.com/sugarme/gotch/tensor"
)
type Conv1DConfig struct {
Kval int64
Stride []int64
Padding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type Conv2DConfig struct {
Kval int64
Stride []int64
Padding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type Conv3DConfig struct {
Kval int64
Stride []int64
Padding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
// DefaultConvConfig create a default 1D ConvConfig
func DefaultConv1DConfig() Conv1DConfig {
return Conv1DConfig{
Stride: []int64{1},
Padding: []int64{0},
Dilation: []int64{1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
// DefaultConvConfig2D creates a default 2D ConvConfig
func DefaultConv2DConfig() Conv2DConfig {
return Conv2DConfig{
Stride: []int64{1, 1},
Padding: []int64{0, 0},
Dilation: []int64{1, 1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
type Conv1D struct {
Ws ts.Tensor
Bs ts.Tensor // optional
Config Conv1DConfig
}
func NewConv1D(vs *Path, inDim, outDim int64, cfg Conv1DConfig) Conv1D {
var conv Conv1D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, cfg.Kval)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
}
type Conv2D struct {
Ws ts.Tensor
Bs ts.Tensor // optional
Config Conv2DConfig
}
func NewConv2D(vs *Path, inDim, outDim int64, cfg Conv2DConfig) Conv2D {
var conv Conv2D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, cfg.Kval, cfg.Kval)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
}
type Conv3D struct {
Ws ts.Tensor
Bs ts.Tensor // optional
Config Conv3DConfig
}
func NewConv3D(vs *Path, inDim, outDim int64, cfg Conv3DConfig) Conv3D {
var conv Conv3D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, cfg.Kval, cfg.Kval, cfg.Kval)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
}
// Implement Module for Conv1D, Conv2D, Conv3D:
// ============================================
func (c Conv1D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConv1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
}
func (c Conv2D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConv2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
}
func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConv3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
}

View File

@ -22,7 +22,7 @@ func TestOptimizer(t *testing.T) {
log.Fatal(err)
}
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
ys := xs.MustMul1(ts.FloatScalar(0.42), true).MustAdd1(ts.FloatScalar(1.337), true)
vs := nn.NewVarStore(gotch.CPU)
@ -39,7 +39,7 @@ func TestOptimizer(t *testing.T) {
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
@ -50,13 +50,13 @@ func TestOptimizer(t *testing.T) {
}
for i := 0; i < 50; i++ {
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
opt.BackwardStep(loss)
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
}
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
finalLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Final loss: %v\n", finalLoss)

View File

@ -1005,7 +1005,6 @@ func (ts Tensor) MustSub1(other Scalar, del bool) (retVal Tensor) {
func (ts Tensor) Sub_(other Tensor) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSub_(ptr, ts.ctensor, other.ctensor)
err := TorchErr()
@ -1013,3 +1012,78 @@ func (ts Tensor) Sub_(other Tensor) {
log.Fatal(err)
}
}
func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgConv1d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), dilation, len(dilation), groups)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustConv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor) {
retVal, err := Conv1D(input, weight, bias, stride, padding, dilation, groups)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Conv2D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgConv2d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), dilation, len(dilation), groups)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustConv2D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor) {
retVal, err := Conv2D(input, weight, bias, stride, padding, dilation, groups)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Conv3D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgConv3d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), dilation, len(dilation), groups)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustConv3D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor) {
retVal, err := Conv3D(input, weight, bias, stride, padding, dilation, groups)
if err != nil {
log.Fatal(err)
}
return retVal
}