commit
4bbff96bda
|
@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
||||
## [0.4.0]
|
||||
- **Update libtorch to 1.9**. Generated **1716 APIs**. There are APIs naming changes ie. `Name1` change to `NameDim` or `NameTensor`.
|
||||
|
||||
## [0.3.14]
|
||||
- Fixed temporary fix huge number of learning group returned from C at `libtch/tensor.go AtoGetLearningRates`
|
||||
- Fixed incorrect `nn.AdamWConfig` and some documentation.
|
||||
|
|
10
README.md
10
README.md
|
@ -18,12 +18,12 @@ Gotch is in active development mode and may have API breaking changes. Feel free
|
|||
|
||||
## Dependencies
|
||||
|
||||
- **Libtorch** C++ v1.7.0 library of [Pytorch](https://pytorch.org/)
|
||||
- **Libtorch** C++ v1.9.0 library of [Pytorch](https://pytorch.org/)
|
||||
|
||||
## Installation
|
||||
|
||||
- Default CUDA version is `10.1` if CUDA is available otherwise using CPU version.
|
||||
- Default Pytorch C++ API version is `1.7.0`
|
||||
- Default Pytorch C++ API version is `1.9.0`
|
||||
|
||||
**NOTE**: `libtorch` will be installed at **`/usr/local/lib`**
|
||||
|
||||
|
@ -51,7 +51,7 @@ Gotch is in active development mode and may have API breaking changes. Feel free
|
|||
```bash
|
||||
wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-gotch.sh
|
||||
chmod +x setup-gotch.sh
|
||||
export CUDA_VER=cpu && export GOTCH_VER=v0.3.14 && bash setup-gotch.sh
|
||||
export CUDA_VER=cpu && export GOTCH_VER=v0.4.0 && bash setup-gotch.sh
|
||||
```
|
||||
|
||||
### GPU
|
||||
|
@ -89,9 +89,9 @@ Gotch is in active development mode and may have API breaking changes. Feel free
|
|||
wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-gotch.sh
|
||||
chmod +x setup-gotch.sh
|
||||
# CUDA 10.1
|
||||
export CUDA_VER=10.1 && export GOTCH_VER=v0.3.14 && bash setup-gotch.sh
|
||||
export CUDA_VER=10.1 && export GOTCH_VER=v0.4.0 && bash setup-gotch.sh
|
||||
# CUDA 11.0
|
||||
export CUDA_VER=11.0 && export GOTCH_VER=v0.3.14 && bash setup-gotch.sh
|
||||
export CUDA_VER=11.0 && export GOTCH_VER=v0.4.0 && bash setup-gotch.sh
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestDataLoader_Next(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
want := 100
|
||||
want := []int{100}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Want: %v\n", want)
|
||||
|
|
|
@ -22,9 +22,9 @@ func createTensors(samples int) []ts.Tensor {
|
|||
s := ts.FloatScalar(float64(0.23))
|
||||
|
||||
for i := 0; i < 1; i++ {
|
||||
t := ts.MustOfSlice(data).MustMul1(s, true)
|
||||
t := ts.MustOfSlice(data).MustMulScalar(s, true)
|
||||
|
||||
tensors = append(tensors, t)
|
||||
tensors = append(tensors, *t)
|
||||
}
|
||||
|
||||
return tensors
|
||||
|
@ -72,7 +72,7 @@ func main() {
|
|||
tensors := createTensors(10000)
|
||||
var gpuTensors []ts.Tensor
|
||||
for _, t := range tensors {
|
||||
gpuTensors = append(gpuTensors, t.MustTo(gpu, true))
|
||||
gpuTensors = append(gpuTensors, *t.MustTo(gpu, true))
|
||||
}
|
||||
|
||||
for _, t := range gpuTensors {
|
||||
|
|
|
@ -40,12 +40,12 @@ func runLinear() {
|
|||
loss.MustBackward()
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.Add_(ws.MustGrad(false).MustMul1(ts.FloatScalar(-1.0), true))
|
||||
bs.Add_(bs.MustGrad(false).MustMul1(ts.FloatScalar(-1.0), true))
|
||||
ws.Add_(ws.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
|
||||
bs.Add_(bs.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
|
||||
})
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
||||
|
||||
|
|
358
gen/gen.ml
358
gen/gen.ml
|
@ -1,8 +1,7 @@
|
|||
(* Automatically generate the C++ -> C -> Go bindings.
|
||||
This takes as input the Descriptions.yaml file that gets generated when
|
||||
func (Func.c_go_args_list func) building PyTorch from source.
|
||||
(* Automatically generated C++ -> C -> Go bindings.
|
||||
Input: Declarations-VERSION.yaml artifact generated when building Pytorch from source.
|
||||
Run with: dune exec gen/gen.exe
|
||||
*)
|
||||
*)
|
||||
open Base
|
||||
open Stdio
|
||||
|
||||
|
@ -26,10 +25,18 @@ let excluded_functions =
|
|||
; "backward"
|
||||
; "set_data"
|
||||
; "_amp_non_finite_check_and_unscale_"
|
||||
; "_amp_foreach_non_finite_check_and_unscale_"
|
||||
; "_cummin_helper"
|
||||
; "_cummax_helper"
|
||||
; "retain_grad"
|
||||
; "_validate_sparse_coo_tensor_args" ]
|
||||
; "_validate_sparse_coo_tensor_args"
|
||||
; "_backward"
|
||||
; "size"
|
||||
; "stride"
|
||||
; "_assert_async"
|
||||
; "gradient"
|
||||
; "linalg_vector_norm"
|
||||
; "linalg_vector_norm_out" ]
|
||||
|
||||
let no_tensor_options =
|
||||
Set.of_list
|
||||
|
@ -85,9 +92,12 @@ module Func = struct
|
|||
| DoubleOption
|
||||
| Tensor
|
||||
| TensorOption
|
||||
(* Tensor.t option *)
|
||||
| IntList
|
||||
| TensorOptList
|
||||
| TensorList
|
||||
| TensorOptions
|
||||
(* Tensor kind and device *)
|
||||
| Scalar
|
||||
| ScalarType
|
||||
| Device
|
||||
|
@ -99,8 +109,10 @@ module Func = struct
|
|||
(* `Func` type *)
|
||||
type t =
|
||||
{ name: string
|
||||
; args: arg list
|
||||
; returns: [`fixed of int | `dynamic]
|
||||
; operator_name: string
|
||||
; overload_name: string
|
||||
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
|
||||
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double]
|
||||
; (* number of tensors that are returned *)
|
||||
kind: [`function_ | `method_] }
|
||||
|
||||
|
@ -109,14 +121,14 @@ module Func = struct
|
|||
| "bool" -> Some Bool
|
||||
| "int64_t" -> Some (if is_nullable then Int64Option else Int64)
|
||||
| "double" -> Some (if is_nullable then DoubleOption else Double)
|
||||
| "booltensor" | "indextensor" | "tensor" ->
|
||||
Some (if is_nullable then TensorOption else Tensor)
|
||||
| "tensoroptions" -> Some TensorOptions
|
||||
| "intarrayref" | "intlist" -> Some IntList
|
||||
| "tensorlist" -> Some TensorList
|
||||
| "device" -> Some Device
|
||||
| "scalar" -> Some Scalar
|
||||
| "scalartype" -> Some ScalarType
|
||||
| "at::tensor" -> Some (if is_nullable then TensorOption else Tensor)
|
||||
| "at::tensoroptions" -> Some TensorOptions
|
||||
| "at::intarrayref" -> Some IntList
|
||||
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
|
||||
| "at::tensorlist" -> Some TensorList
|
||||
| "at::device" -> Some Device
|
||||
| "const at::scalar &" | "at::scalar" -> Some Scalar
|
||||
| "at::scalartype" -> Some ScalarType
|
||||
| "std::string" -> Some String
|
||||
| _ -> None
|
||||
|
||||
|
@ -125,7 +137,7 @@ module Func = struct
|
|||
match arg_type with
|
||||
| IntList ->
|
||||
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorList ->
|
||||
| TensorOptList | TensorList ->
|
||||
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
||||
|
@ -145,8 +157,8 @@ module Func = struct
|
|||
| ScalarType -> "int"
|
||||
| Device -> "int"
|
||||
| Scalar -> "scalar"
|
||||
| Int64Option | DoubleOption | String | IntList | TensorList
|
||||
|TensorOptions ->
|
||||
| Int64Option | DoubleOption | String | IntList | TensorOptList
|
||||
|TensorList | TensorOptions ->
|
||||
assert false
|
||||
in
|
||||
Printf.sprintf "%s %s" simple_type_cstring arg_name )
|
||||
|
@ -164,6 +176,9 @@ module Func = struct
|
|||
arg_name
|
||||
| String ->
|
||||
Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
||||
| TensorOptList ->
|
||||
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name
|
||||
arg_name
|
||||
| TensorList ->
|
||||
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name
|
||||
arg_name
|
||||
|
@ -196,14 +211,15 @@ module Func = struct
|
|||
t.name () )
|
||||
|
||||
(*
|
||||
* let replace_map =
|
||||
* Map.of_alist_exn
|
||||
* (module String)
|
||||
* [ ("t", "tr")
|
||||
* ; ("where", "where_")
|
||||
* ; ("view", "view_")
|
||||
* ; ("unsafe", "unsafe_") ]
|
||||
* *)
|
||||
let replace_map =
|
||||
Map.of_alist_exn
|
||||
(module String)
|
||||
[ ("t", "tr")
|
||||
; ("where", "where_")
|
||||
; ("view", "view_")
|
||||
; ("unsafe", "unsafe_")
|
||||
; ("to_device", "to_device_") ]
|
||||
*)
|
||||
|
||||
let is_method t =
|
||||
List.exists t.args ~f:(fun arg ->
|
||||
|
@ -245,6 +261,7 @@ module Func = struct
|
|||
| Device -> single_param "int32"
|
||||
| String -> single_param "string"
|
||||
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||
| Int64Option -> Printf.sprintf "%sVal int64, %sNull int" an an
|
||||
| DoubleOption -> Printf.sprintf "%sVal float64, %sNull int" an an
|
||||
|
@ -268,6 +285,7 @@ module Func = struct
|
|||
| Device -> Printf.sprintf "c%s" an
|
||||
| String -> Printf.sprintf "c%s, c%sLen" an an
|
||||
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| Int64Option -> Printf.sprintf "c%sVal, c%sNull" an an
|
||||
| DoubleOption -> Printf.sprintf "c%sVal, c%sNull" an an
|
||||
|
@ -306,6 +324,12 @@ module Func = struct
|
|||
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||
an an an an
|
||||
| TensorOptList ->
|
||||
Printf.sprintf
|
||||
"\n\
|
||||
c%sDataPtr := (*Ctensor)(unsafe.Pointer(&%sData[0]))\n\
|
||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||
an an an an
|
||||
| TensorList ->
|
||||
Printf.sprintf
|
||||
"\n\
|
||||
|
@ -382,6 +406,7 @@ module Func = struct
|
|||
| Tensor -> "*Tensor"
|
||||
| TensorOption -> "*Tensor"
|
||||
| IntList -> "[]int64"
|
||||
| TensorOptList -> "[]Tensor"
|
||||
| TensorList -> "[]Tensor"
|
||||
| String -> "string"
|
||||
(* TODO. Struct{Kind gotch.DType Device gotch.Device} *)
|
||||
|
@ -435,6 +460,9 @@ module Func = struct
|
|||
List.init v ~f:(fun i -> Printf.sprintf "retVal%d *Tensor" i)
|
||||
|> String.concat ~sep:", " |> Printf.sprintf "%s"
|
||||
| `dynamic -> "retVal []Tensor"
|
||||
| `bool -> "retVal bool"
|
||||
| `int64_t -> "retVal int64"
|
||||
| `double -> "retVal float64"
|
||||
in
|
||||
if is_inplace t then
|
||||
if fallible then Printf.sprintf "err error" else Printf.sprintf ""
|
||||
|
@ -449,6 +477,9 @@ module Func = struct
|
|||
List.init v ~f:(fun i -> Printf.sprintf "retVal%d" i)
|
||||
|> String.concat ~sep:", " |> Printf.sprintf "%s"
|
||||
| `dynamic -> "retVal"
|
||||
| `bool -> "retVal"
|
||||
| `int64_t -> "retVal"
|
||||
| `double -> "retVal"
|
||||
in
|
||||
if is_inplace t then
|
||||
if fallible then Printf.sprintf "err" else Printf.sprintf ""
|
||||
|
@ -511,6 +542,11 @@ module Func = struct
|
|||
\ c%sNull = 0\n\
|
||||
\ }\n"
|
||||
an an an an an an
|
||||
| TensorOptList ->
|
||||
Printf.sprintf
|
||||
" var c%s []lib.Ctensor\n\
|
||||
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
|
||||
an an an an
|
||||
| TensorList ->
|
||||
Printf.sprintf
|
||||
" var c%s []lib.Ctensor\n\
|
||||
|
@ -536,6 +572,8 @@ let read_yaml filename =
|
|||
List.filter_map funcs ~f:(fun yaml ->
|
||||
let map = extract_map yaml in
|
||||
let name = Map.find_exn map "name" |> extract_string in
|
||||
let operator_name = Map.find_exn map "operator_name" |> extract_string in
|
||||
let overload_name = Map.find_exn map "overload_name" |> extract_string in
|
||||
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
||||
let method_of =
|
||||
Map.find_exn map "method_of"
|
||||
|
@ -548,22 +586,26 @@ let read_yaml filename =
|
|||
let return_type =
|
||||
Map.find_exn returns "dynamic_type" |> extract_string
|
||||
in
|
||||
String.( = ) return_type "Tensor"
|
||||
|| String.( = ) return_type "BoolTensor"
|
||||
|| String.( = ) return_type "IndexTensor"
|
||||
String.( = ) return_type "at::Tensor"
|
||||
in
|
||||
let returns = Map.find_exn map "returns" |> extract_list in
|
||||
if List.for_all returns ~f:is_tensor then
|
||||
Some (`fixed (List.length returns))
|
||||
else
|
||||
match returns with
|
||||
| [returns] ->
|
||||
| [returns] -> (
|
||||
let return_type =
|
||||
Map.find_exn (extract_map returns) "dynamic_type"
|
||||
|> extract_string
|
||||
in
|
||||
if String.( = ) return_type "TensorList" then Some `dynamic
|
||||
else None
|
||||
match return_type with
|
||||
| "bool" -> Some `bool
|
||||
| "int64_t" -> Some `int64_t
|
||||
| "double" -> Some `double
|
||||
| "at::TensorList"
|
||||
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" ->
|
||||
Some `dynamic
|
||||
| _ -> None )
|
||||
| [] | _ :: _ :: _ -> None
|
||||
in
|
||||
let kind =
|
||||
|
@ -622,7 +664,13 @@ let read_yaml filename =
|
|||
if Option.is_some default_value then None
|
||||
else raise Not_a_simple_arg )
|
||||
in
|
||||
Some {Func.name; args; returns; kind}
|
||||
Some
|
||||
{ Func.name
|
||||
; operator_name
|
||||
; overload_name
|
||||
; args
|
||||
; returns
|
||||
; kind }
|
||||
with Not_a_simple_arg -> None )
|
||||
else None )
|
||||
|
||||
|
@ -684,8 +732,23 @@ let write_cpp funcs filename =
|
|||
pc " return nullptr;" ;
|
||||
pc "}" ;
|
||||
pc "" ;
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list ) )
|
||||
)
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
|
||||
| (`bool | `int64_t | `double) as returns ->
|
||||
let c_type =
|
||||
match returns with
|
||||
| `bool -> "int"
|
||||
| `int64_t -> "int64_t"
|
||||
| `double -> "double"
|
||||
in
|
||||
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list ;
|
||||
pc " PROTECT(" ;
|
||||
pc " return %s;" (Func.c_call func) ;
|
||||
pc " )" ;
|
||||
pc " return 0;" ;
|
||||
pc "}" ;
|
||||
pc "" ;
|
||||
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list )
|
||||
) )
|
||||
|
||||
let write_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
|
@ -751,7 +814,26 @@ let write_wrapper funcs filename =
|
|||
; "UnsafeChunk"
|
||||
; "UnsafeSplit"
|
||||
; "UnsafeSplitWithSizes"
|
||||
; "AlignTensors" ]
|
||||
; "AlignTensors"
|
||||
; "UnflattenDenseTensors"
|
||||
; "TensorSplit"
|
||||
; "TensorSplitIndices"
|
||||
; "TensorSplitTensorIndicesOrSections"
|
||||
; "QuantizePerTensorTensors"
|
||||
; "Dsplit"
|
||||
; "DsplitArray"
|
||||
; "Hsplit"
|
||||
; "HsplitArray"
|
||||
; "Vsplit"
|
||||
; "VsplitArray"
|
||||
; "DequantizeTensors"
|
||||
; "Atleast1dSequence"
|
||||
; "Atleast2dSequence"
|
||||
; "Atleast3dSequence"
|
||||
; "Index"
|
||||
; "IndexPut"
|
||||
; "IndexPut_"
|
||||
; "_IndexPutImpl_" ]
|
||||
in
|
||||
if
|
||||
List.exists excluded_funcs ~f:(fun name ->
|
||||
|
@ -777,7 +859,8 @@ let write_wrapper funcs filename =
|
|||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n" ;
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n"
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
|
@ -799,10 +882,62 @@ let write_wrapper funcs filename =
|
|||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n" ;
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n"
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `bool ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `int64_t ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `double ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `fixed _ -> pm "" ) ;
|
||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||
pm "// End of implementing Tensor ================================= \n"
|
||||
|
@ -866,7 +1001,26 @@ let write_must_wrapper funcs filename =
|
|||
; "UnsafeChunk"
|
||||
; "UnsafeSplit"
|
||||
; "UnsafeSplitWithSizes"
|
||||
; "AlignTensors" ]
|
||||
; "AlignTensors"
|
||||
; "UnflattenDenseTensors"
|
||||
; "TensorSplit"
|
||||
; "TensorSplitIndices"
|
||||
; "TensorSplitTensorIndicesOrSections"
|
||||
; "QuantizePerTensorTensors"
|
||||
; "Dsplit"
|
||||
; "DsplitArray"
|
||||
; "Hsplit"
|
||||
; "HsplitArray"
|
||||
; "Vsplit"
|
||||
; "VsplitArray"
|
||||
; "DequantizeTensors"
|
||||
; "Atleast1dSequence"
|
||||
; "Atleast2dSequence"
|
||||
; "Atleast3dSequence"
|
||||
; "Index"
|
||||
; "IndexPut"
|
||||
; "IndexPut_"
|
||||
; "_IndexPutImpl_" ]
|
||||
in
|
||||
if
|
||||
List.exists excluded_funcs ~f:(fun name ->
|
||||
|
@ -876,7 +1030,7 @@ let write_must_wrapper funcs filename =
|
|||
match func.returns with
|
||||
| `dynamic ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
|
@ -913,6 +1067,57 @@ let write_must_wrapper funcs filename =
|
|||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `bool ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
if is_method then
|
||||
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||
go_args_list_notype
|
||||
else
|
||||
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||
go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `int64_t ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
if is_method then
|
||||
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||
go_args_list_notype
|
||||
else
|
||||
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||
go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `double ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
if is_method then
|
||||
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||
go_args_list_notype
|
||||
else
|
||||
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||
go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `fixed _ -> pm "" ) ;
|
||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||
pm "// End of implementing Tensor ================================= \n"
|
||||
|
@ -969,19 +1174,52 @@ let write_ffi funcs filename =
|
|||
in
|
||||
match func.Func.returns with
|
||||
| `fixed _ ->
|
||||
pm "func Atg%s(ptr *Ctensor, %s){%s \nC.atg_%s(ptr, %s)\n}"
|
||||
pm "func Atg%s(ptr *Ctensor, %s){%s \n\tC.atg_%s(ptr, %s)\n}"
|
||||
ffifunc_name (Func.c_go_args_list func)
|
||||
(Func.c_go_args_list_body func)
|
||||
exported_name
|
||||
(Func.c_go_args_list_notype func)
|
||||
| `dynamic -> pm ""
|
||||
| `bool ->
|
||||
pm "func Atg%s(%s) bool{%s" ffifunc_name
|
||||
(Func.c_go_args_list func)
|
||||
(Func.c_go_args_list_body func) ;
|
||||
pm "\t cResult := C.atg_%s(%s)" exported_name
|
||||
(Func.c_go_args_list_notype func) ;
|
||||
pm "\t cbool := *(*int)(unsafe.Pointer(&cResult))" ;
|
||||
pm "\t if cbool == 1{return true}" ;
|
||||
pm "\t return false" ;
|
||||
pm "}"
|
||||
| `int64_t ->
|
||||
pm "func Atg%s(%s) int64{%s" ffifunc_name
|
||||
(Func.c_go_args_list func)
|
||||
(Func.c_go_args_list_body func) ;
|
||||
pm "\t cResult := C.atg_%s(%s)" exported_name
|
||||
(Func.c_go_args_list_notype func) ;
|
||||
pm "\t return *(*int64)(unsafe.Pointer(&cResult))" ;
|
||||
pm "}"
|
||||
| `double ->
|
||||
pm "func Atg%s(%s) float64{%s" ffifunc_name
|
||||
(Func.c_go_args_list func)
|
||||
(Func.c_go_args_list_body func) ;
|
||||
pm "\t cResult := C.atg_%s(%s)" exported_name
|
||||
(Func.c_go_args_list_notype func) ;
|
||||
pm "\t return *(*float64)(unsafe.Pointer(&cResult))" ;
|
||||
pm "}"
|
||||
(* TODO: need more implement here *)
|
||||
(* pm "func Atg%s(%s)(retValPtr *Ctensor)" *)
|
||||
(* (Func.go_name exported_name) *)
|
||||
(* (Func.c_go_args_list func) *) ) )
|
||||
|
||||
let methods =
|
||||
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
||||
let c name args =
|
||||
{ Func.name
|
||||
; operator_name= name
|
||||
; overload_name= ""
|
||||
; args
|
||||
; returns= `fixed 1
|
||||
; kind= `method_ }
|
||||
in
|
||||
let ca arg_name arg_type = {Func.arg_name; arg_type; default_value= None} in
|
||||
[ c "grad" [ca "self" Tensor]
|
||||
; c "set_requires_grad" [ca "self" Tensor; ca "r" Bool]
|
||||
|
@ -995,7 +1233,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
|||
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
|
||||
(* Generate some unique names for overloaded functions. *)
|
||||
let funcs =
|
||||
List.map funcs ~f:(fun func -> (String.lowercase func.name, func))
|
||||
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func))
|
||||
|> Map.of_alist_multi (module String)
|
||||
|> Map.to_alist
|
||||
|> List.concat_map ~f:(fun (name, funcs) ->
|
||||
|
@ -1003,11 +1241,35 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
|||
| [] -> assert false
|
||||
| [func] -> [(name, func)]
|
||||
| funcs ->
|
||||
let has_empty_overload =
|
||||
List.exists funcs ~f:(fun (func : Func.t) ->
|
||||
String.is_empty func.overload_name )
|
||||
in
|
||||
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
||||
Int.compare (List.length f1.args) (List.length f2.args) )
|
||||
|> List.mapi ~f:(fun i func ->
|
||||
( (if i = 0 then name else Printf.sprintf "%s%d" name i)
|
||||
, func ) ) )
|
||||
match
|
||||
Int.compare (String.length f1.name)
|
||||
(String.length f2.name)
|
||||
with
|
||||
| 0 ->
|
||||
Int.compare (List.length f1.args) (List.length f2.args)
|
||||
| cmp -> cmp )
|
||||
|> List.mapi ~f:(fun index (func : Func.t) ->
|
||||
let operator_name =
|
||||
String.lowercase func.operator_name
|
||||
in
|
||||
let overload_name =
|
||||
String.lowercase func.overload_name
|
||||
in
|
||||
let name =
|
||||
if
|
||||
String.is_empty overload_name
|
||||
|| (index = 0 && not has_empty_overload)
|
||||
then operator_name
|
||||
else if String.is_suffix operator_name ~suffix:"_" then
|
||||
operator_name ^ overload_name ^ "_"
|
||||
else operator_name ^ "_" ^ overload_name
|
||||
in
|
||||
(name, func) ) )
|
||||
|> Map.of_alist_exn (module String)
|
||||
in
|
||||
write_cpp funcs cpp_filename ;
|
||||
|
@ -1016,7 +1278,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
|||
write_wrapper funcs wrapper_filename
|
||||
|
||||
let () =
|
||||
run ~yaml_filename:"gen/pytorch/Declarations-v1.7.0.yaml"
|
||||
run ~yaml_filename:"gen/pytorch/Declarations-v1.9.0.yaml"
|
||||
~cpp_filename:"libtch/torch_api_generated"
|
||||
~ffi_filename:"libtch/c-generated.go"
|
||||
~must_wrapper_filename:"tensor/must-tensor-generated.go"
|
||||
|
|
129672
gen/pytorch/Declarations-v1.9.0.yaml
Normal file
129672
gen/pytorch/Declarations-v1.9.0.yaml
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1254
libtch/torch_api.cpp
1254
libtch/torch_api.cpp
File diff suppressed because it is too large
Load Diff
|
@ -105,8 +105,13 @@ void at_set_num_threads(int n_threads);
|
|||
|
||||
void at_free(tensor);
|
||||
|
||||
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
|
||||
tensor *outputs, int keep_graph, int create_graph);
|
||||
void at_run_backward(tensor *tensors,
|
||||
int ntensors,
|
||||
tensor *inputs,
|
||||
int ninputs,
|
||||
tensor *outputs,
|
||||
int keep_graph,
|
||||
int create_graph);
|
||||
|
||||
optimizer ato_adam(double learning_rate, double beta1, double beta2,
|
||||
double weight_decay);
|
||||
|
@ -131,7 +136,7 @@ void ato_step(optimizer);
|
|||
void ato_free(optimizer);
|
||||
|
||||
// TT. APIs for learning rate scheduler
|
||||
void ato_set_learning_rates(optimizer, double* learning_rates, int lrs_num);
|
||||
void ato_set_learning_rates(optimizer, double *learning_rates, int lrs_num);
|
||||
int64_t ato_param_group_num(optimizer);
|
||||
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
||||
void ato_add_param_group(optimizer, tensor *params, int param_num);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -77,7 +77,7 @@ func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Ten
|
|||
}
|
||||
|
||||
initTs := ts.MustRandn(dims, gotch.Float, device)
|
||||
return initTs.MustMul1(ts.FloatScalar(r.stdev), true).MustAdd1(ts.FloatScalar(r.mean), true)
|
||||
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||
}
|
||||
|
||||
func (r randnInit) Set(tensor *ts.Tensor) {
|
||||
|
|
12
nn/rnn.go
12
nn/rnn.go
|
@ -107,9 +107,9 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
// if vs.Device().IsCuda() && gotch.Cuda.CudnnIsAvailable() {
|
||||
// TODO: check if Cudnn is available here!!!
|
||||
if vs.Device().IsCuda() {
|
||||
// NOTE. 2 is for LSTM
|
||||
// ref. rnn.cpp in Pytorch
|
||||
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||
// 2: for LSTM
|
||||
// 0: disables projections
|
||||
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 2, hiddenDim, 0, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||
}
|
||||
|
||||
return &LSTM{
|
||||
|
@ -227,9 +227,9 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
|||
}
|
||||
|
||||
if vs.Device().IsCuda() {
|
||||
// NOTE. 3 is for GRU
|
||||
// ref. rnn.cpp in Pytorch
|
||||
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||
// 3: for GRU
|
||||
// 0: disable projections
|
||||
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, 0, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||
}
|
||||
|
||||
return &GRU{
|
||||
|
|
|
@ -68,8 +68,8 @@ func TestSaveLoad(t *testing.T) {
|
|||
u2, v2 := add(vs2.Root())
|
||||
|
||||
ts.NoGrad(func() {
|
||||
u1.Add1_(ts.FloatScalar(42.0))
|
||||
v1.Mul1_(ts.FloatScalar(2.0))
|
||||
u1.AddScalar_(ts.FloatScalar(42.0))
|
||||
v1.MulScalar_(ts.FloatScalar(2.0))
|
||||
})
|
||||
|
||||
wantU1 := float64(42.0)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.14}"
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.4.0}"
|
||||
CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
GOTCH_PATH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.9.0}"
|
||||
CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
|
||||
if [[ -z "${CUDA_VERSION}"=="cpu" ]]; then
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func ExampleTensor_MustArange1() {
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(12), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
func ExampleTensor_MustArange() {
|
||||
tensor := ts.MustArange(ts.FloatScalar(12), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
|
||||
fmt.Printf("%v", tensor)
|
||||
|
||||
|
@ -50,12 +50,12 @@ func ExampleTensor_Matmul() {
|
|||
|
||||
}
|
||||
|
||||
func ExampleTensor_Add1_() {
|
||||
func ExampleTensor_AddScalar_() {
|
||||
// In-place operation
|
||||
ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU)
|
||||
fmt.Println("Before:")
|
||||
ts3.Print()
|
||||
ts3.MustAdd1_(ts.FloatScalar(2.0))
|
||||
ts3.MustAddScalar_(ts.FloatScalar(2.0))
|
||||
fmt.Printf("After (ts3 + 2.0): \n")
|
||||
ts3.Print()
|
||||
ts3.MustDrop()
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
func TestIntegerIndex(t *testing.T) {
|
||||
// [ 0 1 2
|
||||
// 3 4 5 ]
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
tensor := ts.MustArange(ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
// tensor, err := ts.NewTensorFromData([]bool{true, false, false, false, false, false}, []int64{2, 3})
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
|
@ -71,7 +71,7 @@ func TestIntegerIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewInsertAxis(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
tensor := ts.MustArange(ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
var idxs1 []ts.TensorIndexer = []ts.TensorIndexer{
|
||||
ts.NewInsertNewAxis(),
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ func TestNewInsertAxis(t *testing.T) {
|
|||
func TestRangeIndex(t *testing.T) {
|
||||
|
||||
// Range
|
||||
tensor1 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
tensor1 := ts.MustArange(ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(1, 3),
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ func TestRangeIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
// Full range
|
||||
tensor2 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
tensor2 := ts.MustArange(ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
idx2 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, tensor2.MustSize()[0]),
|
||||
}
|
||||
|
@ -150,7 +150,7 @@ func TestRangeIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
// Range from
|
||||
tensor3 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
tensor3 := ts.MustArange(ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
idx3 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(2, tensor3.MustSize()[0]),
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func TestRangeIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
// Range to
|
||||
tensor4 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
tensor4 := ts.MustArange(ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
idx4 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, 2),
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ func TestRangeIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSliceIndex(t *testing.T) {
|
||||
tensor1 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(6*2), gotch.Int64, gotch.CPU).MustView([]int64{6, 2}, true)
|
||||
tensor1 := ts.MustArange(ts.IntScalar(6*2), gotch.Int64, gotch.CPU).MustView([]int64{6, 2}, true)
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewSliceIndex([]int64{1, 3, 5}),
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ func TestSliceIndex(t *testing.T) {
|
|||
t.Errorf("Got tensor values: %v\n", got1Shape)
|
||||
}
|
||||
|
||||
tensor2 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(3*4), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
tensor2 := ts.MustArange(ts.IntScalar(3*4), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
idx2 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, tensor2.MustSize()[0]),
|
||||
ts.NewSliceIndex([]int64{3, 0}),
|
||||
|
@ -229,7 +229,7 @@ func TestSliceIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestComplexIndex(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3*5*7), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 5, 7}, true)
|
||||
tensor := ts.MustArange(ts.IntScalar(2*3*5*7), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 5, 7}, true)
|
||||
idx := []ts.TensorIndexer{
|
||||
ts.NewSelect(1),
|
||||
ts.NewNarrow(1, 2),
|
||||
|
@ -253,7 +253,7 @@ func TestComplexIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIndex3D(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(24), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
|
||||
tensor := ts.MustArange(ts.IntScalar(24), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
|
||||
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewSelect(0),
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -20,7 +20,7 @@ func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
|
|||
// targets represent ground-truth.
|
||||
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
|
||||
argmax := ts.MustArgmax([]int64{-1}, false, false)
|
||||
eq1 := argmax.MustEq1(targets, true)
|
||||
eq1 := argmax.MustEqTensor(targets, true)
|
||||
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
)
|
||||
|
||||
func ExampleTensor_Split(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
|
||||
tensor := ts.MustArange(ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
|
||||
splitTensors := tensor.MustSplit(2, 0, false)
|
||||
|
||||
for _, t := range splitTensors {
|
||||
|
@ -27,7 +27,7 @@ func ExampleTensor_Split(t *testing.T) {
|
|||
}
|
||||
|
||||
func ExampleTensorSplitWithSizes(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
|
||||
tensor := ts.MustArange(ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
|
||||
splitTensors := tensor.MustSplitWithSizes([]int64{1, 4}, 0, false)
|
||||
|
||||
for _, t := range splitTensors {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1207,7 +1207,7 @@ func (ts *Tensor) Onehot(labels int64) *Tensor {
|
|||
inputTs := unsqueezeTs.MustTotype(gotch.Int64, true)
|
||||
|
||||
zerosTs := MustZeros(dims, gotch.Float, gotch.CPU)
|
||||
retVal := zerosTs.MustScatter1(-1, inputTs, FloatScalar(1.0), true)
|
||||
retVal := zerosTs.MustScatterValue(-1, inputTs, FloatScalar(1.0), true)
|
||||
inputTs.MustDrop()
|
||||
|
||||
return retVal
|
||||
|
|
|
@ -9,9 +9,9 @@ import (
|
|||
)
|
||||
|
||||
func TestTensorInit(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(1), ts.IntScalar(5), gotch.Int64, gotch.CPU)
|
||||
tensor := ts.MustArange(ts.IntScalar(5), gotch.Int64, gotch.CPU)
|
||||
|
||||
want := []float64{1, 2, 3, 4}
|
||||
want := []float64{0, 1, 2, 3, 4}
|
||||
got := tensor.Float64Values()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
|
@ -23,9 +23,9 @@ func TestTensorInit(t *testing.T) {
|
|||
func TestInplaceAssign(t *testing.T) {
|
||||
tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5})
|
||||
|
||||
tensor.MustAdd1_(ts.IntScalar(1))
|
||||
tensor.MustMul1_(ts.IntScalar(2))
|
||||
tensor.MustSub1_(ts.IntScalar(1))
|
||||
tensor.MustAddScalar_(ts.IntScalar(1))
|
||||
tensor.MustMulScalar_(ts.IntScalar(2))
|
||||
tensor.MustSubScalar_(ts.IntScalar(1))
|
||||
|
||||
want := []int64{7, 3, 9, 3, 11}
|
||||
got := tensor.Vals()
|
||||
|
@ -38,7 +38,7 @@ func TestInplaceAssign(t *testing.T) {
|
|||
|
||||
func TestConstantOp(t *testing.T) {
|
||||
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
|
||||
resTs1 := tensor.MustMul1(ts.IntScalar(-1), true)
|
||||
resTs1 := tensor.MustMulScalar(ts.IntScalar(-1), true)
|
||||
|
||||
want1 := []int64{-3, -9, -3, -11}
|
||||
got1 := resTs1.Vals()
|
||||
|
|
|
@ -3,6 +3,7 @@ package aug
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
// "math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
|
@ -38,11 +39,11 @@ func (c *RandomCrop) params(x *ts.Tensor) (int64, int64, int64, int64) {
|
|||
return 0, 0, h, w
|
||||
}
|
||||
|
||||
iTs := ts.MustRandint1(0, h-th+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
iTs := ts.MustRandint(h-th+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
i := iTs.Int64Values()[0]
|
||||
iTs.MustDrop()
|
||||
|
||||
jTs := ts.MustRandint1(0, w-tw+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
jTs := ts.MustRandint(w-tw+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
j := jTs.Int64Values()[0]
|
||||
jTs.MustDrop()
|
||||
|
||||
|
|
|
@ -130,11 +130,11 @@ func (rc *RandomCutout) cutoutParams(x *ts.Tensor) (int64, int64, int64, int64,
|
|||
v := ts.MustOfSlice(rc.rgbVal).MustUnsqueeze(1, true).MustUnsqueeze(1, true)
|
||||
|
||||
// i = torch.randint(0, img_h - h + 1, size=(1, )).item()
|
||||
iTs := ts.MustRandint1(0, imgH-h+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
iTs := ts.MustRandint(imgH-h+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
i := iTs.Int64Values()[0]
|
||||
iTs.MustDrop()
|
||||
// j = torch.randint(0, img_w - w + 1, size=(1, )).item()
|
||||
jTs := ts.MustRandint1(0, imgW-w+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
jTs := ts.MustRandint(imgW-w+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
j := jTs.Int64Values()[0]
|
||||
jTs.MustDrop()
|
||||
return i, j, h, w, v
|
||||
|
|
|
@ -16,7 +16,7 @@ func gaussianKernel1D(ks int64, sigma float64, dtype gotch.DType, device gotch.D
|
|||
x := ts.MustLinspace(ts.IntScalar(-ksHalf), ts.IntScalar(ksHalf), []int64{ks}, dtype, device)
|
||||
|
||||
// pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
pdf := x.MustDiv1(ts.FloatScalar(sigma), true).MustPow(ts.IntScalar(2), true).MustMul1(ts.FloatScalar(0.5), true).MustExp(true)
|
||||
pdf := x.MustDivScalar(ts.FloatScalar(sigma), true).MustPow(ts.IntScalar(2), true).MustMulScalar(ts.FloatScalar(0.5), true).MustExp(true)
|
||||
// kernel1d = pdf / pdf.sum()
|
||||
pdfSum := pdf.MustSum(dtype, false)
|
||||
kernel1d := pdf.MustDiv(pdfSum, true)
|
||||
|
@ -76,7 +76,7 @@ func castSqueezeOut(x *ts.Tensor, needCast, needSqueeze bool, outDType gotch.DTy
|
|||
)
|
||||
switch needSqueeze {
|
||||
case true:
|
||||
squeezeTs = x.MustSqueeze1(0, false)
|
||||
squeezeTs = x.MustSqueezeDim(0, false)
|
||||
case false:
|
||||
squeezeTs = x.MustShallowClone()
|
||||
}
|
||||
|
@ -192,8 +192,8 @@ func blend(img1, img2 *ts.Tensor, ratio float64) *ts.Tensor {
|
|||
bound := 255.0
|
||||
|
||||
// (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
|
||||
i1 := img1.MustMul1(ts.FloatScalar(ratio), false)
|
||||
i2 := img2.MustMul1(ts.FloatScalar(1.0-ratio), false)
|
||||
i1 := img1.MustMulScalar(ts.FloatScalar(ratio), false)
|
||||
i2 := img2.MustMulScalar(ts.FloatScalar(1.0-ratio), false)
|
||||
sumTs := i1.MustAdd(i2, true)
|
||||
i2.MustDrop()
|
||||
out := sumTs.MustClamp(ts.FloatScalar(0), ts.FloatScalar(bound), true).MustTotype(dtype, true)
|
||||
|
@ -262,9 +262,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
|
|||
// This implementation closely follows the TF one:
|
||||
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
|
||||
// l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
|
||||
rmul := r.MustMul1(ts.FloatScalar(0.2989), true)
|
||||
gmul := g.MustMul1(ts.FloatScalar(0.587), true)
|
||||
bmul := b.MustMul1(ts.FloatScalar(0.114), true)
|
||||
rmul := r.MustMulScalar(ts.FloatScalar(0.2989), true)
|
||||
gmul := g.MustMulScalar(ts.FloatScalar(0.587), true)
|
||||
bmul := b.MustMulScalar(ts.FloatScalar(0.114), true)
|
||||
addTs := rmul.MustAdd(gmul, true).MustAdd(bmul, true)
|
||||
gmul.MustDrop()
|
||||
bmul.MustDrop()
|
||||
|
@ -288,7 +288,7 @@ func adjustContrast(x *ts.Tensor, contrast float64) *ts.Tensor {
|
|||
|
||||
grayTs := rgb2Gray(x).MustTotype(x.DType(), true)
|
||||
|
||||
mean := grayTs.MustMean1([]int64{-3, -2, -1}, true, gotch.Float, true).MustTotype(x.DType(), true)
|
||||
mean := grayTs.MustMeanDim([]int64{-3, -2, -1}, true, gotch.Float, true).MustTotype(x.DType(), true)
|
||||
out := blend(x, mean, contrast)
|
||||
mean.MustDrop()
|
||||
|
||||
|
@ -331,7 +331,7 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
|||
// # we don't need to deal with it in case we save the NaN in a buffer in
|
||||
// # backprop, if it is ever supported, but it doesn't hurt to do so.
|
||||
// eqc = maxc == minc
|
||||
eqC := maxC.MustEq1(minC, false)
|
||||
eqC := maxC.MustEqTensor(minC, false)
|
||||
|
||||
// cr = maxc - minc
|
||||
cr := maxC.MustSub(minC, false)
|
||||
|
@ -340,7 +340,7 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
|||
ones := maxC.MustOnesLike(false)
|
||||
|
||||
// s = cr / torch.where(eqc, ones, maxc)
|
||||
condMaxC := ones.MustWhere1(eqC, maxC, false)
|
||||
condMaxC := ones.MustWhereSelf(eqC, maxC, false)
|
||||
s := cr.MustDiv(condMaxC, false)
|
||||
|
||||
// # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
|
||||
|
@ -351,27 +351,27 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
|||
// rc = (maxc - r) / cr_divisor
|
||||
// gc = (maxc - g) / cr_divisor
|
||||
// bc = (maxc - b) / cr_divisor
|
||||
crDivisor := ones.MustWhere1(eqC, cr, true) // delete ones
|
||||
crDivisor := ones.MustWhereSelf(eqC, cr, true) // delete ones
|
||||
rc := maxC.MustSub(r, false).MustDiv(crDivisor, true)
|
||||
gc := maxC.MustSub(g, false).MustDiv(crDivisor, true)
|
||||
bc := maxC.MustSub(b, false).MustDiv(crDivisor, true)
|
||||
|
||||
// hr = (maxc == r) * (bc - gc)
|
||||
rSub := bc.MustSub(gc, false)
|
||||
hr := maxC.MustEq1(r, false).MustMul(rSub, true)
|
||||
hr := maxC.MustEqTensor(r, false).MustMul(rSub, true)
|
||||
rSub.MustDrop()
|
||||
|
||||
// hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
|
||||
maxcCond1 := maxC.MustNotEqual1(r, false)
|
||||
hgMul := rc.MustSub(bc, false).MustAdd1(ts.FloatScalar(2.0), true)
|
||||
hg := maxC.MustEq1(g, false).MustLogicalAnd(maxcCond1, true).MustMul(hgMul, true)
|
||||
maxcCond1 := maxC.MustNotEqualTensor(r, false)
|
||||
hgMul := rc.MustSub(bc, false).MustAddScalar(ts.FloatScalar(2.0), true)
|
||||
hg := maxC.MustEqTensor(g, false).MustLogicalAnd(maxcCond1, true).MustMul(hgMul, true)
|
||||
maxcCond1.MustDrop()
|
||||
hgMul.MustDrop()
|
||||
|
||||
// hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
|
||||
maxcCond2 := maxC.MustNotEqual1(r, false)
|
||||
hbMul := gc.MustSub(rc, false).MustAdd1(ts.FloatScalar(4.0), true)
|
||||
hb := maxC.MustNotEqual1(g, false).MustLogicalAnd(maxcCond2, true).MustMul(hbMul, true)
|
||||
maxcCond2 := maxC.MustNotEqualTensor(r, false)
|
||||
hbMul := gc.MustSub(rc, false).MustAddScalar(ts.FloatScalar(4.0), true)
|
||||
hb := maxC.MustNotEqualTensor(g, false).MustLogicalAnd(maxcCond2, true).MustMul(hbMul, true)
|
||||
maxcCond2.MustDrop()
|
||||
hbMul.MustDrop()
|
||||
|
||||
|
@ -379,8 +379,8 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
|||
h1 := hr.MustAdd(hg, false).MustAdd(hb, true)
|
||||
|
||||
// h = torch.fmod((h / 6.0 + 1.0), 1.0)
|
||||
h2 := h1.MustDiv1(ts.FloatScalar(6.0), true).MustAdd1(ts.FloatScalar(1.0), true) // delete h1
|
||||
h3 := h2.MustFmod(ts.FloatScalar(1.0), true) // delete h2
|
||||
h2 := h1.MustDivScalar(ts.FloatScalar(6.0), true).MustAddScalar(ts.FloatScalar(1.0), true) // delete h1
|
||||
h3 := h2.MustFmod(ts.FloatScalar(1.0), true) // delete h2
|
||||
|
||||
// torch.stack((h, s, maxc), dim=-3)
|
||||
out := ts.MustStack([]ts.Tensor{*h3, *s, *maxC}, -3)
|
||||
|
@ -413,26 +413,26 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
|
|||
s := &hsvTs[1]
|
||||
v := &hsvTs[2]
|
||||
// i = torch.floor(h * 6.0)
|
||||
i := h.MustMul1(ts.FloatScalar(6.0), false).MustFloor(true)
|
||||
i := h.MustMulScalar(ts.FloatScalar(6.0), false).MustFloor(true)
|
||||
// f = (h * 6.0) - i
|
||||
f := h.MustMul1(ts.FloatScalar(6.0), false).MustSub(i, true)
|
||||
f := h.MustMulScalar(ts.FloatScalar(6.0), false).MustSub(i, true)
|
||||
|
||||
// p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
|
||||
x1 := s.MustMul1(ts.FloatScalar(-1), false).MustAdd1(ts.FloatScalar(1.0), true)
|
||||
x1 := s.MustMulScalar(ts.FloatScalar(-1), false).MustAddScalar(ts.FloatScalar(1.0), true)
|
||||
p := v.MustMul(x1, false).MustClamp(ts.FloatScalar(0.0), ts.FloatScalar(1.0), true)
|
||||
x1.MustDrop()
|
||||
|
||||
// q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
|
||||
x2 := s.MustMul(f, false).MustMul1(ts.FloatScalar(-1), true).MustAdd1(ts.FloatScalar(1.0), true)
|
||||
x2 := s.MustMul(f, false).MustMulScalar(ts.FloatScalar(-1), true).MustAddScalar(ts.FloatScalar(1.0), true)
|
||||
q := v.MustMul(x2, false).MustClamp(ts.FloatScalar(0.0), ts.FloatScalar(1.0), true)
|
||||
x2.MustDrop()
|
||||
|
||||
//t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
|
||||
// step1. s * (1.0 - f)
|
||||
sub1 := f.MustMul1(ts.FloatScalar(-1), false).MustAdd1(ts.FloatScalar(1.0), true).MustMul(s, true)
|
||||
sub1 := f.MustMulScalar(ts.FloatScalar(-1), false).MustAddScalar(ts.FloatScalar(1.0), true).MustMul(s, true)
|
||||
// step 2: v *(1.0 - step1)
|
||||
x3 := sub1.MustMul1(ts.FloatScalar(-1), true).MustAdd1(ts.FloatScalar(1.0), true).MustMul(v, true) // deleted sub1
|
||||
t := x3.MustClamp(ts.FloatScalar(0.0), ts.FloatScalar(1.0), true) // deleted x3
|
||||
x3 := sub1.MustMulScalar(ts.FloatScalar(-1), true).MustAddScalar(ts.FloatScalar(1.0), true).MustMul(v, true) // deleted sub1
|
||||
t := x3.MustClamp(ts.FloatScalar(0.0), ts.FloatScalar(1.0), true) // deleted x3
|
||||
|
||||
// i = i.to(dtype=torch.int32)
|
||||
i = i.MustTotype(gotch.Int, true)
|
||||
|
@ -441,7 +441,7 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
|
|||
// torch.arange(6, device=i.device).view(-1, 1, 1)
|
||||
x4 := ts.MustArange(ts.FloatScalar(6), gotch.Float, iremainder.MustDevice()).MustView([]int64{-1, 1, 1}, true)
|
||||
// mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1)
|
||||
mask := iremainder.MustUnsqueeze(-3, true).MustEq1(x4, true).MustTotype(x.DType(), true) // delete iremainder
|
||||
mask := iremainder.MustUnsqueeze(-3, true).MustEqTensor(x4, true).MustTotype(x.DType(), true) // delete iremainder
|
||||
x4.MustDrop()
|
||||
|
||||
// a1 = torch.stack((v, q, p, p, t, v), dim=-3)
|
||||
|
@ -487,7 +487,7 @@ func adjustHue(x *ts.Tensor, hue float64) *ts.Tensor {
|
|||
return out
|
||||
}
|
||||
|
||||
imgFl := x.MustTotype(gotch.Float, false).MustDiv1(ts.FloatScalar(255.0), true)
|
||||
imgFl := x.MustTotype(gotch.Float, false).MustDivScalar(ts.FloatScalar(255.0), true)
|
||||
hsvImg := rgb2HSV(imgFl)
|
||||
|
||||
hsvTs := hsvImg.MustUnbind(-3, true)
|
||||
|
@ -495,13 +495,13 @@ func adjustHue(x *ts.Tensor, hue float64) *ts.Tensor {
|
|||
s := &hsvTs[1]
|
||||
v := &hsvTs[2]
|
||||
// h = (h + hue_factor) % 1.0
|
||||
hAdj := h.MustAdd1(ts.FloatScalar(hue), false).MustRemainder(ts.FloatScalar(1.0), true)
|
||||
hAdj := h.MustAddScalar(ts.FloatScalar(hue), false).MustRemainder(ts.FloatScalar(1.0), true)
|
||||
|
||||
hsvAdj := ts.MustStack([]ts.Tensor{*hAdj, *s, *v}, -3)
|
||||
|
||||
imgHueAdj := hsv2RGB(hsvAdj)
|
||||
|
||||
out := imgHueAdj.MustMul1(ts.FloatScalar(255.0), true)
|
||||
out := imgHueAdj.MustMulScalar(ts.FloatScalar(255.0), true)
|
||||
|
||||
imgFl.MustDrop()
|
||||
h.MustDrop()
|
||||
|
@ -658,7 +658,7 @@ func cutout(x *ts.Tensor, top, left, height, width int64, rgbVal []int64) *ts.Te
|
|||
srcIdx := []ts.TensorIndexer{cIdx, hNar, wNar}
|
||||
view := output.Idx(srcIdx)
|
||||
oneTs := view.MustOnesLike(false)
|
||||
vTs := oneTs.MustMul1(ts.IntScalar(rgbVal[i]), true)
|
||||
vTs := oneTs.MustMulScalar(ts.IntScalar(rgbVal[i]), true)
|
||||
view.Copy_(vTs)
|
||||
vTs.MustDrop()
|
||||
view.MustDrop()
|
||||
|
@ -760,7 +760,7 @@ func applyGridTransform(x, gridInput *ts.Tensor, mode string, fillValue []float6
|
|||
fillImg := ts.MustOfSlice(fillValue).MustTotype(image.DType(), true).MustTo(image.MustDevice(), true).MustView([]int64{1, 3, 1, 1}, true).MustExpandAs(image, true)
|
||||
|
||||
// img = img * mask + (1.0 - mask) * fill_img
|
||||
addTs := mask.MustMul1(ts.FloatScalar(-1), false).MustAdd1(ts.FloatScalar(1.0), true).MustMul(fillImg, true)
|
||||
addTs := mask.MustMulScalar(ts.FloatScalar(-1), false).MustAddScalar(ts.FloatScalar(1.0), true).MustMul(fillImg, true)
|
||||
imgOut := image.MustMul(mask, true).MustAdd(addTs, true)
|
||||
addTs.MustDrop()
|
||||
mask.MustDrop()
|
||||
|
@ -817,7 +817,7 @@ func perspectiveCoeff(startPoints, endPoints [][]int64) []float64 {
|
|||
res := bMat.MustLstsq(aMat, true)
|
||||
|
||||
aMat.MustDrop()
|
||||
outputTs := res.MustSqueeze1(1, true)
|
||||
outputTs := res.MustSqueezeDim(1, true)
|
||||
output := outputTs.Float64Values()
|
||||
outputTs.MustDrop()
|
||||
|
||||
|
@ -897,7 +897,7 @@ func perspectiveGrid(coef []float64, ow, oh int64, dtype gotch.DType, device got
|
|||
rescaledTheta1.MustDrop()
|
||||
rescaledTheta2.MustDrop()
|
||||
|
||||
outputGrid := outputGrid1.MustDiv(outputGrid2, true).MustSub1(ts.FloatScalar(1.0), true).MustView([]int64{1, oh, ow, 2}, true)
|
||||
outputGrid := outputGrid1.MustDiv(outputGrid2, true).MustSubScalar(ts.FloatScalar(1.0), true).MustView([]int64{1, oh, ow, 2}, true)
|
||||
outputGrid2.MustDrop()
|
||||
|
||||
baseGrid.MustDrop()
|
||||
|
@ -1132,7 +1132,7 @@ func solarize(img *ts.Tensor, threshold float64) *ts.Tensor {
|
|||
// return torch.where(img >= threshold, inverted_img, img)
|
||||
conditionTs := img.MustGe(ts.FloatScalar(threshold), false)
|
||||
|
||||
out := img.MustWhere1(conditionTs, invertedImg, false)
|
||||
out := img.MustWhereSelf(conditionTs, invertedImg, false)
|
||||
|
||||
invertedImg.MustDrop()
|
||||
conditionTs.MustDrop()
|
||||
|
@ -1153,7 +1153,7 @@ func invert(img *ts.Tensor) *ts.Tensor {
|
|||
|
||||
var bound int64 = 255
|
||||
// return bound - img
|
||||
out := img.MustMul1(ts.IntScalar(-1), false).MustAdd1(ts.IntScalar(bound), true)
|
||||
out := img.MustMulScalar(ts.IntScalar(-1), false).MustAddScalar(ts.IntScalar(bound), true)
|
||||
return out
|
||||
}
|
||||
|
||||
|
@ -1201,7 +1201,7 @@ func autocontrast(img *ts.Tensor) *ts.Tensor {
|
|||
|
||||
// eq_idxs = torch.where(minimum == maximum)[0]
|
||||
// NOTE. Eq(minTs, maxTs) give [n, c, 1, 1] or [channels, 1, 1]
|
||||
eqIdx := minTs.MustEq1(maxTs, false).MustSqueeze1(-1, true).MustSqueeze1(-1, true).MustTotype(gotch.Int64, true)
|
||||
eqIdx := minTs.MustEqTensor(maxTs, false).MustSqueezeDim(-1, true).MustSqueezeDim(-1, true).MustTotype(gotch.Int64, true)
|
||||
|
||||
// minimum[eq_idxs] = 0
|
||||
minTsView := minTs.MustIndexSelect(0, eqIdx, false)
|
||||
|
@ -1212,13 +1212,13 @@ func autocontrast(img *ts.Tensor) *ts.Tensor {
|
|||
|
||||
// maximum[eq_idxs] = bound
|
||||
maxTsView := maxTs.MustIndexSelect(0, eqIdx, false)
|
||||
boundTs := maxTsView.MustOnesLike(false).MustMul1(ts.FloatScalar(bound), true)
|
||||
boundTs := maxTsView.MustOnesLike(false).MustMulScalar(ts.FloatScalar(bound), true)
|
||||
maxTsView.Copy_(boundTs)
|
||||
boundTs.MustDrop()
|
||||
maxTsView.MustDrop()
|
||||
|
||||
// scale = bound / (maximum - minimum)
|
||||
scale := maxTs.MustSub(minTs, false).MustPow(ts.IntScalar(-1), true).MustMul1(ts.FloatScalar(bound), true)
|
||||
scale := maxTs.MustSub(minTs, false).MustPow(ts.IntScalar(-1), true).MustMulScalar(ts.FloatScalar(bound), true)
|
||||
//
|
||||
// return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
|
||||
out := img.MustSub(minTs, false).MustMul(scale, true).MustClamp(ts.IntScalar(0), ts.FloatScalar(bound), true).MustTotype(dtype, true)
|
||||
|
@ -1265,7 +1265,7 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
|
|||
|
||||
// kernel[1, 1] = 5.0
|
||||
kernelView := kernel.MustNarrow(1, 1, 1, false).MustNarrow(0, 1, 1, true)
|
||||
centerVal := kernelView.MustOnesLike(false).MustMul1(ts.FloatScalar(5.0), true)
|
||||
centerVal := kernelView.MustOnesLike(false).MustMulScalar(ts.FloatScalar(5.0), true)
|
||||
kernelView.Copy_(centerVal) // center kernel value
|
||||
centerVal.MustDrop()
|
||||
kernelView.MustDrop()
|
||||
|
@ -1393,7 +1393,7 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
|||
|
||||
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
|
||||
histoLen := nonzeroHisto.MustSize()[0]
|
||||
step := nonzeroHisto.MustNarrow(0, 0, histoLen-1, true).MustSum(gotch.Float, true).MustFloorDivide1(ts.FloatScalar(255.0), true)
|
||||
step := nonzeroHisto.MustNarrow(0, 0, histoLen-1, true).MustSum(gotch.Float, true).MustFloorDivideScalar(ts.FloatScalar(255.0), true)
|
||||
|
||||
stepVal := step.Float64Values()[0]
|
||||
if stepVal == 0 {
|
||||
|
@ -1404,7 +1404,7 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
|||
}
|
||||
|
||||
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), step, rounding_mode='floor')
|
||||
halfStep := step.MustFloorDivide1(ts.FloatScalar(2.0), false)
|
||||
halfStep := step.MustFloorDivideScalar(ts.FloatScalar(2.0), false)
|
||||
lut := histo.Must_Cumsum(0, true).MustAdd(halfStep, true).MustFloorDivide(step, true)
|
||||
step.MustDrop()
|
||||
halfStep.MustDrop()
|
||||
|
@ -1491,7 +1491,7 @@ func Byte2FloatImage(x *ts.Tensor) *ts.Tensor {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
return x.MustDiv1(ts.FloatScalar(255.0), false)
|
||||
return x.MustDivScalar(ts.FloatScalar(255.0), false)
|
||||
}
|
||||
|
||||
// Float2ByteImage converts float dtype image to uint8 dtype image.
|
||||
|
@ -1503,5 +1503,5 @@ func Float2ByteImage(x *ts.Tensor) *ts.Tensor {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
return x.MustMul1(ts.IntScalar(255), false).MustTotype(gotch.Uint8, true)
|
||||
return x.MustMulScalar(ts.IntScalar(255), false).MustTotype(gotch.Uint8, true)
|
||||
}
|
||||
|
|
|
@ -107,11 +107,11 @@ func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
|
|||
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
|
||||
// ]
|
||||
tlVal1 := int64(rp.distortionScale*float64(halfW)) + 1
|
||||
tlTs1 := ts.MustRandint1(0, tlVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tlTs1 := ts.MustRandint(tlVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tl1 := tlTs1.Int64Values()[0]
|
||||
tlTs1.MustDrop()
|
||||
tlVal2 := int64(rp.distortionScale*float64(halfH)) + 1
|
||||
tlTs2 := ts.MustRandint1(0, tlVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tlTs2 := ts.MustRandint(tlVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tl2 := tlTs2.Int64Values()[0]
|
||||
tlTs2.MustDrop()
|
||||
topLeft = []int64{tl1, tl2}
|
||||
|
@ -121,11 +121,11 @@ func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
|
|||
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
|
||||
// ]
|
||||
trVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
|
||||
trTs1 := ts.MustRandint1(trVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
trTs1 := ts.MustRandintLow(trVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tr1 := trTs1.Int64Values()[0]
|
||||
trTs1.MustDrop()
|
||||
trVal2 := int64(rp.distortionScale*float64(halfH)) + 1
|
||||
trTs2 := ts.MustRandint1(0, trVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
trTs2 := ts.MustRandint(trVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tr2 := trTs2.Int64Values()[0]
|
||||
trTs2.MustDrop()
|
||||
topRight = []int64{tr1, tr2}
|
||||
|
@ -135,11 +135,11 @@ func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
|
|||
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
|
||||
// ]
|
||||
brVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
|
||||
brTs1 := ts.MustRandint1(brVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
brTs1 := ts.MustRandintLow(brVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
br1 := brTs1.Int64Values()[0]
|
||||
brTs1.MustDrop()
|
||||
brVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
|
||||
brTs2 := ts.MustRandint1(brVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
brTs2 := ts.MustRandintLow(brVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
br2 := brTs2.Int64Values()[0]
|
||||
brTs2.MustDrop()
|
||||
bottomRight = []int64{br1, br2}
|
||||
|
@ -149,11 +149,11 @@ func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
|
|||
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
|
||||
// ]
|
||||
blVal1 := int64(rp.distortionScale*float64(halfW)) + 1
|
||||
blTs1 := ts.MustRandint1(0, blVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
blTs1 := ts.MustRandint(blVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
bl1 := blTs1.Int64Values()[0]
|
||||
blTs1.MustDrop()
|
||||
blVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
|
||||
blTs2 := ts.MustRandint1(blVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
blTs2 := ts.MustRandintLow(blVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
bl2 := blTs2.Int64Values()[0]
|
||||
blTs2.MustDrop()
|
||||
bottomLeft = []int64{bl1, bl2}
|
||||
|
|
|
@ -67,7 +67,7 @@ func readFile(filename string) (imagesTs *ts.Tensor, labelsTs *ts.Tensor) {
|
|||
}
|
||||
|
||||
tmp1 := images.MustTotype(gotch.Float, true)
|
||||
imagesTs = tmp1.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
imagesTs = tmp1.MustDivScalar(ts.FloatScalar(255.0), true)
|
||||
|
||||
labelsTs = labels
|
||||
|
||||
|
|
|
@ -313,8 +313,8 @@ func efficientnet(p *nn.Path, params *params, nclasses int64) ts.ModuleT {
|
|||
tmp6.MustDrop()
|
||||
tmp8 := tmp7.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||
tmp7.MustDrop()
|
||||
tmp9 := tmp8.MustSqueeze1(-1, true)
|
||||
tmp10 := tmp9.MustSqueeze1(-1, true)
|
||||
tmp9 := tmp8.MustSqueezeDim(-1, true)
|
||||
tmp10 := tmp9.MustSqueezeDim(-1, true)
|
||||
|
||||
res := tmp10.ApplyT(classifier, train)
|
||||
tmp10.MustDrop()
|
||||
|
|
|
@ -68,7 +68,7 @@ func Save(tensor *ts.Tensor, path string) error {
|
|||
var tsCHW, tsHWC *ts.Tensor
|
||||
switch {
|
||||
case len(shape) == 4 && shape[0] == 1:
|
||||
tsCHW = t.MustSqueeze1(int64(0), true)
|
||||
tsCHW = t.MustSqueezeDim(int64(0), true)
|
||||
chwTs := chwToHWC(tsCHW)
|
||||
tsCHW.MustDrop()
|
||||
tsHWC = chwTs.MustTo(gotch.CPU, true)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
||||
// "os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
|
@ -38,7 +39,7 @@ func (in *ImageNet) Normalize(tensor *ts.Tensor) (*ts.Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
resDiv1, err := res.Div1(ts.FloatScalar(float64(255.0)), true)
|
||||
resDiv1, err := res.DivScalar(ts.FloatScalar(float64(255.0)), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ func (in *ImageNet) UnNormalize(tensor *ts.Tensor) (*ts.Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
resMul1, err := resAdd.Mul1(ts.FloatScalar(float64(255.0)), true)
|
||||
resMul1, err := resAdd.MulScalar(ts.FloatScalar(float64(255.0)), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -263,7 +264,7 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
|
|||
trainImages = append(trainImages, *trainTs)
|
||||
|
||||
trainLabelOnes := ts.MustOnes([]int64{ntrainTs}, gotch.Int64, gotch.CPU)
|
||||
trainLabels = append(trainLabels, *trainLabelOnes.MustMul1(ts.IntScalar(labelIndex), true))
|
||||
trainLabels = append(trainLabels, *trainLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
|
||||
|
||||
// test
|
||||
testDir := fmt.Sprintf("%v/%v", validPath, labelDir)
|
||||
|
@ -276,7 +277,7 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
|
|||
testImages = append(testImages, *testTs)
|
||||
|
||||
testLabelOnes := ts.MustOnes([]int64{ntestTs}, gotch.Int64, gotch.CPU)
|
||||
testLabels = append(testLabels, *testLabelOnes.MustMul1(ts.IntScalar(labelIndex), true))
|
||||
testLabels = append(testLabels, *testLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
|
||||
}
|
||||
|
||||
trainImageTs := ts.MustCat(trainImages, 0)
|
||||
|
|
|
@ -124,7 +124,7 @@ func readImages(filename string) *ts.Tensor {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}, true).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true)
|
||||
return imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}, true).MustTotype(gotch.Float, true).MustDivScalar(ts.FloatScalar(255.0), true)
|
||||
}
|
||||
|
||||
// LoadMNISTDir loads all MNIST data from a given directory to Dataset
|
||||
|
|
|
@ -117,8 +117,8 @@ func MobileNetV2(p *nn.Path, nclasses int64) ts.ModuleT {
|
|||
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
|
||||
tmp1 := xs.ApplyT(features, train)
|
||||
|
||||
tmp2 := tmp1.MustMean1([]int64{2}, false, gotch.Float, true)
|
||||
tmp3 := tmp2.MustMean1([]int64{2}, false, gotch.Float, true)
|
||||
tmp2 := tmp1.MustMeanDim([]int64{2}, false, gotch.Float, true)
|
||||
tmp3 := tmp2.MustMeanDim([]int64{2}, false, gotch.Float, true)
|
||||
|
||||
res := tmp3.ApplyT(classifier, train)
|
||||
tmp3.MustDrop()
|
||||
|
|
Loading…
Reference in New Issue
Block a user