converted to pointer receiver at tensor APIs, tensor and nn sub-packages

This commit is contained in:
sugarme 2020-10-31 19:25:32 +11:00
parent 59ea5f0e1b
commit 75a7d89b92
32 changed files with 30763 additions and 26902 deletions

View File

@ -5,6 +5,7 @@
- **GoTch** is a C++ Libtorch Go binding for developing and implementing deep learning projects in Go. - **GoTch** is a C++ Libtorch Go binding for developing and implementing deep learning projects in Go.
- This package is to create a thin wrapper of Libtorch to make use of its tensor APIs and CUDA support while implementing as much idiomatic Go as possible. - This package is to create a thin wrapper of Libtorch to make use of its tensor APIs and CUDA support while implementing as much idiomatic Go as possible.
- There are about **1129** auto-generated tensor APIs.
## Dependencies ## Dependencies

View File

@ -1,7 +1,6 @@
(* Automatically generate the C++ -> C -> Go bindings. (* Automatically generate the C++ -> C -> Go bindings.
This takes as input the Descriptions.yaml file that gets generated when This takes as input the Descriptions.yaml file that gets generated when
func (Func.c_go_args_list func) building PyTorch from source. func (Func.c_go_args_list func) building PyTorch from source.
Run with: dune exec gen/gen.exe Run with: dune exec gen/gen.exe
*) *)
open Base open Base
@ -347,15 +346,15 @@ module Func = struct
| Bool -> "bool" | Bool -> "bool"
| Int64 -> "int64" | Int64 -> "int64"
| Double -> "float64" | Double -> "float64"
| Tensor -> "Tensor" | Tensor -> "*Tensor"
| TensorOption -> "Tensor" | TensorOption -> "*Tensor"
| IntList -> "[]int64" | IntList -> "[]int64"
| TensorList -> "[]Tensor" | TensorList -> "[]Tensor"
| String -> "string" | String -> "string"
(* TODO. Struct{Kind gotch.DType Device gotch.Device} *) (* TODO. Struct{Kind gotch.DType Device gotch.Device} *)
(* E.g. `type KindDevice struct{}` *) (* E.g. `type KindDevice struct{}` *)
| TensorOptions -> "gotch.KindDevice" | TensorOptions -> "gotch.KindDevice"
| Scalar -> "Scalar" | Scalar -> "*Scalar"
| ScalarType -> "gotch.DType" | ScalarType -> "gotch.DType"
| Device -> "gotch.Device" | Device -> "gotch.Device"
in in
@ -396,9 +395,9 @@ module Func = struct
(* printf "t name: %s\n" t.name ; *) (* printf "t name: %s\n" t.name ; *)
let returns = let returns =
match t.returns with match t.returns with
| `fixed 1 -> "retVal Tensor" | `fixed 1 -> "retVal *Tensor"
| `fixed v -> | `fixed v ->
List.init v ~f:(fun i -> Printf.sprintf "retVal%d Tensor" i) List.init v ~f:(fun i -> Printf.sprintf "retVal%d *Tensor" i)
|> String.concat ~sep:", " |> Printf.sprintf "%s" |> String.concat ~sep:", " |> Printf.sprintf "%s"
| `dynamic -> "retVal []Tensor" | `dynamic -> "retVal []Tensor"
in in
@ -698,7 +697,7 @@ let write_wrapper funcs filename =
match func.returns with match func.returns with
| `dynamic -> | `dynamic ->
pm "\n" ; pm "\n" ;
if is_method then pm "func(ts Tensor) %s(" gofunc_name if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func %s(" gofunc_name ; else pm "func %s(" gofunc_name ;
pm "%s" go_args_list ; pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ; pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
@ -714,13 +713,13 @@ let write_wrapper funcs filename =
pm " }\n" ; pm " }\n" ;
(* NOTE. if in_place method, no retVal return *) (* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then if not (Func.is_inplace func) then
pm " retVal = Tensor{ctensor: *ptr}\n" ; pm " retVal = &Tensor{ctensor: *ptr}\n" ;
pm " \n" ; pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ; pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n" pm "} \n"
| `fixed 1 -> | `fixed 1 ->
pm "\n" ; pm "\n" ;
if is_method then pm "func(ts Tensor) %s(" gofunc_name if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func %s(" gofunc_name ; else pm "func %s(" gofunc_name ;
pm "%s" go_args_list ; pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ; pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
@ -736,7 +735,7 @@ let write_wrapper funcs filename =
pm " }\n" ; pm " }\n" ;
(* NOTE. if in_place method, no retVal return *) (* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then if not (Func.is_inplace func) then
pm " retVal = Tensor{ctensor: *ptr}\n" ; pm " retVal = &Tensor{ctensor: *ptr}\n" ;
pm " \n" ; pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ; pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n" pm "} \n"
@ -804,7 +803,7 @@ let write_must_wrapper funcs filename =
match func.returns with match func.returns with
| `dynamic -> | `dynamic ->
pm "\n" ; pm "\n" ;
if is_method then pm "func(ts Tensor) %s(" gofunc_name if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func Must%s(" gofunc_name ; else pm "func Must%s(" gofunc_name ;
pm "%s" go_args_list ; pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ; pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
@ -821,7 +820,7 @@ let write_must_wrapper funcs filename =
pm "} \n" pm "} \n"
| `fixed 1 -> | `fixed 1 ->
pm "\n" ; pm "\n" ;
if is_method then pm "func(ts Tensor) Must%s(" gofunc_name if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
else pm "func Must%s(" gofunc_name ; else pm "func Must%s(" gofunc_name ;
pm "%s" go_args_list ; pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ; pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;

File diff suppressed because it is too large Load Diff

View File

@ -17,8 +17,8 @@ type BatchNormConfig struct {
BsInit Init BsInit Init
} }
func DefaultBatchNormConfig() BatchNormConfig { func DefaultBatchNormConfig() *BatchNormConfig {
return BatchNormConfig{ return &BatchNormConfig{
CudnnEnable: true, CudnnEnable: true,
Eps: 1e-5, Eps: 1e-5,
Momentum: 0.1, Momentum: 0.1,
@ -29,17 +29,17 @@ func DefaultBatchNormConfig() BatchNormConfig {
// A batch-normalization layer. // A batch-normalization layer.
type BatchNorm struct { type BatchNorm struct {
config BatchNormConfig config *BatchNormConfig
RunningMean ts.Tensor RunningMean *ts.Tensor
RunningVar ts.Tensor RunningVar *ts.Tensor
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor Bs *ts.Tensor
Nd uint Nd uint
} }
// NewBatchNorm creates a new BatchNorm layer // NewBatchNorm creates a new BatchNorm layer
func NewBatchNorm(vs Path, nd uint, outDim int64, config BatchNormConfig) BatchNorm { func NewBatchNorm(vs Path, nd uint, outDim int64, config *BatchNormConfig) *BatchNorm {
return BatchNorm{ return &BatchNorm{
config: config, config: config,
RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}), RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}),
RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}), RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}),
@ -52,7 +52,7 @@ func NewBatchNorm(vs Path, nd uint, outDim int64, config BatchNormConfig) BatchN
// //
// The input shape is assumed to be (N, C, L). Normalization // The input shape is assumed to be (N, C, L). Normalization
// is performed over the first batch dimension N. // is performed over the first batch dimension N.
func BatchNorm1D(vs Path, outDim int64, config BatchNormConfig) BatchNorm { func BatchNorm1D(vs Path, outDim int64, config *BatchNormConfig) *BatchNorm {
return NewBatchNorm(vs, 1, outDim, config) return NewBatchNorm(vs, 1, outDim, config)
} }
@ -60,7 +60,7 @@ func BatchNorm1D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
// //
// The input shape is assumed to be (N, C, H, W). Normalization // The input shape is assumed to be (N, C, H, W). Normalization
// is performed over the first batch dimension N. // is performed over the first batch dimension N.
func BatchNorm2D(vs Path, outDim int64, config BatchNormConfig) BatchNorm { func BatchNorm2D(vs Path, outDim int64, config *BatchNormConfig) *BatchNorm {
return NewBatchNorm(vs, 2, outDim, config) return NewBatchNorm(vs, 2, outDim, config)
} }
@ -68,14 +68,14 @@ func BatchNorm2D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
// //
// The input shape is assumed to be (N, C, D, H, W). Normalization // The input shape is assumed to be (N, C, D, H, W). Normalization
// is performed over the first batch dimension N. // is performed over the first batch dimension N.
func BatchNorm3D(vs Path, outDim int64, config BatchNormConfig) BatchNorm { func BatchNorm3D(vs Path, outDim int64, config *BatchNormConfig) *BatchNorm {
return NewBatchNorm(vs, 3, outDim, config) return NewBatchNorm(vs, 3, outDim, config)
} }
// Implement ModuleT interface for BatchNorm: // Implement ModuleT interface for BatchNorm:
// ========================================== // ==========================================
func (bn BatchNorm) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { func (bn *BatchNorm) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
dim := xs.Dim() dim := xs.Dim()

View File

@ -42,8 +42,8 @@ type ConvTranspose3DConfig struct {
} }
// DefaultConvConfig create a default 1D ConvConfig // DefaultConvConfig create a default 1D ConvConfig
func DefaultConvTranspose1DConfig() ConvTranspose1DConfig { func DefaultConvTranspose1DConfig() *ConvTranspose1DConfig {
return ConvTranspose1DConfig{ return &ConvTranspose1DConfig{
Stride: []int64{1}, Stride: []int64{1},
Padding: []int64{0}, Padding: []int64{0},
OutputPadding: []int64{0}, OutputPadding: []int64{0},
@ -56,83 +56,107 @@ func DefaultConvTranspose1DConfig() ConvTranspose1DConfig {
} }
type ConvTranspose1D struct { type ConvTranspose1D struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor // optional Bs *ts.Tensor // optional
Config ConvTranspose1DConfig Config *ConvTranspose1DConfig
} }
func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose1DConfig) ConvTranspose1D { func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *ConvTranspose1DConfig) *ConvTranspose1D {
if len(ksizes) != 1 { if len(ksizes) != 1 {
log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes)) log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes))
} }
var conv ConvTranspose1D var (
conv.Config = cfg ws *ts.Tensor
if cfg.Bias { bs *ts.Tensor
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) )
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...) weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
return &ConvTranspose1D{
Ws: ws,
Bs: bs,
Config: cfg,
}
} }
type ConvTranspose2D struct { type ConvTranspose2D struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor // optional Bs *ts.Tensor // optional
Config ConvTranspose2DConfig Config *ConvTranspose2DConfig
} }
func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose2DConfig) ConvTranspose2D { func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *ConvTranspose2DConfig) *ConvTranspose2D {
if len(ksizes) != 2 { if len(ksizes) != 2 {
log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes)) log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes))
} }
var conv ConvTranspose2D
conv.Config = cfg var (
ws *ts.Tensor
bs *ts.Tensor
)
if cfg.Bias { if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
} }
weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...) weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &ConvTranspose2D{
Ws: ws,
Bs: bs,
Config: cfg,
}
} }
type ConvTranspose3D struct { type ConvTranspose3D struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor // optional Bs *ts.Tensor // optional
Config ConvTranspose3DConfig Config *ConvTranspose3DConfig
} }
func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose3DConfig) ConvTranspose3D { func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *ConvTranspose3DConfig) *ConvTranspose3D {
if len(ksizes) != 3 { if len(ksizes) != 3 {
log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes)) log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes))
} }
var conv ConvTranspose3D
conv.Config = cfg var (
ws *ts.Tensor
bs *ts.Tensor
)
if cfg.Bias { if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
} }
weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...) weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &ConvTranspose3D{
Ws: ws,
Bs: bs,
Config: cfg,
}
} }
// Implement Module for Conv1D, Conv2D, Conv3D: // Implement Module for Conv1D, Conv2D, Conv3D:
// ============================================ // ============================================
func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor { func (c *ConvTranspose1D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConvTranspose1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation) return ts.MustConvTranspose1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
} }
func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor { func (c *ConvTranspose2D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConvTranspose2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation) return ts.MustConvTranspose2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
} }
func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor { func (c *ConvTranspose3D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConvTranspose3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation) return ts.MustConvTranspose3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
} }

View File

@ -40,8 +40,8 @@ type Conv3DConfig struct {
} }
// DefaultConvConfig create a default 1D ConvConfig // DefaultConvConfig create a default 1D ConvConfig
func DefaultConv1DConfig() Conv1DConfig { func DefaultConv1DConfig() *Conv1DConfig {
return Conv1DConfig{ return &Conv1DConfig{
Stride: []int64{1}, Stride: []int64{1},
Padding: []int64{0}, Padding: []int64{0},
Dilation: []int64{1}, Dilation: []int64{1},
@ -53,8 +53,8 @@ func DefaultConv1DConfig() Conv1DConfig {
} }
// DefaultConvConfig2D creates a default 2D ConvConfig // DefaultConvConfig2D creates a default 2D ConvConfig
func DefaultConv2DConfig() Conv2DConfig { func DefaultConv2DConfig() *Conv2DConfig {
return Conv2DConfig{ return &Conv2DConfig{
Stride: []int64{1, 1}, Stride: []int64{1, 1},
Padding: []int64{0, 0}, Padding: []int64{0, 0},
Dilation: []int64{1, 1}, Dilation: []int64{1, 1},
@ -66,60 +66,78 @@ func DefaultConv2DConfig() Conv2DConfig {
} }
type Conv1D struct { type Conv1D struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor // optional Bs *ts.Tensor // optional
Config Conv1DConfig Config *Conv1DConfig
} }
func NewConv1D(vs *Path, inDim, outDim, k int64, cfg Conv1DConfig) Conv1D { func NewConv1D(vs *Path, inDim, outDim, k int64, cfg *Conv1DConfig) *Conv1D {
var conv Conv1D var (
conv.Config = cfg ws *ts.Tensor
bs *ts.Tensor
)
if cfg.Bias { if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
} }
weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, k) weightSize = append(weightSize, k)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &Conv1D{
Ws: ws,
Bs: bs,
Config: cfg,
}
} }
type Conv2D struct { type Conv2D struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor // optional Bs *ts.Tensor // optional
Config Conv2DConfig Config *Conv2DConfig
} }
func NewConv2D(vs Path, inDim, outDim int64, k int64, cfg Conv2DConfig) Conv2D { func NewConv2D(vs Path, inDim, outDim int64, k int64, cfg *Conv2DConfig) *Conv2D {
var conv Conv2D var (
conv.Config = cfg ws *ts.Tensor
bs *ts.Tensor
)
if cfg.Bias { if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
} }
weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, k, k) weightSize = append(weightSize, k, k)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &Conv2D{
Ws: ws,
Bs: bs,
Config: cfg,
}
} }
type Conv3D struct { type Conv3D struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor // optional Bs *ts.Tensor // optional
Config Conv3DConfig Config *Conv3DConfig
} }
func NewConv3D(vs *Path, inDim, outDim, k int64, cfg Conv3DConfig) Conv3D { func NewConv3D(vs *Path, inDim, outDim, k int64, cfg *Conv3DConfig) *Conv3D {
var conv Conv3D var (
conv.Config = cfg ws *ts.Tensor
bs *ts.Tensor
)
if cfg.Bias { if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
} }
weightSize := []int64{outDim, int64(inDim / cfg.Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, k, k, k) weightSize = append(weightSize, k, k, k)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &Conv3D{
Ws: ws,
Bs: bs,
Config: cfg,
}
} }
type Conv interface{} type Conv interface{}
@ -175,38 +193,51 @@ func buildConvConfig(ksizes []int64) interface{} {
func NewConv(vs Path, inDim, outDim int64, ksizes []int64, config interface{}) Conv { func NewConv(vs Path, inDim, outDim int64, ksizes []int64, config interface{}) Conv {
configT := reflect.TypeOf(config) configT := reflect.TypeOf(config)
var (
ws *ts.Tensor
bs *ts.Tensor
)
switch { switch {
case len(ksizes) == 1 && configT.Name() == "Conv1DConfig": case len(ksizes) == 1 && configT.Name() == "Conv1DConfig":
var conv Conv1D cfg := config.(Conv1DConfig)
conv.Config = config.(Conv1DConfig) if cfg.Bias {
if config.(Conv1DConfig).Bias { bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
conv.Bs = vs.NewVar("bias", []int64{outDim}, config.(Conv1DConfig).BsInit)
} }
weightSize := []int64{outDim, int64(inDim / config.(Conv1DConfig).Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...) weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, config.(Conv1DConfig).WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &Conv1D{
Ws: ws,
Bs: bs,
Config: &cfg,
}
case len(ksizes) == 2 && configT.Name() == "Conv2DConfig": case len(ksizes) == 2 && configT.Name() == "Conv2DConfig":
var conv Conv2D cfg := config.(Conv2DConfig)
conv.Config = config.(Conv2DConfig) if cfg.Bias {
if config.(Conv2DConfig).Bias { bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
conv.Bs = vs.NewVar("bias", []int64{outDim}, config.(Conv2DConfig).BsInit)
} }
weightSize := []int64{outDim, int64(inDim / config.(Conv2DConfig).Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...) weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, config.(Conv2DConfig).WsInit) ws = vs.NewVar("weight", weightSize, config.(Conv2DConfig).WsInit)
return conv return &Conv2D{
Ws: ws,
Bs: bs,
Config: &cfg,
}
case len(ksizes) == 3 && configT.Name() == "Conv3DConfig": case len(ksizes) == 3 && configT.Name() == "Conv3DConfig":
var conv Conv3D cfg := config.(Conv3DConfig)
conv.Config = config.(Conv3DConfig) if cfg.Bias {
if config.(Conv3DConfig).Bias { bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
conv.Bs = vs.NewVar("bias", []int64{outDim}, config.(Conv3DConfig).BsInit)
} }
weightSize := []int64{outDim, int64(inDim / config.(Conv3DConfig).Groups)} weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...) weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, config.(Conv3DConfig).WsInit) ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv return &Conv3D{
Ws: ws,
Bs: bs,
Config: &cfg,
}
default: default:
err := fmt.Errorf("Expected nd length from 1 to 3. Got %v\n", len(ksizes)) err := fmt.Errorf("Expected nd length from 1 to 3. Got %v\n", len(ksizes))
panic(err) panic(err)
@ -216,14 +247,14 @@ func NewConv(vs Path, inDim, outDim int64, ksizes []int64, config interface{}) C
// Implement Module for Conv1D, Conv2D, Conv3D: // Implement Module for Conv1D, Conv2D, Conv3D:
// ============================================ // ============================================
func (c Conv1D) Forward(xs ts.Tensor) ts.Tensor { func (c *Conv1D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConv1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) return ts.MustConv1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
} }
func (c Conv2D) Forward(xs ts.Tensor) ts.Tensor { func (c *Conv2D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConv2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) return ts.MustConv2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
} }
func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor { func (c *Conv3D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConv3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) return ts.MustConv3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
} }
@ -232,13 +263,13 @@ func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor {
// NOTE: `train` param won't be used, will be? // NOTE: `train` param won't be used, will be?
func (c Conv1D) ForwardT(xs ts.Tensor, train bool) ts.Tensor { func (c *Conv1D) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
return ts.MustConv1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) return ts.MustConv1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
} }
func (c Conv2D) ForwardT(xs ts.Tensor, train bool) ts.Tensor { func (c *Conv2D) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
return ts.MustConv2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) return ts.MustConv2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
} }
func (c Conv3D) ForwardT(xs ts.Tensor, train bool) ts.Tensor { func (c *Conv3D) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
return ts.MustConv3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) return ts.MustConv3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
} }

View File

@ -7,36 +7,36 @@ import (
) )
type Func struct { type Func struct {
f func(ts.Tensor) ts.Tensor f func(*ts.Tensor) *ts.Tensor
} }
func NewFunc(fn func(ts.Tensor) ts.Tensor) (retVal Func) { func NewFunc(fn func(*ts.Tensor) *ts.Tensor) (retVal Func) {
return Func{f: fn} return Func{f: fn}
} }
// Implement Module interface for Func: // Implement Module interface for Func:
// ==================================== // ====================================
func (fn Func) Forward(xs ts.Tensor) (retVal ts.Tensor) { func (fn Func) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
return fn.f(xs) return fn.f(xs)
} }
// ForwardT implements ModuleT for Func object as well. // ForwardT implements ModuleT for Func object as well.
// //
// NOTE: train param will not be used. // NOTE: train param will not be used.
func (fn Func) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { func (fn Func) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
return fn.f(xs) return fn.f(xs)
} }
type FuncT struct { type FuncT struct {
f func(ts.Tensor, bool) ts.Tensor f func(*ts.Tensor, bool) *ts.Tensor
} }
func NewFuncT(fn func(ts.Tensor, bool) ts.Tensor) (retVal FuncT) { func NewFuncT(fn func(*ts.Tensor, bool) *ts.Tensor) (retVal FuncT) {
return FuncT{f: fn} return FuncT{f: fn}
} }
// Implement Module interface for Func: // Implement Module interface for Func:
// ==================================== // ====================================
func (fn FuncT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { func (fn FuncT) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
return fn.f(xs, train) return fn.f(xs, train)
} }

View File

@ -11,10 +11,10 @@ import (
type Init interface { type Init interface {
// creates a new tensor with specified initiation // creates a new tensor with specified initiation
InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor)
// re-initializes (in-place) an existing tensor with the specified initiation // re-initializes (in-place) an existing tensor with the specified initiation
Set(tensor ts.Tensor) Set(tensor *ts.Tensor)
} }
// constInit: // constInit:
@ -28,7 +28,7 @@ func NewConstInit(v float64) constInit {
return constInit{v} return constInit{v}
} }
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
var err error var err error
kind := gotch.Float kind := gotch.Float
switch { switch {
@ -50,7 +50,7 @@ func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tens
return retVal return retVal
} }
func (c constInit) Set(tensor ts.Tensor) { func (c constInit) Set(tensor *ts.Tensor) {
var err error var err error
scalarVal := ts.FloatScalar(c.value) scalarVal := ts.FloatScalar(c.value)
if err != nil { if err != nil {
@ -71,7 +71,7 @@ func NewRandnInit(mean, stdev float64) randnInit {
return randnInit{mean, stdev} return randnInit{mean, stdev}
} }
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
var err error var err error
rand.Seed(86) rand.Seed(86)
@ -92,9 +92,9 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tens
} }
func (r randnInit) Set(tensor ts.Tensor) { func (r randnInit) Set(tensor *ts.Tensor) {
var ( var (
randnTs ts.Tensor randnTs *ts.Tensor
err error err error
) )
@ -128,7 +128,7 @@ func NewUniformInit(lo, up float64) uniformInit {
return uniformInit{lo, up} return uniformInit{lo, up}
} }
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
var err error var err error
kind := gotch.Float kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device) retVal = ts.MustZeros(dims, kind, device)
@ -139,7 +139,7 @@ func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Te
return retVal return retVal
} }
func (u uniformInit) Set(tensor ts.Tensor) { func (u uniformInit) Set(tensor *ts.Tensor) {
tensor.Uniform_(u.lo, u.up) tensor.Uniform_(u.lo, u.up)
} }
@ -152,7 +152,7 @@ func NewKaimingUniformInit() kaimingUniformInit {
return kaimingUniformInit{} return kaimingUniformInit{}
} }
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
var fanIn int64 var fanIn int64
if len(dims) == 0 { if len(dims) == 0 {
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims) log.Fatalf("KaimingUniformInit method call: dims (%v) should have length >= 1", dims)
@ -191,7 +191,7 @@ func factorial(n int64) (result int64) {
return 1 return 1
} }
func (k kaimingUniformInit) Set(tensor ts.Tensor) { func (k kaimingUniformInit) Set(tensor *ts.Tensor) {
dims, err := tensor.Size() dims, err := tensor.Size()
if err != nil { if err != nil {
log.Fatalf("uniformInit - Set method call error: %v\n", err) log.Fatalf("uniformInit - Set method call error: %v\n", err)
@ -218,12 +218,12 @@ func NewGlorotNInit() glorotNInit {
return glorotNInit{} return glorotNInit{}
} }
func (gl glorotNInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { func (gl glorotNInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
// TODO: implement // TODO: implement
return return
} }
func (gl glorotNInit) Set(tensor ts.Tensor) { func (gl glorotNInit) Set(tensor *ts.Tensor) {
// TODO: implement // TODO: implement
} }

View File

@ -14,8 +14,8 @@ type LayerNormConfig struct {
BsInit Init BsInit Init
} }
func DefaultLayerNormConfig() LayerNormConfig { func DefaultLayerNormConfig() *LayerNormConfig {
return LayerNormConfig{ return &LayerNormConfig{
CudnnEnable: true, CudnnEnable: true,
Eps: 1e-5, Eps: 1e-5,
ElementwiseAffine: true, ElementwiseAffine: true,
@ -26,30 +26,30 @@ func DefaultLayerNormConfig() LayerNormConfig {
// A layer-normalization layer. // A layer-normalization layer.
type LayerNorm struct { type LayerNorm struct {
Config LayerNormConfig Config *LayerNormConfig
Ws ts.Tensor // optional Ws *ts.Tensor // optional
Bs ts.Tensor // optional Bs *ts.Tensor // optional
NormalizedShape []int64 NormalizedShape []int64
} }
func NewLayerNorm(vs Path, normalizedShape []int64, config LayerNormConfig) LayerNorm { func NewLayerNorm(vs Path, normalizedShape []int64, config *LayerNormConfig) *LayerNorm {
var ( var (
ws ts.Tensor ws *ts.Tensor
bs ts.Tensor bs *ts.Tensor
) )
if config.ElementwiseAffine { if config.ElementwiseAffine {
ws = vs.NewVar("weight", normalizedShape, config.WsInit) ws = vs.NewVar("weight", normalizedShape, config.WsInit)
bs = vs.NewVar("bias", normalizedShape, config.BsInit) bs = vs.NewVar("bias", normalizedShape, config.BsInit)
} }
return LayerNorm{config, ws, bs, normalizedShape} return &LayerNorm{config, ws, bs, normalizedShape}
} }
// Implement Module interface for LayerNorm: // Implement Module interface for LayerNorm:
// ========================================= // =========================================
func (ln LayerNorm) Forward(xs ts.Tensor) (retVal ts.Tensor) { func (ln *LayerNorm) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
return ts.MustLayerNorm(xs, ln.NormalizedShape, ln.Ws, ln.Bs, ln.Config.Eps, ln.Config.CudnnEnable) return ts.MustLayerNorm(xs, ln.NormalizedShape, ln.Ws, ln.Bs, ln.Config.Eps, ln.Config.CudnnEnable)
} }

View File

@ -18,8 +18,8 @@ type LinearConfig struct {
// DefaultLinearConfig creates default LinearConfig with // DefaultLinearConfig creates default LinearConfig with
// weights initiated using KaimingUniform and Bias is set to true // weights initiated using KaimingUniform and Bias is set to true
func DefaultLinearConfig() LinearConfig { func DefaultLinearConfig() *LinearConfig {
return LinearConfig{ return &LinearConfig{
WsInit: NewKaimingUniformInit(), WsInit: NewKaimingUniformInit(),
BsInit: nil, BsInit: nil,
Bias: true, Bias: true,
@ -28,8 +28,8 @@ func DefaultLinearConfig() LinearConfig {
// Linear is a linear fully-connected layer // Linear is a linear fully-connected layer
type Linear struct { type Linear struct {
Ws ts.Tensor Ws *ts.Tensor
Bs ts.Tensor Bs *ts.Tensor
} }
// NewLinear creates a new linear layer // NewLinear creates a new linear layer
@ -37,9 +37,9 @@ type Linear struct {
// inDim - input dimension (x) [input features - columns] // inDim - input dimension (x) [input features - columns]
// outDim - output dimension (y) [output features - columns] // outDim - output dimension (y) [output features - columns]
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim} // NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear { func NewLinear(vs Path, inDim, outDim int64, c *LinearConfig) *Linear {
var bs ts.Tensor var bs *ts.Tensor
// bs has size of output dimension // bs has size of output dimension
switch c.Bias { switch c.Bias {
case false: case false:
@ -55,7 +55,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
} }
} }
return Linear{ return &Linear{
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false), Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
Bs: bs, Bs: bs,
} }
@ -89,7 +89,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
// 1 1 1 // 1 1 1
// 1 1 1 // 1 1 1
// 1 1 1 ] // 1 1 1 ]
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) { func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false) mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true) return mul.MustAdd(l.Bs, true)
@ -98,7 +98,7 @@ func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
// ForwardT implements ModuleT interface for Linear layer. // ForwardT implements ModuleT interface for Linear layer.
// //
// NOTE: train param will not be used. // NOTE: train param will not be used.
func (l Linear) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false) mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true) return mul.MustAdd(l.Bs, true)

View File

@ -10,7 +10,7 @@ import (
// Optimizer is a struct object to run gradient descent. // Optimizer is a struct object to run gradient descent.
type Optimizer struct { type Optimizer struct {
opt ts.COptimizer opt *ts.COptimizer
// variables Variables // having embedded sync.Mutex // variables Variables // having embedded sync.Mutex
variablesInOptimizer uint8 variablesInOptimizer uint8
config interface{} config interface{}
@ -18,7 +18,7 @@ type Optimizer struct {
// OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer. // OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer.
type OptimizerConfig interface { type OptimizerConfig interface {
buildCOpt(lr float64) (retVal ts.COptimizer, err error) buildCOpt(lr float64) (*ts.COptimizer, error)
// Build builds an optimizer with the specified learning rate handling variables stored in `vs`. // Build builds an optimizer with the specified learning rate handling variables stored in `vs`.
// //
@ -29,11 +29,11 @@ type OptimizerConfig interface {
// (config AdamOptimizerConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err error){ // (config AdamOptimizerConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err error){
// return defaultBuild(config, vs, lr) // return defaultBuild(config, vs, lr)
// } // }
Build(vs VarStore, lr float64) (retVal Optimizer, err error) Build(vs *VarStore, lr float64) (*Optimizer, error)
} }
// defaultBuild is `default` Build method for OptimizerConfig interface // defaultBuild is `default` Build method for OptimizerConfig interface
func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optimizer, err error) { func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (retVal *Optimizer, err error) {
opt, err := config.buildCOpt(lr) opt, err := config.buildCOpt(lr)
if err != nil { if err != nil {
@ -43,7 +43,7 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
var parameters []ts.Tensor var parameters []ts.Tensor
for _, v := range vs.Vars.TrainableVariables { for _, v := range vs.Vars.TrainableVariables {
param := v.MustShallowClone() param := v.MustShallowClone()
parameters = append(parameters, param) parameters = append(parameters, *param)
} }
if len(vs.Vars.TrainableVariables) > 0 { if len(vs.Vars.TrainableVariables) > 0 {
@ -54,7 +54,7 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
// TODO: should we clone or copy? // TODO: should we clone or copy?
return Optimizer{ return &Optimizer{
opt: opt, opt: opt,
// variables: vs.Vars, // variables: vs.Vars,
variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)), variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)),
@ -74,8 +74,8 @@ type SGDConfig struct {
} }
// DefaultSGDConfig creates SGDConfig with default values. // DefaultSGDConfig creates SGDConfig with default values.
func DefaultSGDConfig() SGDConfig { func DefaultSGDConfig() *SGDConfig {
return SGDConfig{ return &SGDConfig{
Momentum: 0.0, Momentum: 0.0,
Dampening: 0.0, Dampening: 0.0,
Wd: 0.0, Wd: 0.0,
@ -84,8 +84,8 @@ func DefaultSGDConfig() SGDConfig {
} }
// NewSGD creates the configuration for a SGD optimizer with specified values // NewSGD creates the configuration for a SGD optimizer with specified values
func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) (retVal SGDConfig) { func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) *SGDConfig {
return SGDConfig{ return &SGDConfig{
Momentum: momentum, Momentum: momentum,
Dampening: dampening, Dampening: dampening,
Wd: wd, Wd: wd,
@ -94,11 +94,11 @@ func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) (retVal SGDCon
} }
// Implement OptimizerConfig interface for SGDConfig // Implement OptimizerConfig interface for SGDConfig
func (c SGDConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) { func (c *SGDConfig) buildCOpt(lr float64) (*ts.COptimizer, error) {
return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov) return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov)
} }
func (c SGDConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err error) { func (c *SGDConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
return defaultBuild(c, vs, lr) return defaultBuild(c, vs, lr)
} }
@ -112,8 +112,8 @@ type AdamConfig struct {
} }
// DefaultAdamConfig creates AdamConfig with default values // DefaultAdamConfig creates AdamConfig with default values
func DefaultAdamConfig() AdamConfig { func DefaultAdamConfig() *AdamConfig {
return AdamConfig{ return &AdamConfig{
Beta1: 0.9, Beta1: 0.9,
Beta2: 0.999, Beta2: 0.999,
Wd: 0.0, Wd: 0.0,
@ -121,8 +121,8 @@ func DefaultAdamConfig() AdamConfig {
} }
// NewAdamConfig creates AdamConfig with specified values // NewAdamConfig creates AdamConfig with specified values
func NewAdamConfig(beta1, beta2, wd float64) AdamConfig { func NewAdamConfig(beta1, beta2, wd float64) *AdamConfig {
return AdamConfig{ return &AdamConfig{
Beta1: beta1, Beta1: beta1,
Beta2: beta2, Beta2: beta2,
Wd: wd, Wd: wd,
@ -130,11 +130,11 @@ func NewAdamConfig(beta1, beta2, wd float64) AdamConfig {
} }
// Implement OptimizerConfig interface for AdamConfig // Implement OptimizerConfig interface for AdamConfig
func (c AdamConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) { func (c *AdamConfig) buildCOpt(lr float64) (*ts.COptimizer, error) {
return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd) return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd)
} }
func (c AdamConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err error) { func (c *AdamConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
return defaultBuild(c, vs, lr) return defaultBuild(c, vs, lr)
} }
@ -150,8 +150,8 @@ type RMSPropConfig struct {
} }
// DefaultAdamConfig creates AdamConfig with default values // DefaultAdamConfig creates AdamConfig with default values
func DefaultRMSPropConfig() RMSPropConfig { func DefaultRMSPropConfig() *RMSPropConfig {
return RMSPropConfig{ return &RMSPropConfig{
Alpha: 0.99, Alpha: 0.99,
Eps: 1e-8, Eps: 1e-8,
Wd: 0.0, Wd: 0.0,
@ -161,8 +161,8 @@ func DefaultRMSPropConfig() RMSPropConfig {
} }
// NewRMSPropConfig creates RMSPropConfig with specified values // NewRMSPropConfig creates RMSPropConfig with specified values
func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) RMSPropConfig { func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) *RMSPropConfig {
return RMSPropConfig{ return &RMSPropConfig{
Alpha: alpha, Alpha: alpha,
Eps: eps, Eps: eps,
Wd: wd, Wd: wd,
@ -172,11 +172,11 @@ func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) RMSPropCo
} }
// Implement OptimizerConfig interface for RMSPropConfig // Implement OptimizerConfig interface for RMSPropConfig
func (c RMSPropConfig) buildCOpt(lr float64) (retVal ts.COptimizer, err error) { func (c *RMSPropConfig) buildCOpt(lr float64) (*ts.COptimizer, error) {
return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered) return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered)
} }
func (c RMSPropConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err error) { func (c *RMSPropConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
return defaultBuild(c, vs, lr) return defaultBuild(c, vs, lr)
} }
@ -229,7 +229,7 @@ func (opt *Optimizer) Step() {
} }
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step. // BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
func (opt *Optimizer) BackwardStep(loss ts.Tensor) { func (opt *Optimizer) BackwardStep(loss *ts.Tensor) {
opt.addMissingVariables() opt.addMissingVariables()
@ -250,7 +250,7 @@ func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step. // BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
// //
// The gradients are clipped based on `max` before being applied. // The gradients are clipped based on `max` before being applied.
func (opt *Optimizer) BackwardStepClip(loss ts.Tensor, max float64) { func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) {
opt.addMissingVariables() opt.addMissingVariables()
err := opt.opt.ZeroGrad() err := opt.opt.ZeroGrad()

View File

@ -15,33 +15,33 @@ type RNN interface {
// Applies a single step of the recurrent network. // Applies a single step of the recurrent network.
// //
// The input should have dimensions [batch_size, features]. // The input should have dimensions [batch_size, features].
Step(input ts.Tensor, inState State) State Step(input *ts.Tensor, inState State) State
// Applies multiple steps of the recurrent network. // Applies multiple steps of the recurrent network.
// //
// The input should have dimensions [batch_size, seq_len, features]. // The input should have dimensions [batch_size, seq_len, features].
// The initial state is the result of applying zero_state. // The initial state is the result of applying zero_state.
Seq(input ts.Tensor) (ts.Tensor, State) Seq(input *ts.Tensor) (*ts.Tensor, State)
// Applies multiple steps of the recurrent network. // Applies multiple steps of the recurrent network.
// //
// The input should have dimensions [batch_size, seq_len, features]. // The input should have dimensions [batch_size, seq_len, features].
SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State)
} }
// The state for a LSTM network, this contains two tensors. // The state for a LSTM network, this contains two tensors.
type LSTMState struct { type LSTMState struct {
Tensor1 ts.Tensor Tensor1 *ts.Tensor
Tensor2 ts.Tensor Tensor2 *ts.Tensor
} }
// The hidden state vector, which is also the output of the LSTM. // The hidden state vector, which is also the output of the LSTM.
func (ls LSTMState) H() (retVal ts.Tensor) { func (ls *LSTMState) H() *ts.Tensor {
return ls.Tensor1.MustShallowClone() return ls.Tensor1.MustShallowClone()
} }
// The cell state vector. // The cell state vector.
func (ls LSTMState) C() (retVal ts.Tensor) { func (ls *LSTMState) C() *ts.Tensor {
return ls.Tensor2.MustShallowClone() return ls.Tensor2.MustShallowClone()
} }
@ -57,8 +57,8 @@ type RNNConfig struct {
} }
// Default creates default RNN configuration // Default creates default RNN configuration
func DefaultRNNConfig() RNNConfig { func DefaultRNNConfig() *RNNConfig {
return RNNConfig{ return &RNNConfig{
HasBiases: true, HasBiases: true,
NumLayers: 1, NumLayers: 1,
Dropout: float64(0.0), Dropout: float64(0.0),
@ -74,12 +74,12 @@ func DefaultRNNConfig() RNNConfig {
type LSTM struct { type LSTM struct {
flatWeights []ts.Tensor flatWeights []ts.Tensor
hiddenDim int64 hiddenDim int64
config RNNConfig config *RNNConfig
device gotch.Device device gotch.Device
} }
// NewLSTM creates a LSTM layer. // NewLSTM creates a LSTM layer.
func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) { func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
var numDirections int64 = 1 var numDirections int64 = 1
if cfg.Bidirectional { if cfg.Bidirectional {
@ -100,7 +100,7 @@ func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
bIh := vs.Zeros("b_ih", []int64{gateDim}) bIh := vs.Zeros("b_ih", []int64{gateDim})
bHh := vs.Zeros("b_hh", []int64{gateDim}) bHh := vs.Zeros("b_hh", []int64{gateDim})
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh) flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
} }
} }
@ -112,7 +112,7 @@ func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional) ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
} }
return LSTM{ return &LSTM{
flatWeights: flatWeights, flatWeights: flatWeights,
hiddenDim: hiddenDim, hiddenDim: hiddenDim,
config: cfg, config: cfg,
@ -124,7 +124,7 @@ func NewLSTM(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal LSTM) {
// Implement RNN interface for LSTM: // Implement RNN interface for LSTM:
// ================================= // =================================
func (l LSTM) ZeroState(batchDim int64) (retVal State) { func (l *LSTM) ZeroState(batchDim int64) (retVal State) {
var numDirections int64 = 1 var numDirections int64 = 1
if l.config.Bidirectional { if l.config.Bidirectional {
numDirections = 2 numDirections = 2
@ -144,7 +144,7 @@ func (l LSTM) ZeroState(batchDim int64) (retVal State) {
return retVal return retVal
} }
func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) { func (l *LSTM) Step(input *ts.Tensor, inState State) (retVal State) {
ip := input.MustUnsqueeze(1, false) ip := input.MustUnsqueeze(1, false)
output, state := l.SeqInit(ip, inState) output, state := l.SeqInit(ip, inState)
@ -156,7 +156,7 @@ func (l LSTM) Step(input ts.Tensor, inState State) (retVal State) {
return state return state
} }
func (l LSTM) Seq(input ts.Tensor) (output ts.Tensor, state State) { func (l *LSTM) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
batchDim := input.MustSize()[0] batchDim := input.MustSize()[0]
inState := l.ZeroState(batchDim) inState := l.ZeroState(batchDim)
@ -169,9 +169,9 @@ func (l LSTM) Seq(input ts.Tensor) (output ts.Tensor, state State) {
return output, state return output, state
} }
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) { func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
output, h, c := input.MustLstm([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst) output, h, c := input.MustLstm([]ts.Tensor{*inState.(LSTMState).Tensor1, *inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
return output, LSTMState{ return output, LSTMState{
Tensor1: h, Tensor1: h,
@ -181,10 +181,10 @@ func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
// GRUState is a GRU state. It contains a single tensor. // GRUState is a GRU state. It contains a single tensor.
type GRUState struct { type GRUState struct {
Tensor ts.Tensor Tensor *ts.Tensor
} }
func (gs GRUState) Value() ts.Tensor { func (gs *GRUState) Value() *ts.Tensor {
return gs.Tensor return gs.Tensor
} }
@ -194,12 +194,12 @@ func (gs GRUState) Value() ts.Tensor {
type GRU struct { type GRU struct {
flatWeights []ts.Tensor flatWeights []ts.Tensor
hiddenDim int64 hiddenDim int64
config RNNConfig config *RNNConfig
device gotch.Device device gotch.Device
} }
// NewGRU create a new GRU layer // NewGRU create a new GRU layer
func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) { func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
var numDirections int64 = 1 var numDirections int64 = 1
if cfg.Bidirectional { if cfg.Bidirectional {
numDirections = 2 numDirections = 2
@ -222,7 +222,7 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
bIh := vs.Zeros("b_ih", []int64{gateDim}) bIh := vs.Zeros("b_ih", []int64{gateDim})
bHh := vs.Zeros("b_hh", []int64{gateDim}) bHh := vs.Zeros("b_hh", []int64{gateDim})
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh) flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
} }
} }
@ -232,7 +232,7 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional) ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
} }
return GRU{ return &GRU{
flatWeights: flatWeights, flatWeights: flatWeights,
hiddenDim: hiddenDim, hiddenDim: hiddenDim,
config: cfg, config: cfg,
@ -243,7 +243,7 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
// Implement RNN interface for GRU: // Implement RNN interface for GRU:
// ================================ // ================================
func (g GRU) ZeroState(batchDim int64) (retVal State) { func (g *GRU) ZeroState(batchDim int64) (retVal State) {
var numDirections int64 = 1 var numDirections int64 = 1
if g.config.Bidirectional { if g.config.Bidirectional {
numDirections = 2 numDirections = 2
@ -257,7 +257,7 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
return GRUState{Tensor: tensor} return GRUState{Tensor: tensor}
} }
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) { func (g *GRU) Step(input *ts.Tensor, inState State) (retVal State) {
unsqueezedInput := input.MustUnsqueeze(1, false) unsqueezedInput := input.MustUnsqueeze(1, false)
output, state := g.SeqInit(unsqueezedInput, inState) output, state := g.SeqInit(unsqueezedInput, inState)
@ -269,7 +269,7 @@ func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
return state return state
} }
func (g GRU) Seq(input ts.Tensor) (output ts.Tensor, state State) { func (g *GRU) Seq(input *ts.Tensor) (output *ts.Tensor, state State) {
batchDim := input.MustSize()[0] batchDim := input.MustSize()[0]
inState := g.ZeroState(batchDim) inState := g.ZeroState(batchDim)
@ -281,7 +281,7 @@ func (g GRU) Seq(input ts.Tensor) (output ts.Tensor, state State) {
return output, state return output, state
} }
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) { func (g *GRU) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst) output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)

View File

@ -10,7 +10,7 @@ import (
ts "github.com/sugarme/gotch/tensor" ts "github.com/sugarme/gotch/tensor"
) )
func gruTest(rnnConfig nn.RNNConfig, t *testing.T) { func gruTest(rnnConfig *nn.RNNConfig, t *testing.T) {
var ( var (
batchDim int64 = 5 batchDim int64 = 5
@ -47,7 +47,7 @@ func gruTest(rnnConfig nn.RNNConfig, t *testing.T) {
input = ts.MustRandn([]int64{batchDim, seqLen, inputDim}, gotch.Float, gotch.CPU) input = ts.MustRandn([]int64{batchDim, seqLen, inputDim}, gotch.Float, gotch.CPU)
output, _ = gru.Seq(input) output, _ = gru.Seq(input)
wantSeq := []int64{batchDim, seqLen, outputDim * numDirections} wantSeq := []int64{batchDim, seqLen, outputDim * numDirections}
gotSeq := output.(ts.Tensor).MustSize() gotSeq := output.(*ts.Tensor).MustSize()
if !reflect.DeepEqual(wantSeq, gotSeq) { if !reflect.DeepEqual(wantSeq, gotSeq) {
fmt.Println("Seq test:") fmt.Println("Seq test:")
@ -75,7 +75,7 @@ func TestGRU(t *testing.T) {
gruTest(cfg, t) gruTest(cfg, t)
} }
func lstmTest(rnnConfig nn.RNNConfig, t *testing.T) { func lstmTest(rnnConfig *nn.RNNConfig, t *testing.T) {
var ( var (
batchDim int64 = 5 batchDim int64 = 5
@ -121,7 +121,7 @@ func lstmTest(rnnConfig nn.RNNConfig, t *testing.T) {
output, _ = lstm.Seq(input) output, _ = lstm.Seq(input)
wantSeq := []int64{batchDim, seqLen, outputDim * numDirections} wantSeq := []int64{batchDim, seqLen, outputDim * numDirections}
gotSeq := output.(ts.Tensor).MustSize() gotSeq := output.(*ts.Tensor).MustSize()
if !reflect.DeepEqual(wantSeq, gotSeq) { if !reflect.DeepEqual(wantSeq, gotSeq) {
fmt.Println("Seq test:") fmt.Println("Seq test:")

View File

@ -14,15 +14,15 @@ type Sequential struct {
} }
// Seq creates a new empty sequential layer // Seq creates a new empty sequential layer
func Seq() Sequential { func Seq() *Sequential {
return Sequential{layers: make([]ts.Module, 0)} return &Sequential{layers: make([]ts.Module, 0)}
} }
// Sequential methods: // Sequential methods:
//==================== //====================
// Len returns number of sub-layers embedded in this layer // Len returns number of sub-layers embedded in this layer
func (s Sequential) Len() (retVal int64) { func (s *Sequential) Len() (retVal int64) {
return int64(len(s.layers)) return int64(len(s.layers))
} }
@ -47,7 +47,7 @@ func (s *Sequential) AddFn(fn ts.Module) {
} }
// ForwardAll applies the forward pass and returns the output for each layer. // ForwardAll applies the forward pass and returns the output for each layer.
func (s *Sequential) ForwardAll(xs ts.Tensor, opts ...uint8) (retVal []ts.Tensor) { func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []ts.Tensor) {
var n uint8 = uint8(len(s.layers)) var n uint8 = uint8(len(s.layers))
if len(opts) > 0 { if len(opts) > 0 {
@ -55,11 +55,11 @@ func (s *Sequential) ForwardAll(xs ts.Tensor, opts ...uint8) (retVal []ts.Tensor
} }
if s.IsEmpty() { if s.IsEmpty() {
return []ts.Tensor{xs.MustShallowClone()} return []ts.Tensor{*xs.MustShallowClone()}
} }
for i := 0; i < int(n); i++ { for i := 0; i < int(n); i++ {
retVal = append(retVal, s.layers[i].Forward(xs)) retVal = append(retVal, *s.layers[i].Forward(xs))
} }
return retVal return retVal
@ -76,7 +76,7 @@ func WithUint8(n uint8) func() uint8 {
// ========================================== // ==========================================
// Forward implements Module interface for Sequential // Forward implements Module interface for Sequential
func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { func (s *Sequential) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
if s.IsEmpty() { if s.IsEmpty() {
return xs.MustShallowClone() return xs.MustShallowClone()
} }
@ -85,12 +85,12 @@ func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
outs := make([]ts.Tensor, len(s.layers)) outs := make([]ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ { for i := 0; i < len(s.layers); i++ {
if i == 0 { if i == 0 {
outs[0] = s.layers[i].Forward(xs) outs[0] = *s.layers[i].Forward(xs)
defer outs[0].MustDrop() defer outs[0].MustDrop()
} else if i == len(s.layers)-1 { } else if i == len(s.layers)-1 {
return s.layers[i].Forward(outs[i-1]) return s.layers[i].Forward(&outs[i-1])
} else { } else {
outs[i] = s.layers[i].Forward(outs[i-1]) outs[i] = *s.layers[i].Forward(&outs[i-1])
defer outs[i].MustDrop() defer outs[i].MustDrop()
} }
} }
@ -104,8 +104,8 @@ type SequentialT struct {
} }
/// SeqT creates a new empty sequential layer. /// SeqT creates a new empty sequential layer.
func SeqT() SequentialT { func SeqT() *SequentialT {
return SequentialT{ return &SequentialT{
layers: make([]ts.ModuleT, 0), layers: make([]ts.ModuleT, 0),
} }
} }
@ -140,7 +140,7 @@ func (s *SequentialT) IsEmpty() (retVal bool) {
* return currTs * return currTs
* } * }
* */ * */
func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
if s.IsEmpty() { if s.IsEmpty() {
return xs.MustShallowClone() return xs.MustShallowClone()
} }
@ -149,12 +149,12 @@ func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
outs := make([]ts.Tensor, len(s.layers)) outs := make([]ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ { for i := 0; i < len(s.layers); i++ {
if i == 0 { if i == 0 {
outs[0] = s.layers[i].ForwardT(xs, train) outs[0] = *s.layers[i].ForwardT(xs, train)
defer outs[0].MustDrop() defer outs[0].MustDrop()
} else if i == len(s.layers)-1 { } else if i == len(s.layers)-1 {
return s.layers[i].ForwardT(outs[i-1], train) return s.layers[i].ForwardT(&outs[i-1], train)
} else { } else {
outs[i] = s.layers[i].ForwardT(outs[i-1], train) outs[i] = *s.layers[i].ForwardT(&outs[i-1], train)
defer outs[i].MustDrop() defer outs[i].MustDrop()
} }
} }
@ -187,7 +187,7 @@ func (s *SequentialT) AddFnT(fn ts.ModuleT) {
} }
// ForwardAll applies the forward pass and returns the output for each layer. // ForwardAll applies the forward pass and returns the output for each layer.
func (s *SequentialT) ForwardAllT(xs ts.Tensor, train bool, opts ...uint8) (retVal []ts.Tensor) { func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []ts.Tensor) {
var n uint8 = uint8(len(s.layers)) var n uint8 = uint8(len(s.layers))
if len(opts) > 0 { if len(opts) > 0 {
@ -195,13 +195,13 @@ func (s *SequentialT) ForwardAllT(xs ts.Tensor, train bool, opts ...uint8) (retV
} }
if s.IsEmpty() { if s.IsEmpty() {
return []ts.Tensor{xs.MustShallowClone()} return []ts.Tensor{*xs.MustShallowClone()}
} }
currTs := xs currTs := xs
for i := 0; i < int(n); i++ { for i := 0; i < int(n); i++ {
res := s.layers[i].ForwardT(currTs, train) res := s.layers[i].ForwardT(currTs, train)
retVal = append(retVal, res) retVal = append(retVal, *res)
currTs = res currTs = res
} }
@ -214,15 +214,15 @@ func (s *SequentialT) ForwardAllT(xs ts.Tensor, train bool, opts ...uint8) (retV
// Ref. https://stackoverflow.com/a/42182987 // Ref. https://stackoverflow.com/a/42182987
// NOTE: Specifically, `ForwardWith` is used to wrap anonymous function // NOTE: Specifically, `ForwardWith` is used to wrap anonymous function
// as input parameter of `AddFn` Sequential method. // as input parameter of `AddFn` Sequential method.
type ForwardWith func(ts.Tensor) ts.Tensor type ForwardWith func(*ts.Tensor) *ts.Tensor
func (fw ForwardWith) Forward(xs ts.Tensor) ts.Tensor { func (fw ForwardWith) Forward(xs *ts.Tensor) *ts.Tensor {
return fw(xs) return fw(xs)
} }
type ForwardTWith func(ts.Tensor, bool) ts.Tensor type ForwardTWith func(*ts.Tensor, bool) *ts.Tensor
func (fw ForwardTWith) ForwardT(xs ts.Tensor, train bool) ts.Tensor { func (fw ForwardTWith) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
return fw(xs, train) return fw(xs, train)
} }
@ -235,7 +235,7 @@ func (fw ForwardTWith) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
// This seems not working in Go. // This seems not working in Go.
// There 2 ways to get around. One is freeze VarStore, the other is // There 2 ways to get around. One is freeze VarStore, the other is
// set manually set AutoGrad at `loss` tensor. I.e., `loss = loss.MustSetRequiresGrad(true)` // set manually set AutoGrad at `loss` tensor. I.e., `loss = loss.MustSetRequiresGrad(true)`
func BatchAccuracyForLogits(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) { func BatchAccuracyForLogits(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
var ( var (
sumAccuracy float64 = 0.0 sumAccuracy float64 = 0.0
@ -272,7 +272,7 @@ func BatchAccuracyForLogits(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d gotch
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to // BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
// calculate accuracy for specified batch on module weight. It uses tensor // calculate accuracy for specified batch on module weight. It uses tensor
// indexing instead of Iter2 // indexing instead of Iter2
func BatchAccuracyForLogitsIdx(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) { func BatchAccuracyForLogitsIdx(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
var ( var (
sumAccuracy float64 = 0.0 sumAccuracy float64 = 0.0
sampleCount float64 = 0.0 sampleCount float64 = 0.0

View File

@ -14,8 +14,8 @@ type EmbeddingConfig struct {
PaddingIdx int64 PaddingIdx int64
} }
func DefaultEmbeddingConfig() EmbeddingConfig { func DefaultEmbeddingConfig() *EmbeddingConfig {
return EmbeddingConfig{ return &EmbeddingConfig{
Sparse: false, Sparse: false,
ScaleGradByFreq: false, ScaleGradByFreq: false,
WsInit: NewRandnInit(0.0, 1.0), WsInit: NewRandnInit(0.0, 1.0),
@ -28,13 +28,13 @@ func DefaultEmbeddingConfig() EmbeddingConfig {
// An embedding layer acts as a simple lookup table that stores embeddings. // An embedding layer acts as a simple lookup table that stores embeddings.
// This is commonly used to store word embeddings. // This is commonly used to store word embeddings.
type Embedding struct { type Embedding struct {
Ws ts.Tensor Ws *ts.Tensor
config EmbeddingConfig config *EmbeddingConfig
} }
// NewEmbedding creates a new Embedding // NewEmbedding creates a new Embedding
func NewEmbedding(vs Path, numEmbeddings int64, embeddingDim int64, config EmbeddingConfig) Embedding { func NewEmbedding(vs *Path, numEmbeddings int64, embeddingDim int64, config *EmbeddingConfig) *Embedding {
return Embedding{ return &Embedding{
Ws: vs.NewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit), Ws: vs.NewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
config: config, config: config,
} }
@ -44,11 +44,11 @@ func NewEmbedding(vs Path, numEmbeddings int64, embeddingDim int64, config Embed
// ========================================= // =========================================
// Forward implements Module interface for Embedding // Forward implements Module interface for Embedding
func (e Embedding) Forward(xs ts.Tensor) (retVal ts.Tensor) { func (e *Embedding) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse) return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
} }
// ForwardT implements ModuleT interface for Embedding // ForwardT implements ModuleT interface for Embedding
func (e Embedding) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { func (e *Embedding) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse) return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
} }

View File

@ -9,7 +9,7 @@ import (
ts "github.com/sugarme/gotch/tensor" ts "github.com/sugarme/gotch/tensor"
) )
func embeddingTest(embeddingConfig nn.EmbeddingConfig, t *testing.T) { func embeddingTest(embeddingConfig *nn.EmbeddingConfig, t *testing.T) {
var ( var (
batchDim int64 = 5 batchDim int64 = 5

View File

@ -20,7 +20,7 @@ const SEP = "."
// however the tensor is not set to require gradients. // however the tensor is not set to require gradients.
type Variables struct { type Variables struct {
mutex *sync.Mutex mutex *sync.Mutex
NamedVariables map[string]ts.Tensor NamedVariables map[string]*ts.Tensor
TrainableVariables []ts.Tensor TrainableVariables []ts.Tensor
} }
@ -45,14 +45,14 @@ type Entry struct {
} }
// NewVarStore creates a new variable store located on the specified device // NewVarStore creates a new variable store located on the specified device
func NewVarStore(device gotch.Device) VarStore { func NewVarStore(device gotch.Device) *VarStore {
variables := Variables{ variables := Variables{
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
NamedVariables: make(map[string]ts.Tensor, 0), NamedVariables: make(map[string]*ts.Tensor, 0),
TrainableVariables: make([]ts.Tensor, 0), TrainableVariables: make([]ts.Tensor, 0),
} }
return VarStore{ return &VarStore{
device: device, device: device,
Vars: variables, Vars: variables,
} }
@ -94,7 +94,7 @@ func (vs *VarStore) TrainableVariables() (retVal []ts.Tensor) {
retVal = vs.Vars.TrainableVariables retVal = vs.Vars.TrainableVariables
for _, t := range vs.Vars.TrainableVariables { for _, t := range vs.Vars.TrainableVariables {
retVal = append(retVal, t.MustShallowClone()) retVal = append(retVal, *t.MustShallowClone())
} }
return retVal return retVal
@ -108,7 +108,7 @@ func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
retVal = make(map[string]ts.Tensor, 0) retVal = make(map[string]ts.Tensor, 0)
for k, v := range vs.Vars.NamedVariables { for k, v := range vs.Vars.NamedVariables {
retVal[k] = v.MustShallowClone() retVal[k] = *v.MustShallowClone()
} }
return retVal return retVal
@ -119,8 +119,8 @@ func (vs *VarStore) Variables() (retVal map[string]ts.Tensor) {
// NOTE: Variables are named and organized using paths. This function returns // NOTE: Variables are named and organized using paths. This function returns
// the top level path for the var store and can be combined with '/' // the top level path for the var store and can be combined with '/'
// to create sub-paths. // to create sub-paths.
func (vs *VarStore) Root() (retVal Path) { func (vs *VarStore) Root() *Path {
return Path{ return &Path{
path: []string{}, path: []string{},
varstore: vs, varstore: vs,
} }
@ -130,7 +130,7 @@ func (vs *VarStore) Root() (retVal Path) {
// //
// NOTE: Weight values for all the tensors currently stored in the // NOTE: Weight values for all the tensors currently stored in the
// var-store gets saved in the given file. // var-store gets saved in the given file.
func (vs *VarStore) Save(filepath string) (err error) { func (vs *VarStore) Save(filepath string) error {
vs.Vars.mutex.Lock() vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock() defer vs.Vars.mutex.Unlock()
@ -155,7 +155,7 @@ func (vs *VarStore) Save(filepath string) (err error) {
// for these tensors are modified. // for these tensors are modified.
// It will throw error if name of the loaded tensors can not find // It will throw error if name of the loaded tensors can not find
// in the current var-store named tensors set. // in the current var-store named tensors set.
func (vs *VarStore) Load(filepath string) (err error) { func (vs *VarStore) Load(filepath string) error {
namedTensors, err := ts.LoadMultiWithDevice(filepath, vs.device) namedTensors, err := ts.LoadMultiWithDevice(filepath, vs.device)
if err != nil { if err != nil {
return err return err
@ -163,7 +163,7 @@ func (vs *VarStore) Load(filepath string) (err error) {
var namedTensorsMap map[string]ts.Tensor = make(map[string]ts.Tensor, 0) var namedTensorsMap map[string]ts.Tensor = make(map[string]ts.Tensor, 0)
for _, namedTensor := range namedTensors { for _, namedTensor := range namedTensors {
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor namedTensorsMap[namedTensor.Name] = *namedTensor.Tensor
} }
// Match and in-place copy value (update) from newly loaded tensors // Match and in-place copy value (update) from newly loaded tensors
@ -190,7 +190,7 @@ func (vs *VarStore) Load(filepath string) (err error) {
} }
ts.NoGrad(func() { ts.NoGrad(func() {
vs.Vars.NamedVariables[tsName].Copy_(currTs) vs.Vars.NamedVariables[tsName].Copy_(&currTs)
}) })
} }
return nil return nil
@ -213,7 +213,7 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
return nil, err return nil, err
} }
var namedTensorsMap map[string]ts.Tensor = make(map[string]ts.Tensor, 0) var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0)
for _, namedTensor := range namedTensors { for _, namedTensor := range namedTensors {
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
} }
@ -226,7 +226,7 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
defer vs.Vars.mutex.Unlock() defer vs.Vars.mutex.Unlock()
for tsName := range vs.Vars.NamedVariables { for tsName := range vs.Vars.NamedVariables {
var currTs ts.Tensor var currTs *ts.Tensor
var ok bool var ok bool
// missing variable // missing variable
@ -320,7 +320,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
// ============= // =============
// Sub gets a sub-path of the given path. // Sub gets a sub-path of the given path.
func (p *Path) Sub(str string) (retVal Path) { func (p *Path) Sub(str string) *Path {
if strings.Contains(str, SEP) { if strings.Contains(str, SEP) {
log.Fatalf("Sub name cannot contain %v (%v)\n", SEP, str) log.Fatalf("Sub name cannot contain %v (%v)\n", SEP, str)
@ -328,7 +328,7 @@ func (p *Path) Sub(str string) (retVal Path) {
path := p.path path := p.path
path = append(path, str) path = append(path, str)
return Path{ return &Path{
path: path, path: path,
varstore: p.varstore, varstore: p.varstore,
} }
@ -355,7 +355,7 @@ func (p *Path) getpath(name string) (retVal string) {
} }
} }
func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tensor) { func (p *Path) add(name string, newTs *ts.Tensor, trainable bool) (retVal *ts.Tensor) {
path := p.getpath(name) path := p.getpath(name)
p.varstore.Vars.mutex.Lock() p.varstore.Vars.mutex.Lock()
@ -366,7 +366,7 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
} }
var ( var (
tensor ts.Tensor tensor *ts.Tensor
err error err error
) )
if trainable { if trainable {
@ -379,7 +379,7 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
} }
if trainable { if trainable {
p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, tensor) p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, *tensor)
} }
p.varstore.Vars.NamedVariables[path] = tensor p.varstore.Vars.NamedVariables[path] = tensor
@ -387,7 +387,7 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
return tensor return tensor
} }
func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, variables Variables) (retVal ts.Tensor) { func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, variables Variables) (retVal *ts.Tensor) {
path := p.getpath(name) path := p.getpath(name)
// if found, return it // if found, return it
@ -397,7 +397,7 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
// not found, add it // not found, add it
var err error var err error
var ttensor ts.Tensor var ttensor *ts.Tensor
if trainable { if trainable {
ttensor, err = tensor.SetRequiresGrad(true, false) ttensor, err = tensor.SetRequiresGrad(true, false)
if err != nil { if err != nil {
@ -408,7 +408,7 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
} }
if trainable { if trainable {
variables.TrainableVariables = append(variables.TrainableVariables, ttensor) variables.TrainableVariables = append(variables.TrainableVariables, *ttensor)
} }
variables.NamedVariables[path] = ttensor variables.NamedVariables[path] = ttensor
@ -422,7 +422,7 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
// has the specified shape. The variable will not be trainable so // has the specified shape. The variable will not be trainable so
// gradients will not be tracked. // gradients will not be tracked.
// The variable uses a float tensor initialized with zeros. // The variable uses a float tensor initialized with zeros.
func (p *Path) ZerosNoTrain(name string, dims []int64) (retVal ts.Tensor) { func (p *Path) ZerosNoTrain(name string, dims []int64) (retVal *ts.Tensor) {
device := p.Device() device := p.Device()
z, err := ts.Zeros(dims, gotch.Float, device) z, err := ts.Zeros(dims, gotch.Float, device)
@ -439,7 +439,7 @@ func (p *Path) ZerosNoTrain(name string, dims []int64) (retVal ts.Tensor) {
// has the specified shape. The variable will not be trainable so // has the specified shape. The variable will not be trainable so
// gradients will not be tracked. // gradients will not be tracked.
// The variable uses a float tensor initialized with ones. // The variable uses a float tensor initialized with ones.
func (p *Path) OnesNoTrain(name string, dims []int64) (retVal ts.Tensor) { func (p *Path) OnesNoTrain(name string, dims []int64) (retVal *ts.Tensor) {
device := p.Device() device := p.Device()
z, err := ts.Ones(dims, gotch.Float, device) z, err := ts.Ones(dims, gotch.Float, device)
@ -457,7 +457,7 @@ func (p *Path) OnesNoTrain(name string, dims []int64) (retVal ts.Tensor) {
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized as per the // The variable uses a float tensor initialized as per the
// related argument. // related argument.
func (p *Path) NewVar(name string, dims []int64, ini Init) (retVal ts.Tensor) { func (p *Path) NewVar(name string, dims []int64, ini Init) (retVal *ts.Tensor) {
v := ini.InitTensor(dims, p.varstore.device) v := ini.InitTensor(dims, p.varstore.device)
@ -470,7 +470,7 @@ func (p *Path) NewVar(name string, dims []int64, ini Init) (retVal ts.Tensor) {
// has the specified shape. The variable is trainable, its gradient // has the specified shape. The variable is trainable, its gradient
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized with zeros. // The variable uses a float tensor initialized with zeros.
func (p *Path) Zeros(name string, dims []int64) (retVal ts.Tensor) { func (p *Path) Zeros(name string, dims []int64) (retVal *ts.Tensor) {
return p.NewVar(name, dims, NewConstInit(0.0)) return p.NewVar(name, dims, NewConstInit(0.0))
} }
@ -481,7 +481,7 @@ func (p *Path) Zeros(name string, dims []int64) (retVal ts.Tensor) {
// has the specified shape. The variable is trainable, its gradient // has the specified shape. The variable is trainable, its gradient
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized with ones. // The variable uses a float tensor initialized with ones.
func (p *Path) Ones(name string, dims []int64) (retVal ts.Tensor) { func (p *Path) Ones(name string, dims []int64) (retVal *ts.Tensor) {
return p.NewVar(name, dims, NewConstInit(1.0)) return p.NewVar(name, dims, NewConstInit(1.0))
} }
@ -493,7 +493,7 @@ func (p *Path) Ones(name string, dims []int64) (retVal ts.Tensor) {
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized randomly using a // The variable uses a float tensor initialized randomly using a
// standard normal distribution. // standard normal distribution.
func (p *Path) RandnStandard(name string, dims []int64) (retVal ts.Tensor) { func (p *Path) RandnStandard(name string, dims []int64) (retVal *ts.Tensor) {
return p.NewVar(name, dims, NewRandnInit(0.0, 1.0)) return p.NewVar(name, dims, NewRandnInit(0.0, 1.0))
} }
@ -505,7 +505,7 @@ func (p *Path) RandnStandard(name string, dims []int64) (retVal ts.Tensor) {
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized randomly using a // The variable uses a float tensor initialized randomly using a
// normal distribution with the specified mean and standard deviation. // normal distribution with the specified mean and standard deviation.
func (p *Path) Randn(name string, dims []int64, mean float64, stdev float64) (retVal ts.Tensor) { func (p *Path) Randn(name string, dims []int64, mean float64, stdev float64) (retVal *ts.Tensor) {
return p.NewVar(name, dims, NewRandnInit(mean, stdev)) return p.NewVar(name, dims, NewRandnInit(mean, stdev))
} }
@ -517,7 +517,7 @@ func (p *Path) Randn(name string, dims []int64, mean float64, stdev float64) (re
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized randomly using a // The variable uses a float tensor initialized randomly using a
// uniform distribution between the specified bounds. // uniform distribution between the specified bounds.
func (p *Path) Uniform(name string, dims []int64, lo, up float64) (retVal ts.Tensor) { func (p *Path) Uniform(name string, dims []int64, lo, up float64) (retVal *ts.Tensor) {
return p.NewVar(name, dims, NewUniformInit(lo, up)) return p.NewVar(name, dims, NewUniformInit(lo, up))
} }
@ -529,7 +529,7 @@ func (p *Path) Uniform(name string, dims []int64, lo, up float64) (retVal ts.Ten
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized randomly using a // The variable uses a float tensor initialized randomly using a
// uniform distribution which bounds follow Kaiming initialization. // uniform distribution which bounds follow Kaiming initialization.
func (p *Path) KaimingUniform(name string, dims []int64) (retVal ts.Tensor) { func (p *Path) KaimingUniform(name string, dims []int64) (retVal *ts.Tensor) {
return p.NewVar(name, dims, NewKaimingUniformInit()) return p.NewVar(name, dims, NewKaimingUniformInit())
} }
@ -541,7 +541,7 @@ func (p *Path) KaimingUniform(name string, dims []int64) (retVal ts.Tensor) {
// will be tracked. // will be tracked.
// The variable uses a float tensor initialized by copying some // The variable uses a float tensor initialized by copying some
// given tensor. // given tensor.
func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) { func (p *Path) VarCopy(name string, t *ts.Tensor) (retVal *ts.Tensor) {
size, err := t.Size() size, err := t.Size()
if err != nil { if err != nil {
@ -557,7 +557,7 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
} }
// Get gets the tensor corresponding to a given name if present. // Get gets the tensor corresponding to a given name if present.
func (p *Path) Get(name string) (retVal ts.Tensor, err error) { func (p *Path) Get(name string) (retVal *ts.Tensor, err error) {
p.varstore.Vars.mutex.Lock() p.varstore.Vars.mutex.Lock()
defer p.varstore.Vars.mutex.Unlock() defer p.varstore.Vars.mutex.Unlock()
@ -572,11 +572,11 @@ func (p *Path) Get(name string) (retVal ts.Tensor, err error) {
} }
// Entry gets the entry corresponding to a given name for in-place manipulation. // Entry gets the entry corresponding to a given name for in-place manipulation.
func (p *Path) Entry(name string) (retVal Entry) { func (p *Path) Entry(name string) *Entry {
p.varstore.Vars.mutex.Lock() p.varstore.Vars.mutex.Lock()
defer p.varstore.Vars.mutex.Unlock() defer p.varstore.Vars.mutex.Unlock()
return Entry{ return &Entry{
name: name, name: name,
variables: &p.varstore.Vars, variables: &p.varstore.Vars,
path: p, path: p,
@ -592,14 +592,14 @@ func (p *Path) Entry(name string) (retVal Entry) {
// var store, the corresponding tensor is returned. Otherwise a new // var store, the corresponding tensor is returned. Otherwise a new
// variable is added to the var-store with the entry name and is // variable is added to the var-store with the entry name and is
// initialized according to the init parameter. // initialized according to the init parameter.
func (e *Entry) OrVar(dims []int64, init Init) (retVal ts.Tensor) { func (e *Entry) OrVar(dims []int64, init Init) (retVal *ts.Tensor) {
v := init.InitTensor(dims, e.path.varstore.device) v := init.InitTensor(dims, e.path.varstore.device)
return e.path.getOrAddWithLock(e.name, v, true, *e.variables) return e.path.getOrAddWithLock(e.name, v, true, *e.variables)
} }
// Returns the existing entry if, otherwise create a new variable. // Returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) { func (e *Entry) OrVarCopy(tensor *ts.Tensor) (retVal *ts.Tensor) {
size, err := tensor.Size() size, err := tensor.Size()
if err != nil { if err != nil {
@ -615,50 +615,50 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
} }
// Returns the existing entry if, otherwise create a new variable. // Returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrKaimingUniform(dims []int64) (retVal ts.Tensor) { func (e *Entry) OrKaimingUniform(dims []int64) (retVal *ts.Tensor) {
return e.OrVar(dims, NewKaimingUniformInit()) return e.OrVar(dims, NewKaimingUniformInit())
} }
// OrOnes returns the existing entry if, otherwise create a new variable. // OrOnes returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrOnes(dims []int64) (retVal ts.Tensor) { func (e *Entry) OrOnes(dims []int64) (retVal *ts.Tensor) {
return e.OrVar(dims, NewConstInit(1.0)) return e.OrVar(dims, NewConstInit(1.0))
} }
// OrOnesNoTrain returns the existing entry if, otherwise create a new variable. // OrOnesNoTrain returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrOnesNoTrain(dims []int64) (retVal ts.Tensor) { func (e *Entry) OrOnesNoTrain(dims []int64) (retVal *ts.Tensor) {
o := ts.MustOnes(dims, gotch.Float, e.path.Device()) o := ts.MustOnes(dims, gotch.Float, e.path.Device())
return e.path.getOrAddWithLock(e.name, o, true, *e.variables) return e.path.getOrAddWithLock(e.name, o, true, *e.variables)
} }
// OrRandn returns the existing entry if, otherwise create a new variable. // OrRandn returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrRandn(dims []int64, mean, stdev float64) (retVal ts.Tensor) { func (e *Entry) OrRandn(dims []int64, mean, stdev float64) (retVal *ts.Tensor) {
return e.OrVar(dims, NewRandnInit(mean, stdev)) return e.OrVar(dims, NewRandnInit(mean, stdev))
} }
// OrRandnStandard returns the existing entry if, otherwise create a new variable. // OrRandnStandard returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrRandnStandard(dims []int64) (retVal ts.Tensor) { func (e *Entry) OrRandnStandard(dims []int64) (retVal *ts.Tensor) {
return e.OrVar(dims, NewRandnInit(0.0, 1.0)) return e.OrVar(dims, NewRandnInit(0.0, 1.0))
} }
// OrUniform returns the existing entry if, otherwise create a new variable. // OrUniform returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrUniform(dims []int64, lo, up float64) (retVal ts.Tensor) { func (e *Entry) OrUniform(dims []int64, lo, up float64) (retVal *ts.Tensor) {
return e.OrVar(dims, NewUniformInit(lo, up)) return e.OrVar(dims, NewUniformInit(lo, up))
} }
// OrZeros returns the existing entry if, otherwise create a new variable. // OrZeros returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrZeros(dims []int64) (retVal ts.Tensor) { func (e *Entry) OrZeros(dims []int64) (retVal *ts.Tensor) {
return e.OrVar(dims, NewConstInit(0.0)) return e.OrVar(dims, NewConstInit(0.0))
} }
// OrZerosNoTrain returns the existing entry if, otherwise create a new variable. // OrZerosNoTrain returns the existing entry if, otherwise create a new variable.
func (e *Entry) OrZerosNoTrain(dims []int64) (retVal ts.Tensor) { func (e *Entry) OrZerosNoTrain(dims []int64) (retVal *ts.Tensor) {
z := ts.MustZeros(dims, gotch.Float, e.path.Device()) z := ts.MustZeros(dims, gotch.Float, e.path.Device())
return e.path.getOrAddWithLock(e.name, z, true, *e.variables) return e.path.getOrAddWithLock(e.name, z, true, *e.variables)

View File

@ -46,7 +46,7 @@ func TestSaveLoad(t *testing.T) {
panic(err) panic(err)
} }
add := func(vs nn.Path) (ts.Tensor, ts.Tensor) { add := func(vs *nn.Path) (*ts.Tensor, *ts.Tensor) {
subA := vs.Sub("a") subA := vs.Sub("a")
subB := subA.Sub("b") subB := subA.Sub("b")
v := subB.Ones("t2", []int64{3}) v := subB.Ones("t2", []int64{3})

View File

@ -16,8 +16,8 @@ import (
// containing a (potentially random) slice of each of the two input // containing a (potentially random) slice of each of the two input
// tensors. // tensors.
type Iter2 struct { type Iter2 struct {
xs Tensor xs *Tensor
ys Tensor ys *Tensor
batchIndex int64 batchIndex int64
batchSize int64 batchSize int64
totalSize int64 totalSize int64
@ -38,12 +38,16 @@ type Iter2 struct {
// * `xs` - the features to be used by the model. // * `xs` - the features to be used by the model.
// * `ys` - the targets that the model attempts to predict. // * `ys` - the targets that the model attempts to predict.
// * `batch_size` - the size of batches to be returned. // * `batch_size` - the size of batches to be returned.
func NewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2, err error) { func NewIter2(xs, ys *Tensor, batchSize int64) (*Iter2, error) {
var (
iter *Iter2
err error
)
totalSize := xs.MustSize()[0] totalSize := xs.MustSize()[0]
if ys.MustSize()[0] != totalSize { if ys.MustSize()[0] != totalSize {
err = fmt.Errorf("Different dimension for the two inputs: %v - %v", xs.MustSize(), ys.MustSize()) err = fmt.Errorf("Different dimension for the two inputs: %v - %v", xs.MustSize(), ys.MustSize())
return retVal, err return nil, err
} }
// xsClone, err := xs.ZerosLike(false) // xsClone, err := xs.ZerosLike(false)
@ -58,7 +62,7 @@ func NewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2, err error) {
// } // }
// ysClone.Copy_(ys) // ysClone.Copy_(ys)
retVal = Iter2{ iter = &Iter2{
xs: xs.MustShallowClone(), xs: xs.MustShallowClone(),
ys: ys.MustShallowClone(), ys: ys.MustShallowClone(),
// xs: xsClone, // xs: xsClone,
@ -69,7 +73,7 @@ func NewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2, err error) {
returnSmallLastBatch: false, returnSmallLastBatch: false,
} }
return retVal, nil return iter, nil
} }
// MustNewIter2 returns a new iterator. // MustNewIter2 returns a new iterator.
@ -84,14 +88,14 @@ func NewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2, err error) {
// * `xs` - the features to be used by the model. // * `xs` - the features to be used by the model.
// * `ys` - the targets that the model attempts to predict. // * `ys` - the targets that the model attempts to predict.
// * `batch_size` - the size of batches to be returned. // * `batch_size` - the size of batches to be returned.
func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) { func MustNewIter2(xs, ys *Tensor, batchSize int64) *Iter2 {
retVal, err := NewIter2(xs, ys, batchSize) iter, err := NewIter2(xs, ys, batchSize)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return retVal return iter
} }
// Shuffle shuffles the dataset. // Shuffle shuffles the dataset.
@ -108,20 +112,20 @@ func (it *Iter2) Shuffle() {
} }
// ToDevice transfers the mini-batches to a specified device. // ToDevice transfers the mini-batches to a specified device.
func (it Iter2) ToDevice(device gotch.Device) (retVal Iter2) { func (it *Iter2) ToDevice(device gotch.Device) *Iter2 {
it.device = device it.device = device
return it return it
} }
// ReturnSmallLastBatch when set, returns the last batch even if smaller than the batch size. // ReturnSmallLastBatch when set, returns the last batch even if smaller than the batch size.
func (it Iter2) ReturnSmallLastBatch() (retVal Iter2) { func (it *Iter2) ReturnSmallLastBatch() *Iter2 {
it.returnSmallLastBatch = true it.returnSmallLastBatch = true
return it return it
} }
type Iter2Item struct { type Iter2Item struct {
Data Tensor Data *Tensor
Label Tensor Label *Tensor
} }
// Next implements iterator for Iter2 // Next implements iterator for Iter2
@ -148,7 +152,7 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
} }
} }
func (it Iter2) Drop() { func (it *Iter2) Drop() {
it.xs.MustDrop() it.xs.MustDrop()
it.ys.MustDrop() it.ys.MustDrop()
} }
@ -156,17 +160,17 @@ func (it Iter2) Drop() {
// TextData represent text data in tensor of runes (uint8) // TextData represent text data in tensor of runes (uint8)
// and its corresponding string // and its corresponding string
type TextData struct { type TextData struct {
Data Tensor // frequency (occurence) of byte value from input text Data *Tensor // frequency (occurence) of byte value from input text
CharForLabel []rune // unique rune values from input text CharForLabel []rune // unique rune values from input text
} }
// TextDataIter is a text data interator // TextDataIter is a text data interator
type TextDataIter struct { type TextDataIter struct {
Data Tensor Data *Tensor
SeqLen int64 SeqLen int64
BatchIndex int64 BatchIndex int64
BatchSize int64 BatchSize int64
Indexes Tensor Indexes *Tensor
IndexesLen int64 IndexesLen int64
} }
@ -179,17 +183,17 @@ type TextDataIter struct {
// will labelled with new label(index) // will labelled with new label(index)
// Data: tensor of labels // Data: tensor of labels
// CharForLabel: []rune (unique runes from text input) // CharForLabel: []rune (unique runes from text input)
func NewTextData(filename string) (retVal TextData, err error) { func NewTextData(filename string) (*TextData, error) {
filePath, err := filepath.Abs(filename) filePath, err := filepath.Abs(filename)
if err != nil { if err != nil {
return retVal, err return nil, err
} }
r, err := os.Open(filePath) r, err := os.Open(filePath)
buffer, err := ioutil.ReadAll(r) buffer, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
return retVal, err return nil, err
} }
var labelForChar map[byte]uint8 = make(map[byte]uint8, 0) var labelForChar map[byte]uint8 = make(map[byte]uint8, 0)
@ -216,35 +220,35 @@ func NewTextData(filename string) (retVal TextData, err error) {
data := MustOfSlice(dataIndexes) data := MustOfSlice(dataIndexes)
return TextData{ return &TextData{
Data: data, Data: data,
CharForLabel: charForLabel, CharForLabel: charForLabel,
}, nil }, nil
} }
// Labels returns the number of different `character` (rune) used by the dataset. // Labels returns the number of different `character` (rune) used by the dataset.
func (td TextData) Labels() (retVal int64) { func (td *TextData) Labels() (retVal int64) {
return int64(len(td.CharForLabel)) return int64(len(td.CharForLabel))
} }
// Data returns a shallow copy of the data. // Data returns a shallow copy of the data.
func (td TextData) CloneData() (retVal Tensor) { func (td *TextData) CloneData() *Tensor {
return td.Data.MustShallowClone() return td.Data.MustShallowClone()
} }
// LabelForChar returns a corresponding `char` (rune) for // LabelForChar returns a corresponding `char` (rune) for
// specified label input // specified label input
func (td TextData) LabelForChar(label int64) (retVal rune) { func (td *TextData) LabelForChar(label int64) rune {
return td.CharForLabel[int(label)] return td.CharForLabel[int(label)]
} }
// IterShuffle returns a batch iterator over the dataset. // IterShuffle returns a batch iterator over the dataset.
// Each sample is made of seq_len characters. // Each sample is made of seq_len characters.
func (td TextData) IterShuffle(seqLen int64, batchSize int64) (retVal TextDataIter) { func (td *TextData) IterShuffle(seqLen int64, batchSize int64) *TextDataIter {
indexesLen := td.Data.MustSize()[0] - seqLen + 1 indexesLen := td.Data.MustSize()[0] - seqLen + 1
return TextDataIter{ return &TextDataIter{
Data: td.Data.MustShallowClone(), Data: td.Data.MustShallowClone(),
SeqLen: seqLen, SeqLen: seqLen,
BatchIndex: 0, BatchIndex: 0,
@ -255,12 +259,12 @@ func (td TextData) IterShuffle(seqLen int64, batchSize int64) (retVal TextDataIt
} }
// Next implements iterator for TextDataIter // Next implements iterator for TextDataIter
func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) { func (tdi *TextDataIter) Next() (*Tensor, bool) {
start := tdi.BatchIndex * tdi.BatchSize start := tdi.BatchIndex * tdi.BatchSize
size := min(tdi.BatchSize, tdi.IndexesLen-start) size := min(tdi.BatchSize, tdi.IndexesLen-start)
if size < tdi.BatchSize { if size < tdi.BatchSize {
return retVal, false return nil, false
} }
tdi.BatchIndex += 1 tdi.BatchIndex += 1
@ -276,10 +280,10 @@ func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
for _, idx := range indexes { for _, idx := range indexes {
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen) narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
idxTs := tdi.Data.Idx(narrowIdx) idxTs := tdi.Data.Idx(narrowIdx)
batch = append(batch, idxTs) batch = append(batch, *idxTs)
} }
retVal = MustStack(batch, 0) retVal := MustStack(batch, 0)
// Delete intermediate tensors // Delete intermediate tensors
for _, xs := range batch { for _, xs := range batch {
@ -289,7 +293,7 @@ func (tdi *TextDataIter) Next() (retVal Tensor, ok bool) {
return retVal, true return retVal, true
} }
func min(v1, v2 int64) (retVal int64) { func min(v1, v2 int64) int64 {
if v1 < v2 { if v1 < v2 {
return v1 return v1
} }

View File

@ -9,22 +9,20 @@ import (
) )
// LoadHwc returns a tensor of shape [height, width, channels] on success. // LoadHwc returns a tensor of shape [height, width, channels] on success.
func LoadHwc(path string) (retVal Tensor, err error) { func LoadHwc(path string) (*Tensor, error) {
ctensor := lib.AtLoadImage(path) ctensor := lib.AtLoadImage(path)
err = TorchErr() err := TorchErr()
if err != nil { if err != nil {
return retVal, err return nil, err
} }
retVal = Tensor{ctensor} return &Tensor{ctensor}, nil
return retVal, nil
} }
// SaveHwc save an image from tensor. It expects a tensor of shape [height, // SaveHwc save an image from tensor. It expects a tensor of shape [height,
// width, channels] // width, channels]
func SaveHwc(ts Tensor, path string) (err error) { func SaveHwc(ts *Tensor, path string) error {
lib.AtSaveImage(ts.ctensor, path) lib.AtSaveImage(ts.ctensor, path)
return TorchErr() return TorchErr()
@ -32,14 +30,13 @@ func SaveHwc(ts Tensor, path string) (err error) {
// ResizeHwc expects a tensor of shape [height, width, channels]. // ResizeHwc expects a tensor of shape [height, width, channels].
// On success returns a tensor of shape [height, width, channels]. // On success returns a tensor of shape [height, width, channels].
func ResizeHwc(ts Tensor, outWidth, outHeight int64) (retVal Tensor, err error) { func ResizeHwc(ts *Tensor, outWidth, outHeight int64) (*Tensor, error) {
ctensor := lib.AtResizeImage(ts.ctensor, outWidth, outHeight) ctensor := lib.AtResizeImage(ts.ctensor, outWidth, outHeight)
err = TorchErr() err := TorchErr()
if err != nil { if err != nil {
return retVal, err return nil, err
} }
retVal = Tensor{ctensor}
return retVal, nil return &Tensor{ctensor}, nil
} }

View File

@ -79,7 +79,7 @@ type Narrow struct {
Start int64 Start int64
End int64 End int64
} }
type IndexSelect struct{ Index Tensor } type IndexSelect struct{ Index *Tensor }
type InsertNewAxis struct{} type InsertNewAxis struct{}
// NewSelect creates an tensor indexer with given index. // NewSelect creates an tensor indexer with given index.
@ -93,7 +93,7 @@ func NewNarrow(start, end int64) Narrow {
return Narrow{Start: start, End: end} return Narrow{Start: start, End: end}
} }
func NewIndexSelect(ts Tensor) IndexSelect { func NewIndexSelect(ts *Tensor) IndexSelect {
return IndexSelect{Index: ts} return IndexSelect{Index: ts}
} }
@ -130,7 +130,7 @@ type IndexOp interface {
// //
// NOTE: // NOTE:
// - `index`: expects type `TensorIndexer` or `[]TensorIndexer` // - `index`: expects type `TensorIndexer` or `[]TensorIndexer`
func (ts *Tensor) Idx(index interface{}) (retVal Tensor) { func (ts *Tensor) Idx(index interface{}) (retVal *Tensor) {
// indexTyp := reflect.TypeOf(index) // indexTyp := reflect.TypeOf(index)
indexVal := reflect.ValueOf(index) indexVal := reflect.ValueOf(index)
@ -196,7 +196,7 @@ func (ts *Tensor) Idx(index interface{}) (retVal Tensor) {
// Tensor Methods: // Tensor Methods:
// =============== // ===============
func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) { func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error) {
// Make sure number of non-newaxis is not exceed number of dimensions // Make sure number of non-newaxis is not exceed number of dimensions
var numNewAxis int = 0 var numNewAxis int = 0
@ -221,7 +221,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// If `spec` is `IndexSelect` type and // If `spec` is `IndexSelect` type and
if reflect.TypeOf(spec).Name() == "IndexSelect" { if reflect.TypeOf(spec).Name() == "IndexSelect" {
if reflect.ValueOf(spec).Kind() == reflect.Struct { if reflect.ValueOf(spec).Kind() == reflect.Struct {
inputTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(Tensor) inputTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(*Tensor)
// 1. Either its input tensor has dimension > 1, throw error. // 1. Either its input tensor has dimension > 1, throw error.
inputTensorShape, err := inputTensor.Size() inputTensorShape, err := inputTensor.Size()
@ -249,9 +249,9 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// Now, apply indexing from left to right. // Now, apply indexing from left to right.
var ( var (
currTensor Tensor = ts.MustShallowClone() currTensor *Tensor = ts.MustShallowClone()
currIdx int64 = 0 currIdx int64 = 0
nextTensor Tensor nextTensor *Tensor
nextIdx int64 nextIdx int64
) )
@ -282,8 +282,8 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
return retVal, err return retVal, err
} }
nextIdx = currIdx + 1 nextIdx = currIdx + 1
case "IndexSelect": // 1 field `(Index Tensor)` case "IndexSelect": // 1 field `(Index *Tensor)`
indexTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(Tensor) indexTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(*Tensor)
device, err := currTensor.Device() device, err := currTensor.Device()
if err != nil { if err != nil {
return retVal, err return retVal, err
@ -307,7 +307,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) mustIndexer(indexSpec []TensorIndexer) (retVal Tensor) { func (ts *Tensor) mustIndexer(indexSpec []TensorIndexer) (retVal *Tensor) {
retVal, err := ts.indexer(indexSpec) retVal, err := ts.indexer(indexSpec)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -14,27 +14,27 @@ type Iterator interface {
type Iterable struct { type Iterable struct {
Index int64 Index int64
Len int64 Len int64
Content Tensor Content *Tensor
ItemKind gotch.DType ItemKind gotch.DType
} }
// Next implements Iterator interface // Next implements Iterator interface
func (it *Iterable) Next() (retVal interface{}, ok bool) { func (it *Iterable) Next() (item interface{}, ok bool) {
if it.Index == it.Len { if it.Index == it.Len {
return retVal, false return nil, false
} }
var err error var err error
switch it.ItemKind.Kind().String() { switch it.ItemKind.Kind().String() {
case "int64": case "int64":
retVal, err = it.Content.Int64Value([]int64{it.Index}) item, err = it.Content.Int64Value([]int64{it.Index})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
it.Index += 1 it.Index += 1
case "float64": case "float64":
retVal, err = it.Content.Float64Value([]int64{it.Index}) item, err = it.Content.Float64Value([]int64{it.Index})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -44,22 +44,22 @@ func (it *Iterable) Next() (retVal interface{}, ok bool) {
log.Fatal(err) log.Fatal(err)
} }
return retVal, true return item, true
} }
// Iter creates an iterable object with specified item type. // Iter creates an iterable object with specified item type.
func (ts Tensor) Iter(dtype gotch.DType) (retVal Iterable, err error) { func (ts *Tensor) Iter(dtype gotch.DType) (*Iterable, error) {
num, err := ts.Size1() // size for 1D tensor num, err := ts.Size1() // size for 1D tensor
if err != nil { if err != nil {
return retVal, err return nil, err
} }
tmp, err := ts.ShallowClone() tmp, err := ts.ShallowClone()
if err != nil { if err != nil {
return retVal, err return nil, err
} }
content := tmp.MustTotype(dtype, true) content := tmp.MustTotype(dtype, true)
return Iterable{ return &Iterable{
Index: 0, Index: 0,
Len: num, Len: num,
Content: content, Content: content,

View File

@ -950,7 +950,7 @@ func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (retVal CModu
} }
// Performs the forward pass for a model on some specified tensor inputs. // Performs the forward pass for a model on some specified tensor inputs.
func (cm CModule) ForwardTs(tensors []Tensor) (retVal Tensor, err error) { func (cm CModule) ForwardTs(tensors []Tensor) (retVal *Tensor, err error) {
var ctensors []lib.Ctensor var ctensors []lib.Ctensor
for _, t := range tensors { for _, t := range tensors {
ctensors = append(ctensors, t.ctensor) ctensors = append(ctensors, t.ctensor)
@ -994,7 +994,7 @@ func (cm CModule) ForwardTs(tensors []Tensor) (retVal Tensor, err error) {
return retVal, err return retVal, err
} }
return Tensor{ctensor}, nil return &Tensor{ctensor}, nil
} }
// Performs the forward pass for a model on some specified ivalue input. // Performs the forward pass for a model on some specified ivalue input.
@ -1066,9 +1066,9 @@ func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
// Implement Module for CModule: // Implement Module for CModule:
// ============================= // =============================
func (cm CModule) Forward(tensor Tensor) (retVal Tensor, err error) { func (cm CModule) Forward(tensor *Tensor) (retVal *Tensor, err error) {
var tensors []Tensor = []Tensor{tensor} var tensors []Tensor = []Tensor{*tensor}
return cm.ForwardTs(tensors) return cm.ForwardTs(tensors)
} }
@ -1076,7 +1076,7 @@ func (cm CModule) Forward(tensor Tensor) (retVal Tensor, err error) {
// ====================================== // ======================================
// Apply forwards tensor itself through a module. // Apply forwards tensor itself through a module.
func (ts Tensor) ApplyCModule(m CModule) (retVal Tensor) { func (ts *Tensor) ApplyCModule(m CModule) (retVal *Tensor) {
retVal, err := m.Forward(ts) retVal, err := m.Forward(ts)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -59,7 +59,7 @@ func TestModuleForwardTs(t *testing.T) {
ts1 := ts.TensorFrom([]int64{42}) ts1 := ts.TensorFrom([]int64{42})
ts2 := ts.TensorFrom([]int64{1337}) ts2 := ts.TensorFrom([]int64{1337})
res, err := foo.ForwardTs([]ts.Tensor{ts1, ts2}) res, err := foo.ForwardTs([]ts.Tensor{*ts1, *ts2})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -83,8 +83,8 @@ func TestModuleForwardIValue(t *testing.T) {
ts1 := ts.TensorFrom([]int64{42}) ts1 := ts.TensorFrom([]int64{42})
ts2 := ts.TensorFrom([]int64{1337}) ts2 := ts.TensorFrom([]int64{1337})
iv1 := ts.NewIValue(ts1) iv1 := ts.NewIValue(*ts1)
iv2 := ts.NewIValue(ts2) iv2 := ts.NewIValue(*ts2)
got, err := foo.ForwardIs([]ts.IValue{iv1, iv2}) got, err := foo.ForwardIs([]ts.IValue{iv1, iv2})
if err != nil { if err != nil {
@ -93,7 +93,7 @@ func TestModuleForwardIValue(t *testing.T) {
expectedTs1 := ts.TensorFrom([]int64{1421}) expectedTs1 := ts.TensorFrom([]int64{1421})
expectedTs2 := ts.TensorFrom([]int64{-1295}) expectedTs2 := ts.TensorFrom([]int64{-1295})
want := ts.NewIValue([]ts.Tensor{expectedTs1, expectedTs2}) want := ts.NewIValue([]ts.Tensor{*expectedTs1, *expectedTs2})
if !reflect.DeepEqual(want.Name(), got.Name()) { if !reflect.DeepEqual(want.Name(), got.Name()) {
t.Errorf("Expected Ivalue Name: %v\n", want.Name()) t.Errorf("Expected Ivalue Name: %v\n", want.Name())

View File

@ -9,7 +9,7 @@ package tensor
// be registered, and will have their parameters converted too when you call .cuda(), etc. // be registered, and will have their parameters converted too when you call .cuda(), etc.
type Module interface { type Module interface {
// ModuleT // ModuleT
Forward(xs Tensor) Tensor Forward(xs *Tensor) *Tensor
} }
// ModuleT is a `Module` with an additional train parameter // ModuleT is a `Module` with an additional train parameter
@ -17,7 +17,7 @@ type Module interface {
// between training and evaluation. E.g. When using dropout or batch-normalization. // between training and evaluation. E.g. When using dropout or batch-normalization.
type ModuleT interface { type ModuleT interface {
// Forward(xs Tensor) Tensor // Forward(xs Tensor) Tensor
ForwardT(xs Tensor, train bool) Tensor ForwardT(xs *Tensor, train bool) *Tensor
} }
/* /*
@ -99,18 +99,18 @@ type ModuleT interface {
// ====================================== // ======================================
// Apply forwards tensor itself through a module. // Apply forwards tensor itself through a module.
func (ts Tensor) Apply(m Module) (retVal Tensor) { func (ts *Tensor) Apply(m Module) (retVal *Tensor) {
return m.Forward(ts) return m.Forward(ts)
} }
// Apply forwards tensor itself through a module T. // Apply forwards tensor itself through a module T.
func (ts Tensor) ApplyT(m ModuleT, train bool) (retVal Tensor) { func (ts *Tensor) ApplyT(m ModuleT, train bool) (retVal *Tensor) {
return m.ForwardT(ts, train) return m.ForwardT(ts, train)
} }
// ApplyOpt forwards a tensor itself through a module if given, shallow-copies // ApplyOpt forwards a tensor itself through a module if given, shallow-copies
// the tensor otherwise. // the tensor otherwise.
func (ts Tensor) ApplyOpt(opts ...ModuleOption) (retVal Tensor) { func (ts *Tensor) ApplyOpt(opts ...ModuleOption) (retVal *Tensor) {
switch { switch {
case len(opts) > 0: case len(opts) > 0:
@ -131,7 +131,7 @@ func WithModule(m Module) ModuleOption {
// ApplyOptT forwards a tensor itself through a module T if given, shallow-copies // ApplyOptT forwards a tensor itself through a module T if given, shallow-copies
// the tensor otherwise. // the tensor otherwise.
func (ts Tensor) ApplyOptT(train bool, opts ...ModuleTOption) (retVal Tensor) { func (ts *Tensor) ApplyOptT(train bool, opts ...ModuleTOption) (retVal *Tensor) {
switch { switch {
case len(opts) > 0: case len(opts) > 0:

File diff suppressed because it is too large Load Diff

View File

@ -11,20 +11,18 @@ type COptimizer struct {
} }
// Adam returns Adam optimizer // Adam returns Adam optimizer
func Adam(lr, beta1, beta2, weightDecay float64) (retVal COptimizer, err error) { func Adam(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
coptimizer := lib.AtoAdam(lr, beta1, beta2, weightDecay) coptimizer := lib.AtoAdam(lr, beta1, beta2, weightDecay)
err = TorchErr() if err := TorchErr(); err != nil {
if err != nil { return nil, err
return retVal, err
} }
retVal = COptimizer{coptimizer} return &COptimizer{coptimizer}, nil
return retVal, nil
} }
// RmsProp returns RMSProp optimizer // RmsProp returns RMSProp optimizer
func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (retVal COptimizer, err error) { func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (*COptimizer, error) {
var centeredCInt int var centeredCInt int
switch centered { switch centered {
case true: case true:
@ -34,19 +32,15 @@ func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (retVal COptim
} }
coptimizer := lib.AtoRmsProp(lr, alpha, eps, wd, momentum, centeredCInt) coptimizer := lib.AtoRmsProp(lr, alpha, eps, wd, momentum, centeredCInt)
err = TorchErr() if err := TorchErr(); err != nil {
if err != nil { return nil, err
return retVal, err
} }
retVal = COptimizer{coptimizer} return &COptimizer{coptimizer}, nil
return retVal, nil
} }
// Sgd returns SGD optimizer // Sgd returns SGD optimizer
func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (retVal COptimizer, err error) { func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (*COptimizer, error) {
var nesterovCInt int var nesterovCInt int
switch nesterov { switch nesterov {
case true: case true:
@ -56,18 +50,15 @@ func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (retVal COptimizer,
} }
coptimizer := lib.AtoSgd(lr, momentum, dampening, wd, nesterovCInt) coptimizer := lib.AtoSgd(lr, momentum, dampening, wd, nesterovCInt)
err = TorchErr() if err := TorchErr(); err != nil {
if err != nil { return nil, err
return retVal, err
} }
retVal = COptimizer{coptimizer} return &COptimizer{coptimizer}, nil
return retVal, nil
} }
// AddParameters adds parameters as a slice of tensors to optimizer // AddParameters adds parameters as a slice of tensors to optimizer
func (co COptimizer) AddParameters(tensors []Tensor) (err error) { func (co *COptimizer) AddParameters(tensors []Tensor) error {
var ctensors []lib.Ctensor var ctensors []lib.Ctensor
for _, t := range tensors { for _, t := range tensors {
@ -82,35 +73,35 @@ func (co COptimizer) AddParameters(tensors []Tensor) (err error) {
} }
// SetLeanringRate sets learning rate for the optimizer // SetLeanringRate sets learning rate for the optimizer
func (co COptimizer) SetLearningRate(lr float64) (err error) { func (co *COptimizer) SetLearningRate(lr float64) error {
lib.AtoSetLearningRate(co.coptimizer, lr) lib.AtoSetLearningRate(co.coptimizer, lr)
return TorchErr() return TorchErr()
} }
// SetMomentum sets a momentum for the optimizer // SetMomentum sets a momentum for the optimizer
func (co COptimizer) SetMomentum(m float64) (err error) { func (co *COptimizer) SetMomentum(m float64) error {
lib.AtoSetMomentum(co.coptimizer, m) lib.AtoSetMomentum(co.coptimizer, m)
return TorchErr() return TorchErr()
} }
// ZeroGrad sets gradients to zero // ZeroGrad sets gradients to zero
func (co COptimizer) ZeroGrad() (err error) { func (co *COptimizer) ZeroGrad() error {
lib.AtoZeroGrad(co.coptimizer) lib.AtoZeroGrad(co.coptimizer)
return TorchErr() return TorchErr()
} }
// Steps proceeds optimizer // Steps proceeds optimizer
func (co COptimizer) Step() (err error) { func (co *COptimizer) Step() error {
lib.AtoStep(co.coptimizer) lib.AtoStep(co.coptimizer)
return TorchErr() return TorchErr()
} }
// Drop removes optimizer and frees up memory. // Drop removes optimizer and frees up memory.
func (co COptimizer) Drop() { func (co *COptimizer) Drop() {
lib.AtoFree(co.coptimizer) lib.AtoFree(co.coptimizer)
if err := TorchErr(); err != nil { if err := TorchErr(); err != nil {

View File

@ -7,7 +7,7 @@ import (
) )
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets. // CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) { func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
weight := NewTensor() weight := NewTensor()
reduction := int64(1) // Mean of loss reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100) ignoreIndex := int64(-100)
@ -18,13 +18,13 @@ func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
// AccuracyForLogits returns the average accuracy for some given logits assuming that // AccuracyForLogits returns the average accuracy for some given logits assuming that
// targets represent ground-truth. // targets represent ground-truth.
func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) { func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
argmax := ts.MustArgmax(-1, false, true) argmax := ts.MustArgmax(-1, false, true)
eq1 := argmax.MustEq1(targets, true) eq1 := argmax.MustEq1(targets, true)
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true) return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
} }
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) { func (ts *Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal *Tensor) {
return ts.MustMaxPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del) return ts.MustMaxPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
} }

View File

@ -13,7 +13,7 @@ import (
// NOTE. This is a temporarily patched to make it run. // NOTE. This is a temporarily patched to make it run.
// TODO. make change at generator for []Tensor input // TODO. make change at generator for []Tensor input
func (ts Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor, err error) { func (ts *Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
// NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first // NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first
// Ctensor will have address given by `ctensorPtr1` here. // Ctensor will have address given by `ctensorPtr1` here.
@ -55,11 +55,11 @@ func (ts Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numL
return output, h, c, err return output, h, c, err
} }
return Tensor{ctensor: *ctensorPtr1}, Tensor{ctensor: *ctensorPtr2}, Tensor{ctensor: *ctensorPtr3}, nil return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, &Tensor{ctensor: *ctensorPtr3}, nil
} }
func (ts Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor) { func (ts *Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst) output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil { if err != nil {
@ -69,7 +69,7 @@ func (ts Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool,
return output, h, c return output, h, c
} }
func (ts Tensor) Gru(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor, err error) { func (ts *Tensor) Gru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
// NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land. // NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land.
// The first Ctensor will have address given by `ctensorPtr1` here. // The first Ctensor will have address given by `ctensorPtr1` here.
@ -105,11 +105,11 @@ func (ts Tensor) Gru(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers i
return output, h, err return output, h, err
} }
return Tensor{ctensor: *ctensorPtr1}, Tensor{ctensor: *ctensorPtr2}, nil return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, nil
} }
func (ts Tensor) MustGru(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor) { func (ts *Tensor) MustGru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst) output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -118,7 +118,7 @@ func (ts Tensor) MustGru(hx Tensor, paramsData []Tensor, hasBiases bool, numLaye
return output, h return output, h
} }
func (ts Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1 Tensor, ts2 Tensor, err error) { func (ts *Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor, err error) {
// NOTE: `lib.AtgTopk` will return 2 tensors in C memory. First tensor pointer // NOTE: `lib.AtgTopk` will return 2 tensors in C memory. First tensor pointer
// is given by ctensorPtr1 // is given by ctensorPtr1
@ -139,10 +139,10 @@ func (ts Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1 Tensor
return ts1, ts2, err return ts1, ts2, err
} }
return Tensor{ctensor: *ctensorPtr1}, Tensor{ctensor: *ctensorPtr2}, nil return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, nil
} }
func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Tensor, ts2 Tensor) { func (ts *Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor) {
ts1, ts2, err := ts.TopK(k, dim, largest, sorted) ts1, ts2, err := ts.TopK(k, dim, largest, sorted)
if err != nil { if err != nil {
@ -154,7 +154,7 @@ func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Te
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated // NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
// with default weight, reduction and ignoreIndex // with default weight, reduction and ignoreIndex
func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) { func (ts *Tensor) NLLLoss(target Tensor, del bool) (retVal *Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
@ -169,12 +169,12 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
return retVal, err return retVal, err
} }
retVal = Tensor{ctensor: *ptr} retVal = &Tensor{ctensor: *ptr}
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) { func (ts *Tensor) MustNLLLoss(target Tensor, del bool) (retVal *Tensor) {
retVal, err := ts.NLLLoss(target, del) retVal, err := ts.NLLLoss(target, del)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -285,7 +285,7 @@ func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
} }
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim); // tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
func (ts Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) { func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim) ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
return retVal, err return retVal, err
@ -307,7 +307,7 @@ func (ts Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) { func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) {
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
} }
@ -321,7 +321,7 @@ func (ts Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor)
} }
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len); // tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
func (ts Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) { func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
var ctensors []lib.Ctensor var ctensors []lib.Ctensor
for _, t := range tensors { for _, t := range tensors {
@ -348,7 +348,7 @@ func (ts Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) { func (ts *Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) {
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
} }
@ -362,7 +362,7 @@ func (ts Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) {
} }
// tensor *atg_nonzero_numpy(tensor self); // tensor *atg_nonzero_numpy(tensor self);
func (ts Tensor) NonzeroNumpy() (retVal []Tensor, err error) { func (ts *Tensor) NonzeroNumpy() (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor) ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
@ -384,7 +384,7 @@ func (ts Tensor) NonzeroNumpy() (retVal []Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) { func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
} }
@ -403,7 +403,7 @@ func (ts Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
// - splitSize size of a single chunk // - splitSize size of a single chunk
// - dim dimension along which to split the tensor. // - dim dimension along which to split the tensor.
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html // Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) { func (ts *Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim) ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
@ -430,7 +430,7 @@ func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) { func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
} }
@ -449,7 +449,7 @@ func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
// - splitSizes slice of sizes for each chunk // - splitSizes slice of sizes for each chunk
// - dim dimension along which to split the tensor. // - dim dimension along which to split the tensor.
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html // Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) { func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim) ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
@ -476,7 +476,7 @@ func (ts Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor,
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) { func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) {
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
} }
@ -490,7 +490,7 @@ func (ts Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (re
} }
// tensor *atg_unbind(tensor self, int64_t dim); // tensor *atg_unbind(tensor self, int64_t dim);
func (ts Tensor) Unbind(dim int64) (retVal []Tensor, err error) { func (ts *Tensor) Unbind(dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim) ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
@ -512,7 +512,7 @@ func (ts Tensor) Unbind(dim int64) (retVal []Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) { func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
if del { if del {
defer ts.MustDrop() defer ts.MustDrop()
} }

View File

@ -12,19 +12,19 @@ type Scalar struct {
} }
// IntScalar creates a integer scalar // IntScalar creates a integer scalar
func IntScalar(v int64) Scalar { func IntScalar(v int64) *Scalar {
cscalar := lib.AtsInt(v) cscalar := lib.AtsInt(v)
return Scalar{cscalar} return &Scalar{cscalar}
} }
// FloatScalar creates a float scalar // FloatScalar creates a float scalar
func FloatScalar(v float64) Scalar { func FloatScalar(v float64) *Scalar {
cscalar := lib.AtsFloat(v) cscalar := lib.AtsFloat(v)
return Scalar{cscalar} return &Scalar{cscalar}
} }
// ToInt returns a integer value // ToInt returns a integer value
func (sc Scalar) ToInt() (retVal int64, err error) { func (sc *Scalar) ToInt() (retVal int64, err error) {
retVal = lib.AtsToInt(sc.cscalar) retVal = lib.AtsToInt(sc.cscalar)
err = TorchErr() err = TorchErr()
if err != nil { if err != nil {
@ -35,7 +35,7 @@ func (sc Scalar) ToInt() (retVal int64, err error) {
} }
// ToFloat returns a float value // ToFloat returns a float value
func (sc Scalar) ToFloat() (retVal float64, err error) { func (sc *Scalar) ToFloat() (retVal float64, err error) {
retVal = lib.AtsToFloat(sc.cscalar) retVal = lib.AtsToFloat(sc.cscalar)
err = TorchErr() err = TorchErr()
if err != nil { if err != nil {
@ -46,7 +46,7 @@ func (sc Scalar) ToFloat() (retVal float64, err error) {
} }
// ToString returns a string representation of scalar value // ToString returns a string representation of scalar value
func (sc Scalar) ToString() (retVal string, err error) { func (sc *Scalar) ToString() (retVal string, err error) {
retVal = lib.AtsToString(sc.cscalar) retVal = lib.AtsToString(sc.cscalar)
err = TorchErr() err = TorchErr()
if err != nil { if err != nil {
@ -60,12 +60,12 @@ func (sc Scalar) ToString() (retVal string, err error) {
// //
// TODO: Really? after running s.Drop() and s.ToInt() // TODO: Really? after running s.Drop() and s.ToInt()
// it returns Zero. // it returns Zero.
func (sc Scalar) Drop() (err error) { func (sc *Scalar) Drop() (err error) {
lib.AtsFree(sc.cscalar) lib.AtsFree(sc.cscalar)
return TorchErr() return TorchErr()
} }
func (sc Scalar) MustDrop() { func (sc *Scalar) MustDrop() {
lib.AtsFree(sc.cscalar) lib.AtsFree(sc.cscalar)
if err := TorchErr(); err != nil { if err := TorchErr(); err != nil {
log.Fatal(err) log.Fatal(err)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff