commit
12bb50c5fa
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// DataLoader combines a dataset and a sampler and provides
|
||||
|
|
|
@ -85,7 +85,7 @@ func TestNewMapDataset(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMaptDataset_Len(t *testing.T) {
|
||||
func TestMapDataset_Len(t *testing.T) {
|
||||
var data map[string]int = map[string]int{"one": 1, "two": 2}
|
||||
ds, err := dutil.NewMapDataset(data)
|
||||
if err != nil {
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
"github.com/sugarme/gotch/vision/aug"
|
||||
// ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
var device string
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -6,8 +6,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "github.com/sugarme/gotch/vision"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -2,18 +2,17 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
// "log"
|
||||
|
||||
"github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
x := tensor.TensorFrom([]float64{2.0})
|
||||
x := ts.TensorFrom([]float64{2.0})
|
||||
x = x.MustSetRequiresGrad(true, false)
|
||||
x.ZeroGrad()
|
||||
|
||||
xy := tensor.TensorFrom([]float64{2.0})
|
||||
xz := tensor.TensorFrom([]float64{3.0})
|
||||
xy := ts.TensorFrom([]float64{2.0})
|
||||
xz := ts.TensorFrom([]float64{3.0})
|
||||
|
||||
y := x.MustMul(xy, false)
|
||||
z := x.MustMul(xz, false)
|
||||
|
@ -25,9 +24,9 @@ func main() {
|
|||
xgrad = x.MustGrad(false)
|
||||
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
||||
|
||||
isGradEnabled := tensor.MustGradSetEnabled(false)
|
||||
isGradEnabled := ts.MustGradSetEnabled(false)
|
||||
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||
isGradEnabled = tensor.MustGradSetEnabled(true)
|
||||
isGradEnabled = ts.MustGradSetEnabled(true)
|
||||
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||
|
||||
}
|
||||
|
|
|
@ -4,45 +4,45 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
ts, err := tensor.OfSlice([]float64{1.3, 29.7})
|
||||
x, err := ts.OfSlice([]float64{1.3, 29.7})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
path := "file.pt"
|
||||
ts.MustSave(path)
|
||||
x.MustSave(path)
|
||||
|
||||
loadedTs := tensor.MustLoad(path)
|
||||
loadedTs := ts.MustLoad(path)
|
||||
|
||||
loadedTs.Print()
|
||||
|
||||
ts1 := tensor.MustOfSlice([]float64{1.3, 29.7})
|
||||
ts2 := tensor.MustOfSlice([]float64{2.1, 31.2})
|
||||
ts1 := ts.MustOfSlice([]float64{1.3, 29.7})
|
||||
ts2 := ts.MustOfSlice([]float64{2.1, 31.2})
|
||||
|
||||
var namedTensors []tensor.NamedTensor = []tensor.NamedTensor{
|
||||
var namedTensors []ts.NamedTensor = []ts.NamedTensor{
|
||||
{Name: "ts1", Tensor: ts1},
|
||||
{Name: "ts2", Tensor: ts2},
|
||||
}
|
||||
|
||||
pathMulti := "file_multi.pt"
|
||||
|
||||
// err = tensor.SaveMulti(namedTensors, pathMulti)
|
||||
// err = ts.SaveMulti(namedTensors, pathMulti)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
err = tensor.SaveMultiNew(namedTensors, pathMulti)
|
||||
err = ts.SaveMultiNew(namedTensors, pathMulti)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var data []tensor.NamedTensor
|
||||
var data []ts.NamedTensor
|
||||
|
||||
data = tensor.MustLoadMulti(pathMulti)
|
||||
data = ts.MustLoadMulti(pathMulti)
|
||||
|
||||
for _, v := range data {
|
||||
v.Tensor.Print()
|
||||
|
@ -50,18 +50,18 @@ func main() {
|
|||
|
||||
device := gotch.NewCuda()
|
||||
|
||||
data = tensor.MustLoadMultiWithDevice(pathMulti, device)
|
||||
data = ts.MustLoadMultiWithDevice(pathMulti, device)
|
||||
for _, v := range data {
|
||||
v.Tensor.Print()
|
||||
}
|
||||
|
||||
tsString := ts.MustToString(80)
|
||||
tsString := x.MustToString(80)
|
||||
|
||||
fmt.Printf("Tensor String: \n%v\n", tsString)
|
||||
|
||||
imagePath := "mnist-sample.png"
|
||||
|
||||
imageTs, err := tensor.LoadHwc(imagePath)
|
||||
imageTs, err := ts.LoadHwc(imagePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type Block struct {
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
|
||||
// "image/jpeg"
|
||||
"io/ioutil"
|
||||
|
||||
|
@ -15,7 +16,7 @@ import (
|
|||
"golang.org/x/image/font"
|
||||
|
||||
"github.com/sugarme/gotch/example/yolo/freetype"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -758,7 +758,7 @@ let write_cpp funcs filename =
|
|||
let write_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
let pm s = print_inline out_ml s in
|
||||
pm "package tensor" ;
|
||||
pm "package ts" ;
|
||||
pm "\n\n" ;
|
||||
pm "// NOTE. THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
|
||||
pm "\n\n" ;
|
||||
|
@ -989,7 +989,7 @@ let write_wrapper funcs filename =
|
|||
let write_must_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
let pm s = print_inline out_ml s in
|
||||
pm "package tensor" ;
|
||||
pm "package ts" ;
|
||||
pm "\n\n" ;
|
||||
pm "// NOTE. THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
|
||||
pm "\n\n" ;
|
||||
|
@ -1347,5 +1347,5 @@ let () =
|
|||
run ~yaml_filename:"gen/pytorch/Declarations-v1.10.0.yaml"
|
||||
~cpp_filename:"libtch/torch_api_generated"
|
||||
~ffi_filename:"libtch/c-generated.go"
|
||||
~must_wrapper_filename:"tensor/must-tensor-generated.go"
|
||||
~wrapper_filename:"tensor/tensor-generated.go"
|
||||
~must_wrapper_filename:"ts/must-tensor-generated.go"
|
||||
~wrapper_filename:"ts/tensor-generated.go"
|
||||
|
|
|
@ -5,7 +5,7 @@ package nn
|
|||
import (
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Batch-normalization config.
|
||||
|
|
|
@ -5,7 +5,7 @@ package nn
|
|||
import (
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type ConvTranspose1DConfig struct {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Conv1DConfig:
|
||||
|
|
|
@ -3,7 +3,7 @@ package nn
|
|||
// Layers defined by closure
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type Func struct {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type Init interface {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"log"
|
||||
"strings"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// TrainableCModule is a trainable version of JIT Pytorch module
|
||||
|
|
|
@ -2,7 +2,7 @@ package nn
|
|||
|
||||
// A layer-normalization layer.
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Layer-normalization config.
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// LinearConfig is a configuration for a linear layer
|
||||
|
|
|
@ -2,7 +2,7 @@ package nn
|
|||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type lossFnOptions struct {
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Optimizer is a struct object to run gradient descent.
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestOptimizer(t *testing.T) {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Dropout:
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type State interface{}
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func gruTest(rnnConfig *nn.RNNConfig, t *testing.T) {
|
||||
|
|
|
@ -4,8 +4,7 @@ package nn
|
|||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "reflect"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Sequential is a layer (container) that combines multiple other layers.
|
||||
|
|
|
@ -3,7 +3,7 @@ package nn
|
|||
// Sparse layers
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Configuration option for an embedding layer.
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func embeddingTest(embeddingConfig *nn.EmbeddingConfig, t *testing.T) {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// SEP is a separator to separate path elements in the tensor names.
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestVarStoreEntry(t *testing.T) {
|
||||
|
|
|
@ -21,7 +21,7 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
const hexMagicNumber = "1950a86a20f9469cfc6c"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,10 +1,10 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func ExampleTensor_MustArange() {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
|
@ -9,7 +9,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestTextData_NewTextData(t *testing.T) {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import "C"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// Indexing operations for tensor
|
||||
// It defines a `i` indexing operation. This can be used in various scenarios.
|
|
@ -1,11 +1,11 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestIntegerIndex(t *testing.T) {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// JIT interface to run model trained/saved using PyTorch Python API.
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func roundTrip(v interface{}, t *testing.T) {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// Module interface is a container with only one method `Forward`
|
||||
//
|
18121
ts/must-tensor-generated.go
Normal file
18121
ts/must-tensor-generated.go
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import (
|
||||
"archive/zip"
|
|
@ -1,11 +1,11 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestNpyHeaderParse(t *testing.T) {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import (
|
||||
"log"
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// Other tensor methods
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func ExampleTensor_Split(t *testing.T) {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// #include "stdlib.h"
|
||||
import "C"
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import (
|
||||
"bytes"
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
import (
|
||||
// "unsafe"
|
28745
ts/tensor-generated.go
Normal file
28745
ts/tensor-generated.go
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
//#include "stdlib.h"
|
||||
//#include "stdbool.h"
|
|
@ -1,11 +1,11 @@
|
|||
package tensor_test
|
||||
package ts_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func TestTensorInit(t *testing.T) {
|
|
@ -1,4 +1,4 @@
|
|||
package tensor
|
||||
package ts
|
||||
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
|
@ -2,7 +2,7 @@ package vision
|
|||
|
||||
import (
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// AlexNet implementation
|
||||
|
|
|
@ -2,7 +2,7 @@ package aug
|
|||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomAffine is transformation of the image keeping center invariant.
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type GaussianBlur struct {
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Ref. https://github.com/pytorch/vision/blob/f1d734213af65dc06e777877d315973ba8386080/torchvision/transforms/functional_tensor.py
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomAutocontrast autocontrasts the pixels of the given image randomly with a given probability.
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
// "math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type RandomCrop struct {
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomEqualize equalizes the histogram of the given image randomly with a given probability.
|
||||
|
|
|
@ -2,7 +2,7 @@ package aug
|
|||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomHorizontalFlip horizontally flips the given image randomly with a given probability.
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func gaussianKernel1D(ks int64, sigma float64, dtype gotch.DType, device gotch.Device) *ts.Tensor {
|
||||
|
|
|
@ -3,8 +3,7 @@ package aug
|
|||
import (
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// GrayScale converts image to grayscale.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type RandomInvert struct {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Normalize normalizes a tensor image with mean and standard deviation.
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
// "fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomPerspective performs a random perspective transformation of the given image with a given probability.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomPosterize posterizes the image randomly with a given probability by reducing the
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomRotate randomly rotates a tensor image within a specifed angle range (degree).
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// RandomSolarize solarizes the image randomly with a given probability by inverting all pixel
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Transformer is an interface that can transform an image tensor.
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"math/rand"
|
||||
"time"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type Dataset struct {
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func dnConv2d(p *nn.Path, cIn, cOut, ksize, padding, stride int64) *nn.Conv2D {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// (height, width, channel) -> (channel, height, width)
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Helper functions for ImageNet like datasets.
|
||||
|
|
|
@ -4,7 +4,7 @@ package vision
|
|||
|
||||
import (
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
func convBn(p *nn.Path, cIn, cOut, ksize, pad, stride int64) ts.ModuleT {
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// readInt32 read 4 bytes and convert to MSB first (big endian) interger.
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user