feat(nn/conv): added conv
This commit is contained in:
parent
5861f3c525
commit
d480f969bb
|
@ -316,3 +316,42 @@ func AtgSub1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|||
func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_sub_(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
||||
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
||||
|
||||
C.atg_conv1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
||||
}
|
||||
|
||||
// void atg_conv2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
||||
func AtgConv2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
||||
|
||||
C.atg_conv2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
||||
}
|
||||
|
||||
// void atg_conv3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
||||
func AtgConv3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
||||
|
||||
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
||||
}
|
||||
|
|
137
nn/conv.go
Normal file
137
nn/conv.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package nn
|
||||
|
||||
// N-dimensional convolution layers.
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type Conv1DConfig struct {
|
||||
Kval int64
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
type Conv2DConfig struct {
|
||||
Kval int64
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
type Conv3DConfig struct {
|
||||
Kval int64
|
||||
Stride []int64
|
||||
Padding []int64
|
||||
Dilation []int64
|
||||
Groups int64
|
||||
Bias bool
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
// DefaultConvConfig create a default 1D ConvConfig
|
||||
func DefaultConv1DConfig() Conv1DConfig {
|
||||
return Conv1DConfig{
|
||||
Stride: []int64{1},
|
||||
Padding: []int64{0},
|
||||
Dilation: []int64{1},
|
||||
Groups: 1,
|
||||
Bias: true,
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: NewConstInit(float64(0.0)),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultConvConfig2D creates a default 2D ConvConfig
|
||||
func DefaultConv2DConfig() Conv2DConfig {
|
||||
return Conv2DConfig{
|
||||
Stride: []int64{1, 1},
|
||||
Padding: []int64{0, 0},
|
||||
Dilation: []int64{1, 1},
|
||||
Groups: 1,
|
||||
Bias: true,
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: NewConstInit(float64(0.0)),
|
||||
}
|
||||
}
|
||||
|
||||
type Conv1D struct {
|
||||
Ws ts.Tensor
|
||||
Bs ts.Tensor // optional
|
||||
Config Conv1DConfig
|
||||
}
|
||||
|
||||
func NewConv1D(vs *Path, inDim, outDim int64, cfg Conv1DConfig) Conv1D {
|
||||
var conv Conv1D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, cfg.Kval)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
}
|
||||
|
||||
type Conv2D struct {
|
||||
Ws ts.Tensor
|
||||
Bs ts.Tensor // optional
|
||||
Config Conv2DConfig
|
||||
}
|
||||
|
||||
func NewConv2D(vs *Path, inDim, outDim int64, cfg Conv2DConfig) Conv2D {
|
||||
var conv Conv2D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, cfg.Kval, cfg.Kval)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
}
|
||||
|
||||
type Conv3D struct {
|
||||
Ws ts.Tensor
|
||||
Bs ts.Tensor // optional
|
||||
Config Conv3DConfig
|
||||
}
|
||||
|
||||
func NewConv3D(vs *Path, inDim, outDim int64, cfg Conv3DConfig) Conv3D {
|
||||
var conv Conv3D
|
||||
conv.Config = cfg
|
||||
if cfg.Bias {
|
||||
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, cfg.Kval, cfg.Kval, cfg.Kval)
|
||||
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return conv
|
||||
}
|
||||
|
||||
// Implement Module for Conv1D, Conv2D, Conv3D:
|
||||
// ============================================
|
||||
|
||||
func (c Conv1D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return ts.MustConv1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||
}
|
||||
|
||||
func (c Conv2D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return ts.MustConv2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||
}
|
||||
func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||
return ts.MustConv3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||
}
|
|
@ -22,7 +22,7 @@ func TestOptimizer(t *testing.T) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
|
||||
ys := xs.MustMul1(ts.FloatScalar(0.42), true).MustAdd1(ts.FloatScalar(1.337), true)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
|
@ -39,7 +39,7 @@ func TestOptimizer(t *testing.T) {
|
|||
|
||||
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
|
||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
|
||||
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
|
@ -50,13 +50,13 @@ func TestOptimizer(t *testing.T) {
|
|||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
||||
}
|
||||
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
finalLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Final loss: %v\n", finalLoss)
|
||||
|
||||
|
|
|
@ -1005,7 +1005,6 @@ func (ts Tensor) MustSub1(other Scalar, del bool) (retVal Tensor) {
|
|||
|
||||
func (ts Tensor) Sub_(other Tensor) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSub_(ptr, ts.ctensor, other.ctensor)
|
||||
err := TorchErr()
|
||||
|
@ -1013,3 +1012,78 @@ func (ts Tensor) Sub_(other Tensor) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgConv1d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), dilation, len(dilation), groups)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustConv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor) {
|
||||
retVal, err := Conv1D(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Conv2D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgConv2d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), dilation, len(dilation), groups)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustConv2D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor) {
|
||||
retVal, err := Conv2D(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Conv3D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgConv3d(ptr, input.ctensor, weight.ctensor, bias.ctensor, stride, len(stride), padding, len(padding), dilation, len(dilation), groups)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustConv3D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor) {
|
||||
retVal, err := Conv3D(input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user