BREAKING CHANGE: switch to auto-generated
This commit is contained in:
parent
9f5eccb4e5
commit
0ff2f910f2
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -20,6 +20,7 @@ _build/
|
||||||
data/
|
data/
|
||||||
example/testdata/
|
example/testdata/
|
||||||
tmp/
|
tmp/
|
||||||
|
bak/
|
||||||
gen/.merlin
|
gen/.merlin
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
|
|
|
@ -143,7 +143,7 @@ func main() {
|
||||||
loss := logits.CrossEntropyForLogits(devicedLabel)
|
loss := logits.CrossEntropyForLogits(devicedLabel)
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
lossVal = loss.Values()[0]
|
lossVal = loss.Float64Values()[0]
|
||||||
|
|
||||||
devicedData.MustDrop()
|
devicedData.MustDrop()
|
||||||
devicedLabel.MustDrop()
|
devicedLabel.MustDrop()
|
||||||
|
|
|
@ -133,7 +133,7 @@ func runCNN1() {
|
||||||
vs.Freeze()
|
vs.Freeze()
|
||||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||||
vs.Unfreeze()
|
vs.Unfreeze()
|
||||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100.0)
|
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
||||||
if testAccuracy > bestAccuracy {
|
if testAccuracy > bestAccuracy {
|
||||||
bestAccuracy = testAccuracy
|
bestAccuracy = testAccuracy
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,30 +20,34 @@ func runLinear() {
|
||||||
var ds vision.Dataset
|
var ds vision.Dataset
|
||||||
ds = vision.LoadMNISTDir(MnistDir)
|
ds = vision.LoadMNISTDir(MnistDir)
|
||||||
|
|
||||||
device := (gotch.CPU).CInt()
|
device := gotch.CPU
|
||||||
dtype := (gotch.Float).CInt()
|
dtype := gotch.Float
|
||||||
|
|
||||||
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true, false)
|
||||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true, false)
|
||||||
|
|
||||||
for epoch := 0; epoch < epochs; epoch++ {
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
|
|
||||||
|
weight := ts.NewTensor()
|
||||||
|
reduction := int64(1) // Mean of loss
|
||||||
|
ignoreIndex := int64(-100)
|
||||||
|
|
||||||
logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true)
|
logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true)
|
||||||
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, true)
|
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, weight, reduction, ignoreIndex, true)
|
||||||
|
|
||||||
ws.ZeroGrad()
|
ws.ZeroGrad()
|
||||||
bs.ZeroGrad()
|
bs.ZeroGrad()
|
||||||
loss.MustBackward()
|
loss.MustBackward()
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0), true))
|
ws.Add_(ws.MustGrad(false).MustMul1(ts.FloatScalar(-1.0), true))
|
||||||
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0), true))
|
bs.Add_(bs.MustGrad(false).MustMul1(ts.FloatScalar(-1.0), true))
|
||||||
})
|
})
|
||||||
|
|
||||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
||||||
|
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,9 +46,9 @@ func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer
|
||||||
|
|
||||||
testLogits := m.Forward(testX)
|
testLogits := m.Forward(testX)
|
||||||
testAccuracy := testLogits.AccuracyForLogits(testY)
|
testAccuracy := testLogits.AccuracyForLogits(testY)
|
||||||
accuracy := testAccuracy.Values()[0] * 100
|
accuracy := testAccuracy.Float64Values()[0] * 100
|
||||||
testAccuracy.MustDrop()
|
testAccuracy.MustDrop()
|
||||||
lossVal := loss.Values()[0]
|
lossVal := loss.Float64Values()[0]
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
|
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
|
||||||
|
|
|
@ -46,7 +46,7 @@ func gramMatrix(m ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
|
||||||
mview := m.MustView([]int64{a * b, c * d}, false)
|
mview := m.MustView([]int64{a * b, c * d}, false)
|
||||||
mviewT := mview.MustT(false)
|
mviewT := mview.MustT(false)
|
||||||
gram := mview.MustMatMul(mviewT, true)
|
gram := mview.MustMatmul(mviewT, true)
|
||||||
mviewT.MustDrop()
|
mviewT.MustDrop()
|
||||||
|
|
||||||
return gram.MustDiv1(ts.IntScalar(a*b*c*d), true)
|
return gram.MustDiv1(ts.IntScalar(a*b*c*d), true)
|
||||||
|
@ -57,7 +57,7 @@ func styleLoss(m1 ts.Tensor, m2 ts.Tensor) (retVal ts.Tensor) {
|
||||||
// m1.MustDrop()
|
// m1.MustDrop()
|
||||||
gram2 := gramMatrix(m2)
|
gram2 := gramMatrix(m2)
|
||||||
// m2.MustDrop()
|
// m2.MustDrop()
|
||||||
loss := gram1.MustMseLoss(gram2, ts.ReductionMean.ToInt(), true)
|
loss := gram1.MustMseLoss(gram2, int64(ts.ReductionMean), true)
|
||||||
gram2.MustDrop()
|
gram2.MustDrop()
|
||||||
return loss
|
return loss
|
||||||
}
|
}
|
||||||
|
@ -89,8 +89,8 @@ func main() {
|
||||||
|
|
||||||
cuda := gotch.CudaBuilder(0)
|
cuda := gotch.CudaBuilder(0)
|
||||||
device := cuda.CudaIfAvailable()
|
device := cuda.CudaIfAvailable()
|
||||||
|
|
||||||
// device := gotch.CPU
|
// device := gotch.CPU
|
||||||
|
|
||||||
netVS := nn.NewVarStore(device)
|
netVS := nn.NewVarStore(device)
|
||||||
in := vision.NewImageNet()
|
in := vision.NewImageNet()
|
||||||
net := vision.VGG16(netVS.Root(), in.ClassCount())
|
net := vision.VGG16(netVS.Root(), in.ClassCount())
|
||||||
|
@ -150,8 +150,8 @@ func main() {
|
||||||
inputLayers := net.ForwardAllT(inputVar, false, maxLayer)
|
inputLayers := net.ForwardAllT(inputVar, false, maxLayer)
|
||||||
|
|
||||||
// var sLoss ts.Tensor
|
// var sLoss ts.Tensor
|
||||||
sLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
|
sLoss := ts.MustZeros([]int64{1}, gotch.Float, device)
|
||||||
cLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
|
cLoss := ts.MustZeros([]int64{1}, gotch.Float, device)
|
||||||
for _, idx := range StyleIndexes {
|
for _, idx := range StyleIndexes {
|
||||||
l := styleLoss(inputLayers[idx], styleLayers[idx])
|
l := styleLoss(inputLayers[idx], styleLayers[idx])
|
||||||
sLoss = sLoss.MustAdd(l, true)
|
sLoss = sLoss.MustAdd(l, true)
|
||||||
|
@ -159,7 +159,7 @@ func main() {
|
||||||
}
|
}
|
||||||
for _, idx := range ContentIndexes {
|
for _, idx := range ContentIndexes {
|
||||||
// NOTE: set `del` = true called panic at GPU train (tested on Colab)
|
// NOTE: set `del` = true called panic at GPU train (tested on Colab)
|
||||||
l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), false)
|
l := inputLayers[idx].MustMseLoss(contentLayers[idx], int64(ts.ReductionMean), false)
|
||||||
cLoss = cLoss.MustAdd(l, true)
|
cLoss = cLoss.MustAdd(l, true)
|
||||||
l.MustDrop()
|
l.MustDrop()
|
||||||
}
|
}
|
||||||
|
@ -174,7 +174,7 @@ func main() {
|
||||||
|
|
||||||
if (stepIdx % 1000) == 0 {
|
if (stepIdx % 1000) == 0 {
|
||||||
clone := inputVar.MustShallowClone()
|
clone := inputVar.MustShallowClone()
|
||||||
img := clone.MustDetach()
|
img := clone.MustDetach(false)
|
||||||
imageTs := img.MustTo(gotch.CPU, true)
|
imageTs := img.MustTo(gotch.CPU, true)
|
||||||
clone.MustDrop()
|
clone.MustDrop()
|
||||||
err := in.SaveImage(imageTs, fmt.Sprintf("../../data/neural-style-transfer/out%v.jpg", stepIdx))
|
err := in.SaveImage(imageTs, fmt.Sprintf("../../data/neural-style-transfer/out%v.jpg", stepIdx))
|
||||||
|
@ -184,7 +184,7 @@ func main() {
|
||||||
imageTs.MustDrop()
|
imageTs.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Step %v ... Done. Loss %10.1f\n", stepIdx, loss.Values()[0])
|
fmt.Printf("Step %v ... Done. Loss %10.1f\n", stepIdx, loss.Float64Values()[0])
|
||||||
cLoss.MustDrop()
|
cLoss.MustDrop()
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
x := tensor.TensorFrom([]float64{2.0})
|
x := tensor.TensorFrom([]float64{2.0})
|
||||||
x = x.MustSetRequiresGrad(true)
|
x = x.MustSetRequiresGrad(true, false)
|
||||||
x.ZeroGrad()
|
x.ZeroGrad()
|
||||||
|
|
||||||
xy := tensor.TensorFrom([]float64{2.0})
|
xy := tensor.TensorFrom([]float64{2.0})
|
||||||
|
@ -19,10 +19,10 @@ func main() {
|
||||||
z := x.MustMul(xz, false)
|
z := x.MustMul(xz, false)
|
||||||
|
|
||||||
y.Backward()
|
y.Backward()
|
||||||
xgrad := x.MustGrad()
|
xgrad := x.MustGrad(false)
|
||||||
xgrad.Print() // [2.0]
|
xgrad.Print() // [2.0]
|
||||||
z.Backward()
|
z.Backward()
|
||||||
xgrad = x.MustGrad()
|
xgrad = x.MustGrad(false)
|
||||||
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
||||||
|
|
||||||
isGradEnabled := tensor.MustGradSetEnabled(false)
|
isGradEnabled := tensor.MustGradSetEnabled(false)
|
||||||
|
|
|
@ -78,6 +78,6 @@ func main() {
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
|
|
||||||
testAccuracy := testImages.Apply(linear).AccuracyForLogits(dataset.TestLabels)
|
testAccuracy := testImages.Apply(linear).AccuracyForLogits(dataset.TestLabels)
|
||||||
fmt.Printf("Epoch %v\t Accuracy: %5.2f%%\n", epoch, testAccuracy.Values()[0]*100)
|
fmt.Printf("Epoch %v\t Accuracy: %5.2f%%\n", epoch, testAccuracy.Float64Values()[0]*100)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -270,7 +270,7 @@ func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
|
||||||
h := res[2]
|
h := res[2]
|
||||||
w := res[3]
|
w := res[3]
|
||||||
|
|
||||||
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0)
|
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
return prevChannels, Layer{Val: layer}
|
return prevChannels, Layer{Val: layer}
|
||||||
|
@ -396,7 +396,7 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
||||||
|
|
||||||
xOffset := a.MustView([]int64{-1, 1}, true)
|
xOffset := a.MustView([]int64{-1, 1}, true)
|
||||||
yOffset := b.MustView([]int64{-1, 1}, true)
|
yOffset := b.MustView([]int64{-1, 1}, true)
|
||||||
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1, false)
|
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1)
|
||||||
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
|
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
|
||||||
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
||||||
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
|
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
|
||||||
|
@ -512,7 +512,7 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
||||||
for _, i := range route.TsIdxs {
|
for _, i := range route.TsIdxs {
|
||||||
layers = append(layers, prevYs[int(i)])
|
layers = append(layers, prevYs[int(i)])
|
||||||
}
|
}
|
||||||
ysTs = ts.MustCat(layers, 1, false)
|
ysTs = ts.MustCat(layers, 1)
|
||||||
|
|
||||||
case "Shortcut":
|
case "Shortcut":
|
||||||
from := b.Bl.(Shortcut).TsIdx
|
from := b.Bl.(Shortcut).TsIdx
|
||||||
|
@ -540,7 +540,7 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
||||||
prevYs = append(prevYs, ysTs)
|
prevYs = append(prevYs, ysTs)
|
||||||
} // end of For loop
|
} // end of For loop
|
||||||
|
|
||||||
res = ts.MustCat(detections, 1, true)
|
res = ts.MustCat(detections, 1)
|
||||||
|
|
||||||
// Now, free-up memory held up by prevYs
|
// Now, free-up memory held up by prevYs
|
||||||
for _, t := range prevYs {
|
for _, t := range prevYs {
|
||||||
|
|
|
@ -87,7 +87,7 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
||||||
// Extract the bounding boxes for which confidence is above the threshold.
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
for index := 0; index < int(npreds); index++ {
|
for index := 0; index < int(npreds); index++ {
|
||||||
predIdx := pred.MustGet(index)
|
predIdx := pred.MustGet(index)
|
||||||
var predVals []float64 = predIdx.Values()
|
var predVals []float64 = predIdx.Float64Values()
|
||||||
predIdx.MustDrop()
|
predIdx.MustDrop()
|
||||||
|
|
||||||
confidence := predVals[4]
|
confidence := predVals[4]
|
||||||
|
@ -229,7 +229,7 @@ func main() {
|
||||||
netHeight := darknet.Height()
|
netHeight := darknet.Height()
|
||||||
netWidth := darknet.Width()
|
netWidth := darknet.Width()
|
||||||
|
|
||||||
imgClone := originalImage.MustShallowClone().MustDetach()
|
imgClone := originalImage.MustShallowClone().MustDetach(false)
|
||||||
|
|
||||||
imageTs, err := vision.Resize(imgClone, netWidth, netHeight)
|
imageTs, err := vision.Resize(imgClone, netWidth, netHeight)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
702
gen/gen.ml
702
gen/gen.ml
|
@ -1,6 +1,6 @@
|
||||||
(* Automatically generate the C++ -> C -> rust bindings.
|
(* Automatically generate the C++ -> C -> Go bindings.
|
||||||
This takes as input the Descriptions.yaml file that gets generated when
|
This takes as input the Descriptions.yaml file that gets generated when
|
||||||
(Func.c_go_args_list func) building PyTorch from source.
|
func (Func.c_go_args_list func) building PyTorch from source.
|
||||||
|
|
||||||
Run with: dune exec gen/gen.exe
|
Run with: dune exec gen/gen.exe
|
||||||
*)
|
*)
|
||||||
|
@ -42,11 +42,12 @@ let no_tensor_options =
|
||||||
; "randint_like"
|
; "randint_like"
|
||||||
; "randn_like" ]
|
; "randn_like" ]
|
||||||
|
|
||||||
let prefixed_functions =
|
(*
|
||||||
Set.of_list
|
* let prefixed_functions =
|
||||||
(module String)
|
* Set.of_list
|
||||||
["add"; "add_"; "div"; "div_"; "mul"; "mul_"; "sub"; "sub_"; "nll_loss"]
|
* (module String)
|
||||||
|
* ["add"; "add_"; "div"; "div_"; "mul"; "mul_"; "sub"; "sub_"; "nll_loss"]
|
||||||
|
* *)
|
||||||
let excluded_prefixes = ["_thnn_"; "_th_"; "thnn_"; "th_"]
|
let excluded_prefixes = ["_thnn_"; "_th_"; "thnn_"; "th_"]
|
||||||
|
|
||||||
let excluded_suffixes = ["_forward"; "_forward_out"]
|
let excluded_suffixes = ["_forward"; "_forward_out"]
|
||||||
|
@ -178,153 +179,291 @@ module Func = struct
|
||||||
Printf.failwithf "Method calls should have at least one argument %s"
|
Printf.failwithf "Method calls should have at least one argument %s"
|
||||||
t.name () )
|
t.name () )
|
||||||
|
|
||||||
let replace_map =
|
(*
|
||||||
Map.of_alist_exn
|
* let replace_map =
|
||||||
(module String)
|
* Map.of_alist_exn
|
||||||
[ ("t", "tr")
|
* (module String)
|
||||||
; ("where", "where_")
|
* [ ("t", "tr")
|
||||||
; ("view", "view_")
|
* ; ("where", "where_")
|
||||||
; ("unsafe", "unsafe_") ]
|
* ; ("view", "view_")
|
||||||
|
* ; ("unsafe", "unsafe_") ]
|
||||||
|
* *)
|
||||||
|
|
||||||
|
let is_method t =
|
||||||
|
List.exists t.args ~f:(fun arg ->
|
||||||
|
match arg.arg_name with "self" -> true | _ -> false )
|
||||||
|
|
||||||
let go_name name =
|
let go_name name =
|
||||||
let name =
|
let last_underscore name = Str.string_match (Str.regexp ".*_$") name 0 in
|
||||||
Map.find replace_map name |> Option.value ~default:name
|
let words = Str.split (Str.regexp "_") name in
|
||||||
|> String.capitalize
|
if last_underscore name then
|
||||||
|> String.substr_replace_all ~pattern:"__" ~with_:""
|
let cap_words = List.map words ~f:(fun word -> String.capitalize word) in
|
||||||
|
String.concat ~sep:"" cap_words ^ "_"
|
||||||
|
else
|
||||||
|
let cap_words = List.map words ~f:(fun word -> String.capitalize word) in
|
||||||
|
String.concat ~sep:"" cap_words
|
||||||
|
|
||||||
|
let go_variable name =
|
||||||
|
let goname = go_name name in
|
||||||
|
(* NOTE: Deal with Go namespace conflict *)
|
||||||
|
let safe_name =
|
||||||
|
match goname with
|
||||||
|
| "Var" -> "vari"
|
||||||
|
| "Unsafe" -> "unsafety"
|
||||||
|
| _ -> goname
|
||||||
in
|
in
|
||||||
if String.is_prefix name ~prefix:"_" then
|
String.uncapitalize safe_name
|
||||||
"Internal" ^ (name |> String.capitalize)
|
|
||||||
else name |> String.capitalize
|
|
||||||
|
|
||||||
let c_go_args_list t =
|
let c_go_args_list t =
|
||||||
List.map t.args ~f:(fun arg ->
|
List.map t.args ~f:(fun arg ->
|
||||||
let an = arg.arg_name in
|
let an = go_variable arg.arg_name in
|
||||||
let single_param = Printf.sprintf "%s %s" an in
|
let single_param = Printf.sprintf "%s %s" an in
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Bool -> single_param "C.int"
|
| Bool -> single_param "int32"
|
||||||
| Int64 -> single_param "C.long"
|
| Int64 -> single_param "int64"
|
||||||
| Double -> single_param "C.double"
|
| Double -> single_param "float64"
|
||||||
| Tensor -> single_param "Ctensor"
|
| Tensor -> single_param "Ctensor"
|
||||||
| TensorOption -> single_param "Ctensor"
|
| TensorOption -> single_param "Ctensor"
|
||||||
| Scalar -> single_param "Cscalar"
|
| Scalar -> single_param "Cscalar"
|
||||||
| ScalarType -> single_param "C.int"
|
| ScalarType -> single_param "int32"
|
||||||
| Device -> single_param "C.int"
|
| Device -> single_param "int32"
|
||||||
| String -> Printf.sprintf "%s_ptr C.int, %s_len C.int" an an
|
| String -> single_param "string"
|
||||||
| IntList -> Printf.sprintf "%s_data C.long, %s_len C.int" an an
|
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||||
| TensorList -> Printf.sprintf "%s_data Ctensor, %s_len C.int" an an
|
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||||
| TensorOptions ->
|
| TensorOptions -> Printf.sprintf "%sKind int32, %sDevice int32" an an
|
||||||
Printf.sprintf "%s_kind C.int, %s_device C.int" an an )
|
)
|
||||||
|> String.concat ~sep:", "
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
let c_go_args_list_notype t =
|
let c_go_args_list_notype t =
|
||||||
List.map t.args ~f:(fun arg ->
|
List.map t.args ~f:(fun arg ->
|
||||||
let an = arg.arg_name in
|
let an = go_variable arg.arg_name in
|
||||||
|
let an = match an with "var" -> "vari" | _ -> an in
|
||||||
let single_param = Printf.sprintf "%s %s" an in
|
let single_param = Printf.sprintf "%s %s" an in
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Bool -> single_param ""
|
| Bool -> Printf.sprintf "c%s" an
|
||||||
| Int64 -> single_param ""
|
| Int64 -> Printf.sprintf "c%s" an
|
||||||
| Double -> single_param ""
|
| Double -> Printf.sprintf "c%s" an
|
||||||
| Tensor -> single_param ""
|
| Tensor -> Printf.sprintf "%s" an
|
||||||
| TensorOption -> single_param ""
|
| TensorOption -> Printf.sprintf "%s" an
|
||||||
| Scalar -> single_param ""
|
| Scalar -> single_param ""
|
||||||
| ScalarType -> single_param ""
|
| ScalarType -> Printf.sprintf "c%s" an
|
||||||
| Device -> single_param ""
|
| Device -> Printf.sprintf "c%s" an
|
||||||
| String -> Printf.sprintf "%s_ptr, %s_len" an an
|
| String -> Printf.sprintf "c%s, c%sLen" an an
|
||||||
| IntList -> Printf.sprintf "%s_data, %s_len" an an
|
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||||
| TensorList -> Printf.sprintf "%s_data, %s_len" an an
|
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||||
| TensorOptions -> Printf.sprintf "%s_kind, %s_device" an an )
|
| TensorOptions -> Printf.sprintf "c%sKind, c%sDevice" an an )
|
||||||
|> String.concat ~sep:", "
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
let self_name = "self"
|
(* TODO: convert Go pointer to C pointer *)
|
||||||
|
let c_go_args_list_body t =
|
||||||
|
List.map t.args ~f:(fun arg ->
|
||||||
|
let an = go_variable arg.arg_name in
|
||||||
|
(* let single_param = Printf.sprintf "%s %s" an in *)
|
||||||
|
match arg.arg_type with
|
||||||
|
| Bool ->
|
||||||
|
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||||
|
| Int64 ->
|
||||||
|
Printf.sprintf "\nc%s := *(*C.int64_t)(unsafe.Pointer(&%s))" an an
|
||||||
|
| Double ->
|
||||||
|
Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an
|
||||||
|
| Tensor -> ""
|
||||||
|
| TensorOption -> ""
|
||||||
|
| Scalar -> ""
|
||||||
|
| ScalarType ->
|
||||||
|
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||||
|
| Device ->
|
||||||
|
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||||
|
| String ->
|
||||||
|
Printf.sprintf
|
||||||
|
"\n\
|
||||||
|
c%s := C.CString(%s)\n\
|
||||||
|
%sLen := len(%s)\n\
|
||||||
|
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||||
|
an an an an an an
|
||||||
|
| IntList ->
|
||||||
|
Printf.sprintf
|
||||||
|
"\n\
|
||||||
|
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
||||||
|
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||||
|
an an an an
|
||||||
|
| TensorList ->
|
||||||
|
Printf.sprintf
|
||||||
|
"\n\
|
||||||
|
c%sDataPtr := (*Ctensor)(unsafe.Pointer(&%sData[0]))\n\
|
||||||
|
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||||
|
an an an an
|
||||||
|
| TensorOptions ->
|
||||||
|
Printf.sprintf
|
||||||
|
"\n\
|
||||||
|
c%sKind := *(*C.int)(unsafe.Pointer(&%sKind))\n\
|
||||||
|
c%sDevice := *(*C.int)(unsafe.Pointer(&%sDevice))"
|
||||||
|
an an an an )
|
||||||
|
|> String.concat ~sep:""
|
||||||
|
|
||||||
(* let input_name = "input" *)
|
let self_name = "self"
|
||||||
|
|
||||||
let self_tensor arg =
|
let self_tensor arg =
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Tensor -> String.( = ) arg.arg_name self_name
|
| Tensor -> String.( = ) arg.arg_name self_name
|
||||||
| _ -> false
|
| _ -> false
|
||||||
|
|
||||||
let type_parameters t =
|
(*
|
||||||
let needs_scalar_parameter =
|
* let type_parameters t =
|
||||||
List.exists t.args ~f:(fun arg ->
|
* let needs_scalar_parameter =
|
||||||
match arg.arg_type with Scalar -> true | _ -> false )
|
* List.exists t.args ~f:(fun arg ->
|
||||||
in
|
* match arg.arg_type with Scalar -> true | _ -> false )
|
||||||
let needs_type_parameter =
|
* in
|
||||||
List.exists t.args ~f:(fun arg ->
|
* let needs_type_parameter =
|
||||||
match arg.arg_type with
|
* List.exists t.args ~f:(fun arg ->
|
||||||
| TensorList | TensorOption -> true
|
* match arg.arg_type with
|
||||||
| _ -> false )
|
* | TensorList | TensorOption -> true
|
||||||
in
|
* | _ -> false )
|
||||||
if needs_type_parameter && needs_scalar_parameter then "Tensor, Scalar"
|
* in
|
||||||
else if needs_type_parameter then "Tensor"
|
* if needs_type_parameter && needs_scalar_parameter then "Tensor, Scalar"
|
||||||
else if needs_scalar_parameter then "Scalar"
|
* else if needs_type_parameter then "Tensor"
|
||||||
else ""
|
* else if needs_scalar_parameter then "Scalar"
|
||||||
|
* else ""
|
||||||
|
* *)
|
||||||
|
|
||||||
|
(*
|
||||||
|
* let go_args_list t =
|
||||||
|
* (* https://ocaml.janestreet.com/ocaml-core/latest/doc/base/Base/List/#val-partition_tf *)
|
||||||
|
* (* TODO. implement special cases - TensorOptions, ... *)
|
||||||
|
* match List.partition_tf t.args ~f:self_tensor with _, args_list ->
|
||||||
|
* args_list
|
||||||
|
* *)
|
||||||
|
|
||||||
let go_args_list t =
|
let is_inplace t =
|
||||||
(* https://ocaml.janestreet.com/ocaml-core/latest/doc/base/Base/List/#val-partition_tf *)
|
match Str.string_match (Str.regexp ".*_$") t.name 0 with
|
||||||
match List.partition_tf t.args ~f:self_tensor with _, args_list ->
|
| true -> true
|
||||||
args_list
|
| _ -> false
|
||||||
|
|
||||||
let go_typed_args_list t =
|
let go_typed_args_list t =
|
||||||
let to_string args =
|
let to_string args =
|
||||||
List.map args ~f:(fun arg ->
|
let args_list =
|
||||||
let go_arg_type =
|
List.map args ~f:(fun arg ->
|
||||||
|
let go_arg_type =
|
||||||
|
match arg.arg_type with
|
||||||
|
| Bool -> "bool"
|
||||||
|
| Int64 -> "int64"
|
||||||
|
| Double -> "float64"
|
||||||
|
| Tensor -> "Tensor"
|
||||||
|
| TensorOption -> "Tensor"
|
||||||
|
| IntList -> "[]int64"
|
||||||
|
| TensorList -> "[]Tensor"
|
||||||
|
| String -> "string"
|
||||||
|
(* TODO. Struct{Kind gotch.DType Device gotch.Device} *)
|
||||||
|
(* E.g. `type KindDevice struct{}` *)
|
||||||
|
| TensorOptions -> "gotch.KindDevice"
|
||||||
|
| Scalar -> "Scalar"
|
||||||
|
| ScalarType -> "gotch.DType"
|
||||||
|
| Device -> "gotch.Device"
|
||||||
|
in
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Bool -> "bool"
|
| TensorOptions ->
|
||||||
| Int64 -> "int64"
|
Printf.sprintf "%sKind gotch.DType, %sDevice gotch.Device"
|
||||||
| Double -> "float64"
|
(go_variable arg.arg_name) (go_variable arg.arg_name)
|
||||||
| Tensor -> "Tensor"
|
| _ ->
|
||||||
| TensorOption -> "TensorOption"
|
Printf.sprintf "%s %s" (go_variable arg.arg_name) go_arg_type
|
||||||
| IntList -> "[]int64"
|
)
|
||||||
| TensorList -> "[]Tensor"
|
in
|
||||||
| String -> "string"
|
if is_method t && not (is_inplace t) then
|
||||||
| TensorOptions -> "(Kind, Device)"
|
args_list @ ["del bool"] |> String.concat ~sep:", "
|
||||||
| Scalar -> "Scalar"
|
else args_list |> String.concat ~sep:", "
|
||||||
| ScalarType -> "Kind"
|
|
||||||
| Device -> "Device"
|
|
||||||
in
|
|
||||||
Printf.sprintf "%s %s" (go_name arg.arg_name) go_arg_type )
|
|
||||||
|> String.concat ~sep:", "
|
|
||||||
in
|
in
|
||||||
let self_arg =
|
(* let self_arg = "self Tensor" in *)
|
||||||
"self Tensor"
|
match List.partition_tf t.args ~f:self_tensor with _, args_list ->
|
||||||
(* if String.is_suffix t.name ~suffix:"_" then "self" else "&self" *)
|
Printf.sprintf "%s" (to_string args_list)
|
||||||
|
|
||||||
|
let go_notype_args_list t =
|
||||||
|
let to_string args =
|
||||||
|
let args_list =
|
||||||
|
List.map args ~f:(fun arg ->
|
||||||
|
match arg.arg_type with
|
||||||
|
| TensorOptions ->
|
||||||
|
Printf.sprintf "%sKind, %sDevice" (go_variable arg.arg_name)
|
||||||
|
(go_variable arg.arg_name)
|
||||||
|
| _ -> Printf.sprintf "%s" (go_variable arg.arg_name) )
|
||||||
|
in
|
||||||
|
if is_method t && not (is_inplace t) then
|
||||||
|
args_list @ ["del"] |> String.concat ~sep:", "
|
||||||
|
else args_list |> String.concat ~sep:", "
|
||||||
in
|
in
|
||||||
match List.partition_tf t.args ~f:self_tensor with _, args_list ->
|
match List.partition_tf t.args ~f:self_tensor with _, args_list ->
|
||||||
Printf.sprintf "%s, %s" self_arg (to_string args_list)
|
Printf.sprintf "%s" (to_string args_list)
|
||||||
|
|
||||||
let go_return_type t ~fallible =
|
let go_return_type t ~fallible =
|
||||||
|
(* printf "t name: %s\n" t.name ; *)
|
||||||
let returns =
|
let returns =
|
||||||
match t.returns with
|
match t.returns with
|
||||||
| `fixed 1 -> "Tensor"
|
| `fixed 1 -> "retVal Tensor"
|
||||||
| `fixed v ->
|
| `fixed v ->
|
||||||
List.init v ~f:(fun _ -> "Tensor")
|
List.init v ~f:(fun i -> Printf.sprintf "retVal%d Tensor" i)
|
||||||
|> String.concat ~sep:", " |> Printf.sprintf "(%s)"
|
|> String.concat ~sep:", " |> Printf.sprintf "%s"
|
||||||
| `dynamic -> "[]Tensor"
|
| `dynamic -> "retVal []Tensor"
|
||||||
in
|
in
|
||||||
if fallible then Printf.sprintf "(error, %s)" returns
|
if is_inplace t then
|
||||||
else Printf.sprintf " %s" returns
|
if fallible then Printf.sprintf "err error" else Printf.sprintf ""
|
||||||
|
else if fallible then Printf.sprintf "%s, err error" returns
|
||||||
|
else Printf.sprintf "%s" returns
|
||||||
|
|
||||||
|
let go_return_notype t ~fallible =
|
||||||
|
let returns =
|
||||||
|
match t.returns with
|
||||||
|
| `fixed 1 -> "retVal"
|
||||||
|
| `fixed v ->
|
||||||
|
List.init v ~f:(fun i -> Printf.sprintf "retVal%d" i)
|
||||||
|
|> String.concat ~sep:", " |> Printf.sprintf "%s"
|
||||||
|
| `dynamic -> "retVal"
|
||||||
|
in
|
||||||
|
if is_inplace t then
|
||||||
|
if fallible then Printf.sprintf "err" else Printf.sprintf ""
|
||||||
|
else if fallible then Printf.sprintf "%s, err" returns
|
||||||
|
else Printf.sprintf "%s" returns
|
||||||
|
|
||||||
let go_binding_args t =
|
let go_binding_args t =
|
||||||
List.map t.args ~f:(fun arg ->
|
List.map t.args ~f:(fun arg ->
|
||||||
let name = go_name arg.arg_name in
|
let name = go_variable arg.arg_name in
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Tensor -> Printf.sprintf "%s.c_tensor" name
|
| Tensor ->
|
||||||
| Scalar -> Printf.sprintf "%s.c_scalar" name
|
if String.( = ) name "self" then "ts.ctensor"
|
||||||
| Bool -> Printf.sprintf "if %s { 1 } else { 0 }" name
|
else Printf.sprintf "%s.ctensor" name
|
||||||
| ScalarType -> Printf.sprintf "%s.c_int()" name
|
| Scalar -> Printf.sprintf "%s.cscalar" name
|
||||||
| Device -> Printf.sprintf "%s.c_int()" name
|
| Bool -> Printf.sprintf "c%s" name
|
||||||
|
| ScalarType -> Printf.sprintf "%s.CInt()" name
|
||||||
|
| Device -> Printf.sprintf "%s.CInt()" name
|
||||||
| TensorOptions ->
|
| TensorOptions ->
|
||||||
Printf.sprintf "%s.0.c_int(), %s.1.c_int()" name name
|
Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name
|
||||||
| String -> Printf.sprintf "%s.as_ptr(), %s.len() int32" name name
|
| String -> Printf.sprintf "%s" name
|
||||||
| IntList -> Printf.sprintf "%s.as_ptr(), %s.len() int32" name name
|
| IntList -> Printf.sprintf "%s, len(%s)" name name
|
||||||
| TensorList ->
|
| TensorList -> Printf.sprintf "c%s, len(c%s)" name name
|
||||||
Printf.sprintf "ptr_list(%s).as_ptr(), %s.len() int32" name name
|
| TensorOption -> Printf.sprintf "%s.ctensor" name
|
||||||
| TensorOption -> Printf.sprintf "%s.c_tensor)" name
|
|
||||||
| Int64 when String.( = ) name "reduction" -> "reduction.to_int()"
|
|
||||||
| _ -> name )
|
| _ -> name )
|
||||||
(* |> String.concat ~sep:",\n " *)
|
|
||||||
|> String.concat ~sep:", "
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
|
let go_binding_body t =
|
||||||
|
List.map t.args ~f:(fun arg ->
|
||||||
|
let an = go_variable arg.arg_name in
|
||||||
|
match arg.arg_type with
|
||||||
|
| Bool ->
|
||||||
|
Printf.sprintf "c%s := int32(0)\n if %s { c%s = int32(1) }\n" an an
|
||||||
|
an
|
||||||
|
| Int64 -> ""
|
||||||
|
| Double -> ""
|
||||||
|
| Tensor -> ""
|
||||||
|
| TensorOption -> ""
|
||||||
|
| Scalar -> ""
|
||||||
|
| ScalarType -> ""
|
||||||
|
| Device -> ""
|
||||||
|
| String -> ""
|
||||||
|
| IntList -> ""
|
||||||
|
| TensorList ->
|
||||||
|
Printf.sprintf
|
||||||
|
" var c%s []lib.Ctensor\n\
|
||||||
|
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
|
||||||
|
an an an an
|
||||||
|
| TensorOptions -> "" )
|
||||||
|
|> String.concat ~sep:""
|
||||||
end
|
end
|
||||||
|
|
||||||
exception Not_a_simple_arg
|
exception Not_a_simple_arg
|
||||||
|
@ -494,110 +633,280 @@ let write_cpp funcs filename =
|
||||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list ) )
|
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list ) )
|
||||||
)
|
)
|
||||||
|
|
||||||
let write_fallible_wrapper funcs filename =
|
|
||||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
|
||||||
let pm s = print_inline out_ml s in
|
|
||||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
|
|
||||||
pm "\n" ;
|
|
||||||
pm "package libtch" ;
|
|
||||||
pm "\n\n" ;
|
|
||||||
pm "func ptr_list(l []Tensor) []*C_tensor {\n" ;
|
|
||||||
pm " var retVal []*C_tensor \n" ;
|
|
||||||
pm " for _, x := range l{ \n" ;
|
|
||||||
pm " retVal = append(retVal, x) \n" ;
|
|
||||||
pm " } \n" ;
|
|
||||||
pm "} \n" ;
|
|
||||||
pm "\n" ;
|
|
||||||
(* Implement Tensor *)
|
|
||||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
|
|
||||||
let go_name = Func.go_name exported_name in
|
|
||||||
let go_args_list = Func.go_typed_args_list func in
|
|
||||||
pm "\n" ;
|
|
||||||
pm "func f_%s%s(" go_name (Func.type_parameters func) ;
|
|
||||||
pm "%s" go_args_list ;
|
|
||||||
pm ")%s { \n" (Func.go_return_type func ~fallible:true) ;
|
|
||||||
match func.returns with
|
|
||||||
| `dynamic ->
|
|
||||||
pm " c_tensors := unsafe_torch_err!({" ;
|
|
||||||
pm "atg_%s(" exported_name ;
|
|
||||||
pm "%s)}) \n" (Func.go_binding_args func) ;
|
|
||||||
pm " var r__ []Tensor \n" ;
|
|
||||||
pm " i := 0 \n" ;
|
|
||||||
pm " for { \n" ;
|
|
||||||
pm " c__ := unsafe{*c_tensors.add(i)} \n" ;
|
|
||||||
pm " if c__.is_null() { break } \n" ;
|
|
||||||
pm " r__ = append(r__, Tensor {C_tensor: c__}) \n" ;
|
|
||||||
pm " i += 1 \n" ;
|
|
||||||
pm " } \n" ;
|
|
||||||
(* pm " // unsafe{libc::free(c_tensors as *mut libc::c_void)}" ; *)
|
|
||||||
pm " return r__ \n" ;
|
|
||||||
pm "} \n"
|
|
||||||
| `fixed ntensors ->
|
|
||||||
pm " var c_tensors []C_tensor = make([]C_tensor, %d) \n"
|
|
||||||
ntensors ;
|
|
||||||
pm " unsafe_torch_err({ \n" ;
|
|
||||||
pm " atg_%s(c_tensors, " exported_name ;
|
|
||||||
pm "%s) \n" (Func.go_binding_args func) ;
|
|
||||||
pm " }) \n" ;
|
|
||||||
let returns =
|
|
||||||
if ntensors = 1 then "Tensor { C_tensor: c_tensors[0] }"
|
|
||||||
else
|
|
||||||
List.init ntensors
|
|
||||||
~f:(Printf.sprintf "Tensor { C_tensor: c_tensors[%d] }")
|
|
||||||
|> String.concat ~sep:", " |> Printf.sprintf "(%s)"
|
|
||||||
in
|
|
||||||
pm " return %s \n" returns ;
|
|
||||||
pm "} \n" ) )
|
|
||||||
|
|
||||||
let write_wrapper funcs filename =
|
let write_wrapper funcs filename =
|
||||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||||
let pm s = print_inline out_ml s in
|
let pm s = print_inline out_ml s in
|
||||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
|
pm "package tensor" ;
|
||||||
pm "\n\n" ;
|
pm "\n\n" ;
|
||||||
pm "package libtch" ;
|
pm "// NOTE. THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
|
||||||
pm "\n\n" ;
|
pm "\n\n" ;
|
||||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
|
pm "// #include \"stdlib.h\"\n" ;
|
||||||
let go_name = Func.go_name exported_name in
|
pm "import \"C\"" ;
|
||||||
let go_name, fallible_go_name =
|
pm "" ;
|
||||||
if Set.mem prefixed_functions func.name then
|
pm "\n\n" ;
|
||||||
("g_" ^ go_name, "f_" ^ go_name)
|
pm "import(\n" ;
|
||||||
else (go_name, "f_" ^ go_name)
|
pm " \"unsafe\"\n" ;
|
||||||
|
pm "\n" ;
|
||||||
|
pm " \"github.com/sugarme/gotch\"\n" ;
|
||||||
|
pm " lib \"github.com/sugarme/gotch/libtch\"\n" ;
|
||||||
|
pm ")" ;
|
||||||
|
pm "\n\n" ;
|
||||||
|
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||||
|
let is_method = Func.is_method func in
|
||||||
|
let is_inplace = Func.is_inplace func in
|
||||||
|
(* NOTE. `torch.__PATTERN` *)
|
||||||
|
let prefix_2underscore exported_name =
|
||||||
|
Str.string_match (Str.regexp "^__") exported_name 0
|
||||||
in
|
in
|
||||||
pm "\n" ;
|
(* NOTE. `torch._PATTERN` *)
|
||||||
pm "func %s%s(" go_name (Func.type_parameters func) ;
|
let prefix_1underscore exported_name =
|
||||||
|
Str.string_match (Str.regexp "^_") exported_name 0
|
||||||
|
in
|
||||||
|
(* NOTE. `torch.PATTERN_1` *)
|
||||||
|
let suffix_1 exported_name =
|
||||||
|
Str.string_match (Str.regexp ".*_1$") exported_name 0
|
||||||
|
in
|
||||||
|
let gofunc_name =
|
||||||
|
if prefix_2underscore exported_name then
|
||||||
|
"__" ^ Func.go_name exported_name
|
||||||
|
else if prefix_1underscore exported_name then
|
||||||
|
"_" ^ Func.go_name exported_name
|
||||||
|
else if suffix_1 exported_name then
|
||||||
|
Func.go_name exported_name ^ "_"
|
||||||
|
else Func.go_name exported_name
|
||||||
|
in
|
||||||
|
let cfunc_name = "lib.Atg" ^ gofunc_name in
|
||||||
let go_args_list = Func.go_typed_args_list func in
|
let go_args_list = Func.go_typed_args_list func in
|
||||||
pm "%s" go_args_list ;
|
(* NOTE. temporarily excluding these functions as not implemented at FFI *)
|
||||||
pm ")%s {\n" (Func.go_return_type func ~fallible:false) ;
|
(* TODO. implement multiple tensors return function []Tensor *)
|
||||||
let go_args_list = Func.go_args_list func in
|
let excluded_funcs =
|
||||||
let go_args_list =
|
[ "Chunk"
|
||||||
List.map go_args_list ~f:(fun arg -> Func.go_name arg.Func.arg_name)
|
; "AlignTensors"
|
||||||
|> String.concat ~sep:", "
|
; "BroadcastTensors"
|
||||||
|
; "Meshgrid"
|
||||||
|
; "NonzeroNumpy"
|
||||||
|
; "Split"
|
||||||
|
; "SplitWithSizes"
|
||||||
|
; "Unbind"
|
||||||
|
; "Where" ]
|
||||||
in
|
in
|
||||||
pm " %s(%s)\n" fallible_go_name go_args_list ;
|
if
|
||||||
pm "}\n" ) ;
|
List.exists excluded_funcs ~f:(fun name ->
|
||||||
|
String.( = ) name gofunc_name )
|
||||||
|
then pm ""
|
||||||
|
else
|
||||||
|
match func.returns with
|
||||||
|
| `dynamic ->
|
||||||
|
pm "\n" ;
|
||||||
|
if is_method then pm "func(ts Tensor) %s(" gofunc_name
|
||||||
|
else pm "func %s(" gofunc_name ;
|
||||||
|
pm "%s" go_args_list ;
|
||||||
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
|
if is_method && not is_inplace then
|
||||||
|
pm "if del { defer ts.MustDrop() }\n" ;
|
||||||
|
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
|
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
|
pm " return %s\n"
|
||||||
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
pm " }\n" ;
|
||||||
|
(* NOTE. if in_place method, no retVal return *)
|
||||||
|
if not (Func.is_inplace func) then
|
||||||
|
pm " retVal = Tensor{ctensor: *ptr}\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||||
|
pm "} \n"
|
||||||
|
| `fixed 1 ->
|
||||||
|
pm "\n" ;
|
||||||
|
if is_method then pm "func(ts Tensor) %s(" gofunc_name
|
||||||
|
else pm "func %s(" gofunc_name ;
|
||||||
|
pm "%s" go_args_list ;
|
||||||
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
|
if is_method && not is_inplace then
|
||||||
|
pm "if del { defer ts.MustDrop() }\n" ;
|
||||||
|
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
|
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
|
pm " return %s\n"
|
||||||
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
pm " }\n" ;
|
||||||
|
(* NOTE. if in_place method, no retVal return *)
|
||||||
|
if not (Func.is_inplace func) then
|
||||||
|
pm " retVal = Tensor{ctensor: *ptr}\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||||
|
pm "} \n"
|
||||||
|
| `fixed _ -> pm "" ) ;
|
||||||
|
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||||
|
pm "// End of implementing Tensor ================================= \n"
|
||||||
|
)
|
||||||
|
|
||||||
|
let write_must_wrapper funcs filename =
|
||||||
|
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||||
|
let pm s = print_inline out_ml s in
|
||||||
|
pm "package tensor" ;
|
||||||
|
pm "\n\n" ;
|
||||||
|
pm "// NOTE. THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
|
||||||
|
pm "\n\n" ;
|
||||||
|
pm "import(\n" ;
|
||||||
|
pm " \"log\"\n" ;
|
||||||
|
pm "\n" ;
|
||||||
|
pm " \"github.com/sugarme/gotch\"\n" ;
|
||||||
|
pm ")" ;
|
||||||
|
pm "\n\n" ;
|
||||||
|
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||||
|
let is_method = Func.is_method func in
|
||||||
|
(* NOTE. `torch.__PATTERN` *)
|
||||||
|
let prefix_2underscore exported_name =
|
||||||
|
Str.string_match (Str.regexp "^__") exported_name 0
|
||||||
|
in
|
||||||
|
(* NOTE. `torch._PATTERN` *)
|
||||||
|
let prefix_1underscore exported_name =
|
||||||
|
Str.string_match (Str.regexp "^_") exported_name 0
|
||||||
|
in
|
||||||
|
(* NOTE. `torch.PATTERN_1` *)
|
||||||
|
let suffix_1 exported_name =
|
||||||
|
Str.string_match (Str.regexp ".*_1$") exported_name 0
|
||||||
|
in
|
||||||
|
let gofunc_name =
|
||||||
|
if prefix_2underscore exported_name then
|
||||||
|
"__" ^ Func.go_name exported_name
|
||||||
|
else if prefix_1underscore exported_name then
|
||||||
|
"_" ^ Func.go_name exported_name
|
||||||
|
else if suffix_1 exported_name then
|
||||||
|
Func.go_name exported_name ^ "_"
|
||||||
|
else Func.go_name exported_name
|
||||||
|
in
|
||||||
|
let go_args_list = Func.go_typed_args_list func in
|
||||||
|
let go_args_list_notype = Func.go_notype_args_list func in
|
||||||
|
(* NOTE. temporarily excluding these functions as not implemented at FFI *)
|
||||||
|
(* TODO. implement multiple tensors return function []Tensor *)
|
||||||
|
let excluded_funcs =
|
||||||
|
[ "Chunk"
|
||||||
|
; "AlignTensors"
|
||||||
|
; "BroadcastTensors"
|
||||||
|
; "Meshgrid"
|
||||||
|
; "NonzeroNumpy"
|
||||||
|
; "Split"
|
||||||
|
; "SplitWithSizes"
|
||||||
|
; "Unbind"
|
||||||
|
; "Where" ]
|
||||||
|
in
|
||||||
|
if
|
||||||
|
List.exists excluded_funcs ~f:(fun name ->
|
||||||
|
String.( = ) name gofunc_name )
|
||||||
|
then pm ""
|
||||||
|
else
|
||||||
|
match func.returns with
|
||||||
|
| `dynamic ->
|
||||||
|
pm "\n" ;
|
||||||
|
if is_method then pm "func(ts Tensor) %s(" gofunc_name
|
||||||
|
else pm "func Must%s(" gofunc_name ;
|
||||||
|
pm "%s" go_args_list ;
|
||||||
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||||
|
pm " \n" ;
|
||||||
|
if is_method then
|
||||||
|
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||||
|
go_args_list_notype
|
||||||
|
else
|
||||||
|
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||||
|
go_args_list_notype ;
|
||||||
|
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||||
|
pm "} \n"
|
||||||
|
| `fixed 1 ->
|
||||||
|
pm "\n" ;
|
||||||
|
if is_method then pm "func(ts Tensor) Must%s(" gofunc_name
|
||||||
|
else pm "func Must%s(" gofunc_name ;
|
||||||
|
pm "%s" go_args_list ;
|
||||||
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||||
|
pm " \n" ;
|
||||||
|
(* NOTE. No return retVal for in_place method *)
|
||||||
|
if Func.is_inplace func then
|
||||||
|
if is_method then
|
||||||
|
pm " err := ts.%s(%s)\n" gofunc_name go_args_list_notype
|
||||||
|
else pm " err := %s(%s)\n" gofunc_name go_args_list_notype
|
||||||
|
else if is_method then
|
||||||
|
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||||
|
go_args_list_notype
|
||||||
|
else
|
||||||
|
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||||
|
go_args_list_notype ;
|
||||||
|
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||||
|
pm "} \n"
|
||||||
|
| `fixed _ -> pm "" ) ;
|
||||||
|
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||||
pm "// End of implementing Tensor ================================= \n"
|
pm "// End of implementing Tensor ================================= \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
let write_ffi funcs filename =
|
let write_ffi funcs filename =
|
||||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||||
let pm s = p out_ml s in
|
let pm s = p out_ml s in
|
||||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
|
|
||||||
pm "package libtch" ;
|
pm "package libtch" ;
|
||||||
pm "" ;
|
pm "" ;
|
||||||
pm "// #include \"stdbool.h\" " ;
|
pm "// NOTE. THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!" ;
|
||||||
pm "// #include \"torch_api.h\" " ;
|
pm "" ;
|
||||||
|
pm "//#include \"stdbool.h\" " ;
|
||||||
|
pm "//#include \"torch_api.h\" " ;
|
||||||
pm "import \"C\"" ;
|
pm "import \"C\"" ;
|
||||||
pm "" ;
|
pm "" ;
|
||||||
|
pm "import \"unsafe\"" ;
|
||||||
|
pm "" ;
|
||||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||||
|
(* let is_method = *)
|
||||||
|
(* match func.Func.kind with `method_ -> true | `function_ -> false *)
|
||||||
|
(* in *)
|
||||||
|
(* let is_inplace = *)
|
||||||
|
(* Func.is_inplace func *)
|
||||||
|
(*
|
||||||
|
* match exported_name with
|
||||||
|
* | "add_1" -> true
|
||||||
|
* | "sub_1" -> true
|
||||||
|
* | "div_1" -> true
|
||||||
|
* | "mul_1" -> true
|
||||||
|
* | _ -> false
|
||||||
|
* *)
|
||||||
|
(* in *)
|
||||||
|
(* NOTE. `torch.__PATTERN` *)
|
||||||
|
let prefix_2underscore exported_name =
|
||||||
|
Str.string_match (Str.regexp "^__") exported_name 0
|
||||||
|
in
|
||||||
|
(* NOTE. `torch._PATTERN` *)
|
||||||
|
let prefix_1underscore exported_name =
|
||||||
|
Str.string_match (Str.regexp "^_") exported_name 0
|
||||||
|
in
|
||||||
|
(* NOTE. `torch.PATTERN_1` *)
|
||||||
|
let suffix_1 exported_name =
|
||||||
|
Str.string_match (Str.regexp ".*_1$") exported_name 0
|
||||||
|
in
|
||||||
|
let ffifunc_name =
|
||||||
|
if prefix_2underscore exported_name then
|
||||||
|
"__" ^ Func.go_name exported_name
|
||||||
|
else if prefix_1underscore exported_name then
|
||||||
|
"_" ^ Func.go_name exported_name
|
||||||
|
else if suffix_1 exported_name then
|
||||||
|
Func.go_name exported_name ^ "_"
|
||||||
|
else Func.go_name exported_name
|
||||||
|
in
|
||||||
match func.Func.returns with
|
match func.Func.returns with
|
||||||
| `fixed _ ->
|
| `fixed _ ->
|
||||||
pm "func Atg_%s(ptr *Ctensor, %s){C.atg_%s(ptr, %s)}"
|
pm "func Atg%s(ptr *Ctensor, %s){%s \nC.atg_%s(ptr, %s)\n}"
|
||||||
(Func.go_name exported_name)
|
ffifunc_name (Func.c_go_args_list func)
|
||||||
(Func.c_go_args_list func) exported_name
|
(Func.c_go_args_list_body func)
|
||||||
|
exported_name
|
||||||
(Func.c_go_args_list_notype func)
|
(Func.c_go_args_list_notype func)
|
||||||
| `dynamic ->
|
| `dynamic -> pm ""
|
||||||
pm "func Atg_%s(%s)(*Ctensor)" exported_name
|
(* TODO: need more implement here *)
|
||||||
(Func.c_go_args_list func) ) )
|
(* pm "func Atg%s(%s)(retValPtr *Ctensor)" *)
|
||||||
|
(* (Func.go_name exported_name) *)
|
||||||
|
(* (Func.c_go_args_list func) *) ) )
|
||||||
|
|
||||||
let methods =
|
let methods =
|
||||||
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
||||||
|
@ -607,8 +916,8 @@ let methods =
|
||||||
; c "toType" [ca "self" Tensor; ca "scalar_type" ScalarType]
|
; c "toType" [ca "self" Tensor; ca "scalar_type" ScalarType]
|
||||||
; c "to" [ca "self" Tensor; ca "device" Device] ]
|
; c "to" [ca "self" Tensor; ca "device" Device] ]
|
||||||
|
|
||||||
let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
|
let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
||||||
~fallible_wrapper_filename =
|
~wrapper_filename =
|
||||||
let funcs = read_yaml yaml_filename in
|
let funcs = read_yaml yaml_filename in
|
||||||
let funcs = methods @ funcs in
|
let funcs = methods @ funcs in
|
||||||
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
|
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
|
||||||
|
@ -631,11 +940,12 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
|
||||||
in
|
in
|
||||||
write_cpp funcs cpp_filename ;
|
write_cpp funcs cpp_filename ;
|
||||||
write_ffi funcs ffi_filename ;
|
write_ffi funcs ffi_filename ;
|
||||||
write_wrapper funcs wrapper_filename ;
|
write_must_wrapper funcs must_wrapper_filename ;
|
||||||
write_fallible_wrapper funcs fallible_wrapper_filename
|
write_wrapper funcs wrapper_filename
|
||||||
|
|
||||||
let () =
|
let () =
|
||||||
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
|
run ~yaml_filename:"gen/pytorch/Declarations-v1.5.0.yaml"
|
||||||
~cpp_filename:"tmp/torch_api_generated" ~ffi_filename:"tmp/c_generated.go"
|
~cpp_filename:"libtch/torch_api_generated"
|
||||||
~wrapper_filename:"tmp/tensor_generated.go"
|
~ffi_filename:"libtch/c-generated.go"
|
||||||
~fallible_wrapper_filename:"tmp/tensor_fallible_generated.go"
|
~must_wrapper_filename:"tensor/must-tensor-generated.go"
|
||||||
|
~wrapper_filename:"tensor/tensor-generated.go"
|
||||||
|
|
|
@ -1,732 +0,0 @@
|
||||||
// NOTE: this is a sample for OCaml generated code for `c-generated.go`
|
|
||||||
package libtch
|
|
||||||
|
|
||||||
//#include "stdbool.h"
|
|
||||||
//#include "torch_api.h"
|
|
||||||
import "C"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// void atg_eq1(tensor *, tensor self, tensor other);
|
|
||||||
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_eq1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_matmul(tensor *, tensor self, tensor other);
|
|
||||||
func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_matmul(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_to(tensor *, tensor self, int device);
|
|
||||||
func AtgTo(ptr *Ctensor, self Ctensor, device int) {
|
|
||||||
cdevice := *(*C.int)(unsafe.Pointer(&device))
|
|
||||||
C.atg_to(ptr, self, cdevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_grad(tensor *, tensor self);
|
|
||||||
func AtgGrad(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_grad(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_detach_(tensor *, tensor self);
|
|
||||||
func AtgDetach_(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_detach_(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_detach(tensor *, tensor self);
|
|
||||||
func AtgDetach(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_detach(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_zero_(tensor *, tensor self);
|
|
||||||
func AtgZero_(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_zero_(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_set_requires_grad(tensor *, tensor self, int r);
|
|
||||||
func AtgSetRequiresGrad(ptr *Ctensor, self Ctensor, r int) {
|
|
||||||
cr := *(*C.int)(unsafe.Pointer(&r))
|
|
||||||
C.atg_set_requires_grad(ptr, self, cr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mul(tensor *, tensor self, tensor other);
|
|
||||||
func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_mul(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mul_(tensor *, tensor self, tensor other);
|
|
||||||
func AtgMul_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_mul_(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mul1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgMul1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_mul1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mul_1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgMul1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_mul_1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_add(tensor *, tensor self, tensor other);
|
|
||||||
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_add(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_add_(tensor *, tensor self, tensor other);
|
|
||||||
func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_add_(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// id atg_add1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgAdd1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_add1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_add_1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgAdd1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_add_1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_totype(tensor *, tensor self, int scalar_type);
|
|
||||||
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
|
||||||
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
|
||||||
C.atg_totype(ptr, self, cscalar_type)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_unsqueeze(tensor *, tensor self, int64_t dim);
|
|
||||||
func AtgUnsqueeze(ptr *Ctensor, self Ctensor, dim int64) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
C.atg_unsqueeze(ptr, self, cdim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_select(tensor *, tensor self, int64_t dim, int64_t index);
|
|
||||||
func AtgSelect(ptr *Ctensor, self Ctensor, dim int64, index int64) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
cindex := *(*C.int64_t)(unsafe.Pointer(&index))
|
|
||||||
C.atg_select(ptr, self, cdim, cindex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_narrow(tensor *, tensor self, int64_t dim, int64_t start, int64_t length);
|
|
||||||
func AtgNarrow(ptr *Ctensor, self Ctensor, dim int64, start int64, length int64) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
cstart := *(*C.int64_t)(unsafe.Pointer(&start))
|
|
||||||
clength := *(*C.int64_t)(unsafe.Pointer(&length))
|
|
||||||
C.atg_narrow(ptr, self, cdim, cstart, clength)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_index_select(tensor *, tensor self, int64_t dim, tensor index);
|
|
||||||
func AtgIndexSelect(ptr *Ctensor, self Ctensor, dim int64, index Ctensor) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
C.atg_index_select(ptr, self, cdim, index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_zeros(tensor *, int64_t *size_data, int size_len, int options_kind, int options_device);
|
|
||||||
func AtgZeros(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind, optionsDevice int32) {
|
|
||||||
// just get pointer of the first element of the shape(sizeData)
|
|
||||||
csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
|
||||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
|
||||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_zeros(ptr, csizeDataPtr, csizeLen, coptionsKind, coptionsDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_ones(tensor *, int64_t *size_data, int size_len, int options_kind, int options_device);
|
|
||||||
func AtgOnes(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind, optionsDevice int32) {
|
|
||||||
// just get pointer of the first element of the shape(sizeData)
|
|
||||||
csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
|
||||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
|
||||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_ones(ptr, csizeDataPtr, csizeLen, coptionsKind, coptionsDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_uniform_(tensor *, tensor self, double from, double to);
|
|
||||||
func AtgUniform_(ptr *Ctensor, self Ctensor, from float64, to float64) {
|
|
||||||
cfrom := *(*C.double)(unsafe.Pointer(&from))
|
|
||||||
cto := *(*C.double)(unsafe.Pointer(&to))
|
|
||||||
|
|
||||||
C.atg_uniform_(ptr, self, cfrom, cto)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_zeros_like(tensor *, tensor self);
|
|
||||||
func AtgZerosLike(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_zeros_like(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_fill_(tensor *, tensor self, scalar value);
|
|
||||||
func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) {
|
|
||||||
C.atg_fill_(ptr, self, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_randn_like(tensor *, tensor self);
|
|
||||||
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_rand_like(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_log_softmax(tensor *, tensor self, int64_t dim, int dtype);
|
|
||||||
func AtgLogSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
|
||||||
|
|
||||||
C.atg_log_softmax(ptr, self, cdim, cdtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_nll_loss(tensor *, tensor self, tensor target, tensor weight, int64_t reduction, int64_t ignore_index);
|
|
||||||
func AtgNllLoss(ptr *Ctensor, self Ctensor, target Ctensor, weight Ctensor, reduction int64, ignoreIndex int64) {
|
|
||||||
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
|
|
||||||
cignoreIndex := *(*C.int64_t)(unsafe.Pointer(&ignoreIndex))
|
|
||||||
|
|
||||||
C.atg_nll_loss(ptr, self, target, weight, creduction, cignoreIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_argmax(tensor *, tensor self, int64_t dim, int keepdim);
|
|
||||||
func AtgArgmax(ptr *Ctensor, self Ctensor, dim int64, keepDim int) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
ckeepDim := *(*C.int)(unsafe.Pointer(&keepDim))
|
|
||||||
|
|
||||||
C.atg_argmax(ptr, self, cdim, ckeepDim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mean(tensor *, tensor self, int dtype);
|
|
||||||
func AtgMean(ptr *Ctensor, self Ctensor, dtype int32) {
|
|
||||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
|
||||||
|
|
||||||
C.atg_mean(ptr, self, cdtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mean1(tensor *, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype);
|
|
||||||
func AtgMean1(ptr *Ctensor, self Ctensor, dimData []int64, dimLen int, keepDim int, dtype int32) {
|
|
||||||
|
|
||||||
cdimDataPtr := (*C.int64_t)(unsafe.Pointer(&dimData[0]))
|
|
||||||
cdimLen := *(*C.int)(unsafe.Pointer(&dimLen))
|
|
||||||
ckeepDim := *(*C.int)(unsafe.Pointer(&keepDim))
|
|
||||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
|
||||||
|
|
||||||
C.atg_mean1(ptr, self, cdimDataPtr, cdimLen, ckeepDim, cdtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_permute(tensor *, tensor self, int64_t *dims_data, int dims_len);
|
|
||||||
func AtgPermute(ptr *Ctensor, self Ctensor, dims []int64, dimLen int) {
|
|
||||||
// just get pointer of the first element of the shape
|
|
||||||
cdimsPtr := (*C.int64_t)(unsafe.Pointer(&dims[0]))
|
|
||||||
cdimLen := *(*C.int)(unsafe.Pointer(&dimLen))
|
|
||||||
|
|
||||||
C.atg_permute(ptr, self, cdimsPtr, cdimLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_squeeze1(tensor *, tensor self, int64_t dim);
|
|
||||||
func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
|
|
||||||
C.atg_squeeze1(ptr, self, cdim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_squeeze_(tensor *, tensor self);
|
|
||||||
func AtgSqueeze_(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_squeeze_(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_stack(tensor *, tensor *tensors_data, int tensors_len, int64_t dim);
|
|
||||||
func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
|
|
||||||
tensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
|
|
||||||
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
|
|
||||||
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mm(tensor *, tensor self, tensor mat2);
|
|
||||||
func AtgMm(ptr *Ctensor, self Ctensor, mat2 Ctensor) {
|
|
||||||
C.atg_mm(ptr, self, mat2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_view(tensor *, tensor self, int64_t *size_data, int size_len);
|
|
||||||
func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
|
|
||||||
sizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
|
||||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
|
||||||
|
|
||||||
C.atg_view(ptr, self, sizeDataPtr, csizeLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_div1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_div1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_div(tensor *, tensor self, tensor other);
|
|
||||||
func AtgDiv(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_div(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_div_(tensor *, tensor self, tensor other);
|
|
||||||
func AtgDiv_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_div_(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_div_1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgDiv1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_div_1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
|
|
||||||
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
|
||||||
cn := *(*C.int64_t)(unsafe.Pointer(&n))
|
|
||||||
coptionKind := *(*C.int)(unsafe.Pointer(&optionKind))
|
|
||||||
coptionDevice := *(*C.int)(unsafe.Pointer(&optionDevice))
|
|
||||||
|
|
||||||
C.atg_randperm(ptr, cn, coptionKind, coptionDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_clamp_(tensor *, tensor self, scalar min, scalar max);
|
|
||||||
func AtgClamp_(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
|
|
||||||
C.atg_clamp_(ptr, self, min, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_clamp(tensor *, tensor self, scalar min, scalar max);
|
|
||||||
func AtgClamp(ptr *Ctensor, self Ctensor, min Cscalar, max Cscalar) {
|
|
||||||
C.atg_clamp(ptr, self, min, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_clamp_max(tensor *, tensor self, scalar max);
|
|
||||||
func AtgClampMax(ptr *Ctensor, self Ctensor, max Cscalar) {
|
|
||||||
C.atg_clamp_max(ptr, self, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_relu(tensor *, tensor self);
|
|
||||||
func AtgRelu(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_relu(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_relu_(tensor *, tensor self);
|
|
||||||
func AtgRelu_(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_relu_(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_t(tensor *, tensor self);
|
|
||||||
func AtgT(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_t(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_t_(tensor *, tensor self);
|
|
||||||
func AtgT_(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_t_(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_mse_loss(tensor *, tensor self, tensor target, int64_t reduction);
|
|
||||||
func AtgMseLoss(ptr *Ctensor, self Ctensor, target Ctensor, reduction int) {
|
|
||||||
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
|
|
||||||
|
|
||||||
C.atg_mse_loss(ptr, self, target, creduction)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_exp(tensor *, tensor self);
|
|
||||||
func AtgExp(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_exp(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_exp_(tensor *, tensor self);
|
|
||||||
func AtgExp_(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_exp_(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_pow(tensor *, tensor self, scalar exponent);
|
|
||||||
func AtgPow(ptr *Ctensor, self Ctensor, exponent Cscalar) {
|
|
||||||
C.atg_pow(ptr, self, exponent)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_sum(tensor *, tensor self, int dtype);
|
|
||||||
func AtgSum(ptr *Ctensor, self Ctensor, dtype int32) {
|
|
||||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
|
||||||
|
|
||||||
C.atg_sum(ptr, self, cdtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_sub(tensor *, tensor self, tensor other);
|
|
||||||
func AtgSub(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_sub(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_sub1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgSub1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_sub1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_sub_(tensor *, tensor self, tensor other);
|
|
||||||
func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_sub_(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_sub_1(tensor *, tensor self, scalar other);
|
|
||||||
func AtgSub1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
|
||||||
C.atg_sub_1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
|
||||||
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
|
||||||
|
|
||||||
C.atg_conv1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_conv2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
|
||||||
func AtgConv2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
|
||||||
|
|
||||||
C.atg_conv2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_conv3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
|
||||||
func AtgConv3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
|
||||||
|
|
||||||
C.atg_conv3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cgroups)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_max_pool2d(tensor *, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int ceil_mode);
|
|
||||||
func AtgMaxPool2d(ptr *Ctensor, self Ctensor, kernelSizeData []int64, kernelSizeLen int, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, ceilMode int) {
|
|
||||||
|
|
||||||
ckernelSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&kernelSizeData[0]))
|
|
||||||
ckernelSizeLen := *(*C.int)(unsafe.Pointer(&kernelSizeLen))
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cceilMode := *(*C.int)(unsafe.Pointer(&ceilMode))
|
|
||||||
|
|
||||||
C.atg_max_pool2d(ptr, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cdilationDataPtr, cdilationLen, cceilMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_avg_pool2d(tensor *, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override);
|
|
||||||
func AtgAvgPool2d(ptr *Ctensor, self Ctensor, kernelSizeData []int64, kernelSizeLen int, strideData []int64, strideLen int, paddingData []int64, paddingLen int, ceilMode int, countIncludePad int, divisorOverride int64) {
|
|
||||||
|
|
||||||
ckernelSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&kernelSizeData[0]))
|
|
||||||
ckernelSizeLen := *(*C.int)(unsafe.Pointer(&kernelSizeLen))
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
cceilMode := *(*C.int)(unsafe.Pointer(&ceilMode))
|
|
||||||
ccountIncludePad := *(*C.int)(unsafe.Pointer(&countIncludePad))
|
|
||||||
cdivisorOverride := *(*C.int64_t)(unsafe.Pointer(&divisorOverride))
|
|
||||||
|
|
||||||
C.atg_avg_pool2d(ptr, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverride)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_dropout(tensor *, tensor input, double p, int train);
|
|
||||||
func AtgDropout(ptr *Ctensor, input Ctensor, p float64, train int) {
|
|
||||||
cp := *(*C.double)(unsafe.Pointer(&p))
|
|
||||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
|
||||||
|
|
||||||
C.atg_dropout(ptr, input, cp, ctrain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_dropout_(tensor *, tensor self, double p, int train);
|
|
||||||
func AtgDropout_(ptr *Ctensor, self Ctensor, p float64, train int) {
|
|
||||||
cp := *(*C.double)(unsafe.Pointer(&p))
|
|
||||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
|
||||||
|
|
||||||
C.atg_dropout_(ptr, self, cp, ctrain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_conv_transpose1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
|
|
||||||
func AtgConvTranspose1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
|
|
||||||
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
|
||||||
|
|
||||||
C.atg_conv_transpose1d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_conv_transpose2d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
|
|
||||||
func AtgConvTranspose2d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
|
|
||||||
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
|
||||||
|
|
||||||
C.atg_conv_transpose2d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_conv_transpose3d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *output_padding_data, int output_padding_len, int64_t groups, int64_t *dilation_data, int dilation_len);
|
|
||||||
func AtgConvTranspose3d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, outputPaddingData []int64, outputPaddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
|
||||||
cstrideLen := *(*C.int)(unsafe.Pointer(&strideLen))
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
coutputPaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&outputPaddingData[0]))
|
|
||||||
coutputPaddingLen := *(*C.int)(unsafe.Pointer(&outputPaddingLen))
|
|
||||||
cdilationDataPtr := (*C.int64_t)(unsafe.Pointer(&dilationData[0]))
|
|
||||||
cdilationLen := *(*C.int)(unsafe.Pointer(&dilationLen))
|
|
||||||
cgroups := *(*C.int64_t)(unsafe.Pointer(&groups))
|
|
||||||
|
|
||||||
C.atg_conv_transpose3d(ptr, input, weight, bias, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, coutputPaddingDataPtr, coutputPaddingLen, cgroups, cdilationDataPtr, cdilationLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_lstm(tensor *, tensor input, tensor *hx_data, int hx_len, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first);
|
|
||||||
func AtgLstm(ptr *Ctensor, input Ctensor, hxData []Ctensor, hxLen int, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) {
|
|
||||||
|
|
||||||
chxDataPtr := (*Ctensor)(unsafe.Pointer(&hxData[0]))
|
|
||||||
chxLen := *(*C.int)(unsafe.Pointer(&hxLen))
|
|
||||||
cparamsDataPtr := (*Ctensor)(unsafe.Pointer(¶msData[0]))
|
|
||||||
cparamsLen := *(*C.int)(unsafe.Pointer(¶msLen))
|
|
||||||
chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases))
|
|
||||||
cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers))
|
|
||||||
cdropout := *(*C.double)(unsafe.Pointer(&dropout))
|
|
||||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
|
||||||
cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional))
|
|
||||||
cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst))
|
|
||||||
|
|
||||||
C.atg_lstm(ptr, input, chxDataPtr, chxLen, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_gru(tensor *, tensor input, tensor hx, tensor *params_data, int params_len, int has_biases, int64_t num_layers, double dropout, int train, int bidirectional, int batch_first);
|
|
||||||
func AtgGru(ptr *Ctensor, input Ctensor, hx Ctensor, paramsData []Ctensor, paramsLen int, hasBiases int, numLayers int64, dropout float64, train int, bidirectional int, batchFirst int) {
|
|
||||||
|
|
||||||
cparamsDataPtr := (*Ctensor)(unsafe.Pointer(¶msData[0]))
|
|
||||||
cparamsLen := *(*C.int)(unsafe.Pointer(¶msLen))
|
|
||||||
chasBiases := *(*C.int)(unsafe.Pointer(&hasBiases))
|
|
||||||
cnumLayers := *(*C.int64_t)(unsafe.Pointer(&numLayers))
|
|
||||||
cdropout := *(*C.double)(unsafe.Pointer(&dropout))
|
|
||||||
ctrain := *(*C.int)(unsafe.Pointer(&train))
|
|
||||||
cbidirectional := *(*C.int)(unsafe.Pointer(&bidirectional))
|
|
||||||
cbatchFirst := *(*C.int)(unsafe.Pointer(&batchFirst))
|
|
||||||
|
|
||||||
C.atg_gru(ptr, input, hx, cparamsDataPtr, cparamsLen, chasBiases, cnumLayers, cdropout, ctrain, cbidirectional, cbatchFirst)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_randn(tensor *, int64_t *size_data, int size_len, int options_kind, int options_device);
|
|
||||||
func AtgRandn(ptr *Ctensor, sizeData []int64, sizeLen int, optionsKind int32, optionsDevice int32) {
|
|
||||||
|
|
||||||
csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
|
||||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
|
||||||
coptionKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_randn(ptr, csizeDataPtr, csizeLen, coptionKind, coptionDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_embedding(tensor *, tensor weight, tensor indices, int64_t padding_idx, int scale_grad_by_freq, int sparse);
|
|
||||||
func AtgEmbedding(ptr *Ctensor, weight Ctensor, indices Ctensor, paddingIdx int64, scaleGradByFreq int, sparse int) {
|
|
||||||
|
|
||||||
cpaddingIdx := *(*C.int64_t)(unsafe.Pointer(&paddingIdx))
|
|
||||||
cscaleGradByFreq := *(*C.int)(unsafe.Pointer(&scaleGradByFreq))
|
|
||||||
csparse := *(*C.int)(unsafe.Pointer(&sparse))
|
|
||||||
|
|
||||||
C.atg_embedding(ptr, weight, indices, cpaddingIdx, cscaleGradByFreq, csparse)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_randint(tensor *, int64_t high, int64_t *size_data, int size_len, int options_kind, int options_device);
|
|
||||||
func AtgRandint(ptr *Ctensor, high int64, sizeData []int64, sizeLen int, optionsKind int32, optionsDevice int32) {
|
|
||||||
|
|
||||||
chigh := *(*C.int64_t)(unsafe.Pointer(&high))
|
|
||||||
csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
|
||||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
|
||||||
coptionKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_randint(ptr, chigh, csizeDataPtr, csizeLen, coptionKind, coptionDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_layer_norm(tensor *, tensor input, int64_t *normalized_shape_data, int normalized_shape_len, tensor weight, tensor bias, double eps, int cudnn_enable);
|
|
||||||
func AtgLayerNorm(ptr *Ctensor, input Ctensor, normalizedShapeData []int64, normalizedShapeLen int, weight Ctensor, bias Ctensor, eps float64, cudnnEnable int) {
|
|
||||||
|
|
||||||
cnormalizedShapeDataPtr := (*C.int64_t)(unsafe.Pointer(&normalizedShapeData[0]))
|
|
||||||
cnormalizedShapeLen := *(*C.int)(unsafe.Pointer(&normalizedShapeLen))
|
|
||||||
ceps := *(*C.double)(unsafe.Pointer(&eps))
|
|
||||||
ccudnnEnable := *(*C.int)(unsafe.Pointer(&cudnnEnable))
|
|
||||||
|
|
||||||
C.atg_layer_norm(ptr, input, cnormalizedShapeDataPtr, cnormalizedShapeLen, weight, bias, ceps, ccudnnEnable)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_batch_norm(tensor *, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, int training, double momentum, double eps, int cudnn_enabled);
|
|
||||||
func AtgBatchNorm(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, runningMean Ctensor, runningVar Ctensor, training int, momentum float64, eps float64, cudnnEnable int) {
|
|
||||||
|
|
||||||
ctraining := *(*C.int)(unsafe.Pointer(&training))
|
|
||||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
|
||||||
ceps := *(*C.double)(unsafe.Pointer(&eps))
|
|
||||||
ccudnnEnable := *(*C.int)(unsafe.Pointer(&cudnnEnable))
|
|
||||||
|
|
||||||
C.atg_batch_norm(ptr, input, weight, bias, runningMean, runningVar, ctraining, cmomentum, ceps, ccudnnEnable)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_cat(tensor *, tensor *tensors_data, int tensors_len, int64_t dim);
|
|
||||||
func AtgCat(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
|
|
||||||
tensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
|
|
||||||
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
|
|
||||||
C.atg_cat(ptr, tensorsDataPtr, ctensorsLen, cdim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_topk(tensor *, tensor self, int64_t k, int64_t dim, int largest, int sorted);
|
|
||||||
func AtgTopk(ptr *Ctensor, self Ctensor, k int64, dim int64, largest int, sorted int) {
|
|
||||||
ck := *(*C.int64_t)(unsafe.Pointer(&k))
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
clargest := *(*C.int)(unsafe.Pointer(&largest))
|
|
||||||
csorted := *(*C.int)(unsafe.Pointer(&sorted))
|
|
||||||
|
|
||||||
C.atg_topk(ptr, self, ck, cdim, clargest, csorted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_adaptive_avg_pool2d(tensor *, tensor self, int64_t *output_size_data, int output_size_len);
|
|
||||||
func AtgAdaptiveAvgPool2d(ptr *Ctensor, self Ctensor, outputSizeData []int64, outputSizeLen int) {
|
|
||||||
outputSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&outputSizeData[0]))
|
|
||||||
coutputSizeLen := *(*C.int)(unsafe.Pointer(&outputSizeLen))
|
|
||||||
|
|
||||||
C.atg_adaptive_avg_pool2d(ptr, self, outputSizeDataPtr, coutputSizeLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_softmax(tensor *, tensor self, int64_t dim, int dtype);
|
|
||||||
func AtgSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
|
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
|
||||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
|
||||||
|
|
||||||
C.atg_softmax(ptr, self, cdim, cdtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len);
|
|
||||||
func AtgConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int) {
|
|
||||||
cpadDataPtr := (*C.int64_t)(unsafe.Pointer(&padData[0]))
|
|
||||||
cpadLen := *(*C.int)(unsafe.Pointer(&padLen))
|
|
||||||
|
|
||||||
C.atg_constant_pad_nd(ptr, self, cpadDataPtr, cpadLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_sigmoid(tensor *, tensor self);
|
|
||||||
func AtgSigmoid(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_sigmoid(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_flip(tensor *, tensor self, int64_t *dims_data, int dims_len);
|
|
||||||
func AtgFlip(ptr *Ctensor, self Ctensor, dimsData []int64, dimsLen int) {
|
|
||||||
|
|
||||||
cdimsDataPtr := (*C.int64_t)(unsafe.Pointer(&dimsData[0]))
|
|
||||||
cdimsLen := *(*C.int)(unsafe.Pointer(&dimsLen))
|
|
||||||
|
|
||||||
C.atg_flip(ptr, self, cdimsDataPtr, cdimsLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_reflection_pad2d(tensor *, tensor self, int64_t *padding_data, int padding_len);
|
|
||||||
func AtgReflectionPad2d(ptr *Ctensor, self Ctensor, paddingData []int64, paddingLen int) {
|
|
||||||
|
|
||||||
cpaddingDataPtr := (*C.int64_t)(unsafe.Pointer(&paddingData[0]))
|
|
||||||
cpaddingLen := *(*C.int)(unsafe.Pointer(&paddingLen))
|
|
||||||
|
|
||||||
C.atg_reflection_pad2d(ptr, self, cpaddingDataPtr, cpaddingLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_arange(tensor *, scalar end, int options_kind, int options_device);
|
|
||||||
func AtgArange(ptr *Ctensor, end Cscalar, optionsKind int32, optionsDevice int32) {
|
|
||||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_arange(ptr, end, coptionsKind, coptionsDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_arange1(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
|
||||||
func AtgArange1(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32) {
|
|
||||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_arange1(ptr, start, end, coptionsKind, coptionsDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_arange2(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
|
|
||||||
func AtgArange2(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32) {
|
|
||||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
|
||||||
|
|
||||||
C.atg_arange2(ptr, start, end, step, coptionsKind, coptionsDevice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_arange_out(tensor *, tensor out, scalar end);
|
|
||||||
func AtgArangeOut(ptr *Ctensor, out Ctensor, end Cscalar) {
|
|
||||||
|
|
||||||
C.atg_arange_out(ptr, out, end)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_arange_out1(tensor *, tensor out, scalar start, scalar end);
|
|
||||||
func AtgArangeOut1(ptr *Ctensor, out Ctensor, start Cscalar, end Cscalar) {
|
|
||||||
|
|
||||||
C.atg_arange_out1(ptr, out, start, end)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_max1(tensor *, tensor self, tensor other);
|
|
||||||
func AtgMax1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|
||||||
C.atg_max1(ptr, self, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_upsample_nearest2d(tensor *, tensor self, int64_t *output_size_data, int output_size_len, double scales_h, double scales_w);
|
|
||||||
func AtgUpsampleNearest2d(ptr *Ctensor, self Ctensor, outputSizeData []int64, outputSizeLen int, scalesH float64, scalesW float64) {
|
|
||||||
|
|
||||||
coutputSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&outputSizeData[0]))
|
|
||||||
coutputSizeLen := *(*C.int)(unsafe.Pointer(&outputSizeLen))
|
|
||||||
cscalesH := *(*C.double)(unsafe.Pointer(&scalesH))
|
|
||||||
cscalesW := *(*C.double)(unsafe.Pointer(&scalesW))
|
|
||||||
|
|
||||||
C.atg_upsample_nearest2d(ptr, self, coutputSizeDataPtr, coutputSizeLen, cscalesH, cscalesW)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_repeat(tensor *, tensor self, int64_t *repeats_data, int repeats_len);
|
|
||||||
func AtgRepeat(ptr *Ctensor, self Ctensor, repeatData []int64, repeatLen int) {
|
|
||||||
crepeatDataPtr := (*C.int64_t)(unsafe.Pointer(&repeatData[0]))
|
|
||||||
crepeatLen := *(*C.int)(unsafe.Pointer(&repeatLen))
|
|
||||||
|
|
||||||
C.atg_repeat(ptr, self, crepeatDataPtr, crepeatLen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_contiguous(tensor *, tensor self);
|
|
||||||
func AtgContiguous(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_contiguous(ptr, self)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_transpose(tensor *, tensor self, int64_t dim0, int64_t dim1);
|
|
||||||
func AtgTranspose(ptr *Ctensor, self Ctensor, dim0 int64, dim1 int64) {
|
|
||||||
|
|
||||||
cdim0 := *(*C.int64_t)(unsafe.Pointer(&dim0))
|
|
||||||
cdim1 := *(*C.int64_t)(unsafe.Pointer(&dim1))
|
|
||||||
|
|
||||||
C.atg_transpose(ptr, self, cdim0, cdim1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// void atg_squeeze(tensor *, tensor self);
|
|
||||||
func AtgSqueeze(ptr *Ctensor, self Ctensor) {
|
|
||||||
C.atg_squeeze(ptr, self)
|
|
||||||
}
|
|
5487
libtch/c-generated.go
Normal file
5487
libtch/c-generated.go
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -127,12 +127,12 @@ func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvT
|
||||||
// ============================================
|
// ============================================
|
||||||
|
|
||||||
func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
|
func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConvTranspose1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
|
func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConvTranspose2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
|
||||||
}
|
}
|
||||||
func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
|
func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConvTranspose3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
|
||||||
}
|
}
|
||||||
|
|
12
nn/conv.go
12
nn/conv.go
|
@ -217,14 +217,14 @@ func NewConv(vs Path, inDim, outDim int64, ksizes []int64, config interface{}) C
|
||||||
// ============================================
|
// ============================================
|
||||||
|
|
||||||
func (c Conv1D) Forward(xs ts.Tensor) ts.Tensor {
|
func (c Conv1D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
return ts.MustConv1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConv1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Conv2D) Forward(xs ts.Tensor) ts.Tensor {
|
func (c Conv2D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
return ts.MustConv2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConv2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
}
|
}
|
||||||
func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor {
|
func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
return ts.MustConv3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConv3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement ModuleT for Conv1D, Conv2D, Conv3D:
|
// Implement ModuleT for Conv1D, Conv2D, Conv3D:
|
||||||
|
@ -233,12 +233,12 @@ func (c Conv3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
// NOTE: `train` param won't be used, will be?
|
// NOTE: `train` param won't be used, will be?
|
||||||
|
|
||||||
func (c Conv1D) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
func (c Conv1D) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return ts.MustConv1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConv1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Conv2D) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
func (c Conv2D) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return ts.MustConv2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConv2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
}
|
}
|
||||||
func (c Conv3D) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
func (c Conv3D) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return ts.MustConv3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
return ts.MustConv3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
}
|
}
|
||||||
|
|
14
nn/init.go
14
nn/init.go
|
@ -30,12 +30,12 @@ func NewConstInit(v float64) constInit {
|
||||||
|
|
||||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||||
var err error
|
var err error
|
||||||
kind := gotch.Float.CInt()
|
kind := gotch.Float
|
||||||
switch {
|
switch {
|
||||||
case c.value == 0.0:
|
case c.value == 0.0:
|
||||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
retVal = ts.MustZeros(dims, kind, device)
|
||||||
case c.value == 1.0:
|
case c.value == 1.0:
|
||||||
retVal = ts.MustOnes(dims, kind, device.CInt())
|
retVal = ts.MustOnes(dims, kind, device)
|
||||||
default:
|
default:
|
||||||
data := make([]float64, ts.FlattenDim(dims))
|
data := make([]float64, ts.FlattenDim(dims))
|
||||||
for i := range data {
|
for i := range data {
|
||||||
|
@ -127,8 +127,8 @@ func NewUniformInit(lo, up float64) uniformInit {
|
||||||
|
|
||||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||||
var err error
|
var err error
|
||||||
kind := gotch.Float.CInt()
|
kind := gotch.Float
|
||||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
retVal = ts.MustZeros(dims, kind, device)
|
||||||
retVal.Uniform_(u.lo, u.up)
|
retVal.Uniform_(u.lo, u.up)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||||
|
@ -158,8 +158,8 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa
|
||||||
}
|
}
|
||||||
|
|
||||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||||
kind := gotch.Float.CInt()
|
kind := gotch.Float
|
||||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
retVal = ts.MustZeros(dims, kind, device)
|
||||||
retVal.Uniform_(-bound, bound)
|
retVal.Uniform_(-bound, bound)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
|
|
|
@ -43,7 +43,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
|
||||||
// bs has size of output dimension
|
// bs has size of output dimension
|
||||||
switch c.Bias {
|
switch c.Bias {
|
||||||
case false:
|
case false:
|
||||||
bs = ts.MustZeros([]int64{outDim}, gotch.Float.CInt(), vs.Device().CInt())
|
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device())
|
||||||
case true:
|
case true:
|
||||||
switch {
|
switch {
|
||||||
case c.BsInit == nil:
|
case c.BsInit == nil:
|
||||||
|
@ -91,7 +91,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
|
||||||
// 1 1 1 ]
|
// 1 1 1 ]
|
||||||
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
|
||||||
mul := xs.MustMatMul(l.Ws, false)
|
mul := xs.MustMatmul(l.Ws, false)
|
||||||
return mul.MustAdd(l.Bs, true)
|
return mul.MustAdd(l.Bs, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,6 +100,6 @@ func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
// NOTE: train param will not be used.
|
// NOTE: train param will not be used.
|
||||||
func (l Linear) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
func (l Linear) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
|
|
||||||
mul := xs.MustMatMul(l.Ws, false)
|
mul := xs.MustMatmul(l.Ws, false)
|
||||||
return mul.MustAdd(l.Bs, true)
|
return mul.MustAdd(l.Bs, true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,7 +131,7 @@ func (l LSTM) ZeroState(batchDim int64) (retVal State) {
|
||||||
|
|
||||||
layerDim := l.config.NumLayers * numDirections
|
layerDim := l.config.NumLayers * numDirections
|
||||||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||||
zeros := ts.MustZeros(shape, gotch.Float.CInt(), l.device.CInt())
|
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
||||||
|
|
||||||
return LSTMState{
|
return LSTMState{
|
||||||
Tensor1: zeros.MustShallowClone(),
|
Tensor1: zeros.MustShallowClone(),
|
||||||
|
@ -157,7 +157,7 @@ func (l LSTM) Seq(input ts.Tensor) (ts.Tensor, State) {
|
||||||
|
|
||||||
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
func (l LSTM) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||||
|
|
||||||
output, h, c := input.MustLSTM([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
output, h, c := input.MustLstm([]ts.Tensor{inState.(LSTMState).Tensor1, inState.(LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||||
|
|
||||||
return output, LSTMState{
|
return output, LSTMState{
|
||||||
Tensor1: h,
|
Tensor1: h,
|
||||||
|
@ -229,7 +229,7 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
||||||
layerDim := g.config.NumLayers * numDirections
|
layerDim := g.config.NumLayers * numDirections
|
||||||
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
||||||
|
|
||||||
tensor := ts.MustZeros(shape, gotch.Float.CInt(), g.device.CInt())
|
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
||||||
|
|
||||||
return GRUState{Tensor: tensor}
|
return GRUState{Tensor: tensor}
|
||||||
}
|
}
|
||||||
|
@ -252,7 +252,7 @@ func (g GRU) Seq(input ts.Tensor) (ts.Tensor, State) {
|
||||||
|
|
||||||
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
func (g GRU) SeqInit(input ts.Tensor, inState State) (ts.Tensor, State) {
|
||||||
|
|
||||||
output, h := input.MustGRU(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
output, h := input.MustGru(inState.(GRUState).Tensor, g.flatWeights, g.config.HasBiases, g.config.NumLayers, g.config.Dropout, g.config.Train, g.config.Bidirectional, g.config.BatchFirst)
|
||||||
|
|
||||||
return output, GRUState{Tensor: h}
|
return output, GRUState{Tensor: h}
|
||||||
}
|
}
|
||||||
|
|
|
@ -258,7 +258,7 @@ func BatchAccuracyForLogits(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d gotch
|
||||||
|
|
||||||
logits := m.ForwardT(bImages, false)
|
logits := m.ForwardT(bImages, false)
|
||||||
acc := logits.AccuracyForLogits(bLabels)
|
acc := logits.AccuracyForLogits(bLabels)
|
||||||
sumAccuracy += acc.Values()[0] * size
|
sumAccuracy += acc.Float64Values()[0] * size
|
||||||
sampleCount += size
|
sampleCount += size
|
||||||
|
|
||||||
bImages.MustDrop()
|
bImages.MustDrop()
|
||||||
|
@ -310,7 +310,7 @@ func BatchAccuracyForLogitsIdx(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d go
|
||||||
logits := m.ForwardT(bImages, true)
|
logits := m.ForwardT(bImages, true)
|
||||||
bAccuracy := logits.AccuracyForLogits(bLabels)
|
bAccuracy := logits.AccuracyForLogits(bLabels)
|
||||||
|
|
||||||
accuVal := bAccuracy.Values()[0]
|
accuVal := bAccuracy.Float64Values()[0]
|
||||||
bSamples := float64(xs.MustSize()[0])
|
bSamples := float64(xs.MustSize()[0])
|
||||||
sumAccuracy += accuVal * bSamples
|
sumAccuracy += accuVal * bSamples
|
||||||
sampleCount += bSamples
|
sampleCount += bSamples
|
||||||
|
|
|
@ -239,7 +239,7 @@ func (vs *VarStore) Freeze() {
|
||||||
defer vs.Vars.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
for _, v := range vs.Vars.TrainableVariables {
|
for _, v := range vs.Vars.TrainableVariables {
|
||||||
_, err := v.SetRequiresGrad(false)
|
_, err := v.SetRequiresGrad(false, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Freeze() Error: %v\n", err)
|
log.Fatalf("Freeze() Error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -254,7 +254,7 @@ func (vs *VarStore) Unfreeze() {
|
||||||
defer vs.Vars.mutex.Unlock()
|
defer vs.Vars.mutex.Unlock()
|
||||||
|
|
||||||
for _, v := range vs.Vars.TrainableVariables {
|
for _, v := range vs.Vars.TrainableVariables {
|
||||||
_, err := v.SetRequiresGrad(true)
|
_, err := v.SetRequiresGrad(true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Unfreeze() Error: %v\n", err)
|
log.Fatalf("Unfreeze() Error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -349,7 +349,7 @@ func (p *Path) add(name string, newTs ts.Tensor, trainable bool) (retVal ts.Tens
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if trainable {
|
if trainable {
|
||||||
tensor, err = newTs.MustShallowClone().SetRequiresGrad(true)
|
tensor, err = newTs.MustShallowClone().SetRequiresGrad(true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Path 'add' method error: %v\n", err)
|
log.Fatalf("Path 'add' method error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -378,7 +378,7 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
|
||||||
var err error
|
var err error
|
||||||
var ttensor ts.Tensor
|
var ttensor ts.Tensor
|
||||||
if trainable {
|
if trainable {
|
||||||
ttensor, err = tensor.SetRequiresGrad(true)
|
ttensor, err = tensor.SetRequiresGrad(true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Path - call method 'getOrAddWithLock' error: %v\n", err)
|
log.Fatalf("Path - call method 'getOrAddWithLock' error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -403,9 +403,8 @@ func (p *Path) getOrAddWithLock(name string, tensor ts.Tensor, trainable bool, v
|
||||||
// The variable uses a float tensor initialized with zeros.
|
// The variable uses a float tensor initialized with zeros.
|
||||||
func (p *Path) ZerosNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
func (p *Path) ZerosNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
||||||
|
|
||||||
dtype, err := gotch.DType2CInt(gotch.Float) // DType Float
|
device := p.Device()
|
||||||
device := p.Device().CInt()
|
z, err := ts.Zeros(dims, gotch.Float, device)
|
||||||
z, err := ts.Zeros(dims, dtype, device)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Path - 'ZerosNoTrain' method call error: %v\n", err)
|
log.Fatalf("Path - 'ZerosNoTrain' method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -421,9 +420,8 @@ func (p *Path) ZerosNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
||||||
// The variable uses a float tensor initialized with ones.
|
// The variable uses a float tensor initialized with ones.
|
||||||
func (p *Path) OnesNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
func (p *Path) OnesNoTrain(name string, dims []int64) (retVal ts.Tensor) {
|
||||||
|
|
||||||
dtype, err := gotch.DType2CInt(gotch.Float) // DType Float
|
device := p.Device()
|
||||||
device := p.Device().CInt()
|
z, err := ts.Ones(dims, gotch.Float, device)
|
||||||
z, err := ts.Ones(dims, dtype, device)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Path - 'OnesNoTrain' method call error: %v\n", err)
|
log.Fatalf("Path - 'OnesNoTrain' method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -610,7 +608,7 @@ func (e *Entry) OrOnes(dims []int64) (retVal ts.Tensor) {
|
||||||
// OrOnesNoTrain returns the existing entry if, otherwise create a new variable.
|
// OrOnesNoTrain returns the existing entry if, otherwise create a new variable.
|
||||||
func (e *Entry) OrOnesNoTrain(dims []int64) (retVal ts.Tensor) {
|
func (e *Entry) OrOnesNoTrain(dims []int64) (retVal ts.Tensor) {
|
||||||
|
|
||||||
o := ts.MustOnes(dims, gotch.Float.CInt(), e.path.Device().CInt())
|
o := ts.MustOnes(dims, gotch.Float, e.path.Device())
|
||||||
return e.path.getOrAddWithLock(e.name, o, true, *e.variables)
|
return e.path.getOrAddWithLock(e.name, o, true, *e.variables)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -641,7 +639,7 @@ func (e *Entry) OrZeros(dims []int64) (retVal ts.Tensor) {
|
||||||
// OrZerosNoTrain returns the existing entry if, otherwise create a new variable.
|
// OrZerosNoTrain returns the existing entry if, otherwise create a new variable.
|
||||||
func (e *Entry) OrZerosNoTrain(dims []int64) (retVal ts.Tensor) {
|
func (e *Entry) OrZerosNoTrain(dims []int64) (retVal ts.Tensor) {
|
||||||
|
|
||||||
z := ts.MustZeros(dims, gotch.Float.CInt(), e.path.Device().CInt())
|
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
|
||||||
return e.path.getOrAddWithLock(e.name, z, true, *e.variables)
|
return e.path.getOrAddWithLock(e.name, z, true, *e.variables)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,10 +74,10 @@ func TestSaveLoad(t *testing.T) {
|
||||||
wantU2 := float64(0.0)
|
wantU2 := float64(0.0)
|
||||||
wantV2 := float64(1.0)
|
wantV2 := float64(1.0)
|
||||||
|
|
||||||
gotU1 := u1.MustMean(gotch.Float.CInt(), false).Values()[0]
|
gotU1 := u1.MustMean(gotch.Float, false).Float64Values()[0]
|
||||||
gotV1 := v1.MustMean(gotch.Float.CInt(), false).Values()[0]
|
gotV1 := v1.MustMean(gotch.Float, false).Float64Values()[0]
|
||||||
gotU2 := u2.MustMean(gotch.Float.CInt(), false).Values()[0]
|
gotU2 := u2.MustMean(gotch.Float, false).Float64Values()[0]
|
||||||
gotV2 := v2.MustMean(gotch.Float.CInt(), false).Values()[0]
|
gotV2 := v2.MustMean(gotch.Float, false).Float64Values()[0]
|
||||||
|
|
||||||
if !reflect.DeepEqual(wantU1, gotU1) {
|
if !reflect.DeepEqual(wantU1, gotU1) {
|
||||||
t.Errorf("Expected u1: %v\n", wantU1)
|
t.Errorf("Expected u1: %v\n", wantU1)
|
||||||
|
@ -109,8 +109,8 @@ func TestSaveLoad(t *testing.T) {
|
||||||
|
|
||||||
wantU2 = float64(42.0)
|
wantU2 = float64(42.0)
|
||||||
wantV2 = float64(2.0)
|
wantV2 = float64(2.0)
|
||||||
gotU2 = u2.MustMean(gotch.Float.CInt(), false).Values()[0]
|
gotU2 = u2.MustMean(gotch.Float, false).Float64Values()[0]
|
||||||
gotV2 = v2.MustMean(gotch.Float.CInt(), false).Values()[0]
|
gotV2 = v2.MustMean(gotch.Float, false).Float64Values()[0]
|
||||||
|
|
||||||
if !reflect.DeepEqual(wantU1, gotU1) {
|
if !reflect.DeepEqual(wantU1, gotU1) {
|
||||||
t.Errorf("Expected u1: %v\n", wantU1)
|
t.Errorf("Expected u1: %v\n", wantU1)
|
||||||
|
|
|
@ -59,7 +59,7 @@ func TestModuleForwardTs(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
got := int(res.Values()[0])
|
got := int(res.Float64Values()[0])
|
||||||
|
|
||||||
want := 1421
|
want := 1421
|
||||||
|
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
package tensor
|
|
||||||
|
|
||||||
// TODO: implement tensor.From macro
|
|
8043
tensor/must-tensor-generated.go
Normal file
8043
tensor/must-tensor-generated.go
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -8,10 +8,12 @@ import (
|
||||||
|
|
||||||
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
|
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
|
||||||
func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
|
func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
|
||||||
// return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true)
|
weight := NewTensor()
|
||||||
|
reduction := int64(1) // Mean of loss
|
||||||
|
ignoreIndex := int64(-100)
|
||||||
|
|
||||||
logSm := ts.MustLogSoftmax(-1, gotch.Float.CInt(), true)
|
logSm := ts.MustLogSoftmax(-1, gotch.Float, true)
|
||||||
return logSm.MustNllLoss(targets, true)
|
return logSm.MustNllLoss(targets, weight, reduction, ignoreIndex, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
||||||
|
@ -19,11 +21,11 @@ func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
|
||||||
func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
|
func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
|
||||||
argmax := ts.MustArgmax(-1, false, true)
|
argmax := ts.MustArgmax(-1, false, true)
|
||||||
eq1 := argmax.MustEq1(targets, true)
|
eq1 := argmax.MustEq1(targets, true)
|
||||||
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||||
return ts.MustMaxPool2D([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
|
return ts.MustMaxPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: continue
|
// TODO: continue
|
||||||
|
|
154
tensor/patch.go
Normal file
154
tensor/patch.go
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
package tensor
|
||||||
|
|
||||||
|
// #include "stdlib.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
// "github.com/sugarme/gotch"
|
||||||
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NOTE. This is a temporarily patched to make it run.
|
||||||
|
// TODO. make change at generator for []Tensor input
|
||||||
|
|
||||||
|
func (ts Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor, err error) {
|
||||||
|
|
||||||
|
// NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first
|
||||||
|
// Ctensor will have address given by `ctensorPtr1` here.
|
||||||
|
// The next pointers can be calculated based on `ctensorPtr1`
|
||||||
|
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
|
||||||
|
ctensorPtr3 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr2)) + unsafe.Sizeof(ctensorPtr1)))
|
||||||
|
|
||||||
|
var chxData []lib.Ctensor
|
||||||
|
for _, t := range hxData {
|
||||||
|
chxData = append(chxData, t.ctensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cparamsData []lib.Ctensor
|
||||||
|
for _, t := range paramsData {
|
||||||
|
cparamsData = append(cparamsData, t.ctensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chasBiases int32 = 0
|
||||||
|
if hasBiases {
|
||||||
|
chasBiases = 1
|
||||||
|
}
|
||||||
|
var ctrain int32 = 0
|
||||||
|
if train {
|
||||||
|
ctrain = 1
|
||||||
|
}
|
||||||
|
var cbidirectional int32 = 0
|
||||||
|
if bidirectional {
|
||||||
|
cbidirectional = 1
|
||||||
|
}
|
||||||
|
var cbatchFirst int32 = 0
|
||||||
|
if batchFirst {
|
||||||
|
cbatchFirst = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgLstm(ctensorPtr1, ts.ctensor, chxData, len(hxData), cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return output, h, c, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tensor{ctensor: *ctensorPtr1}, Tensor{ctensor: *ctensorPtr2}, Tensor{ctensor: *ctensorPtr3}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c Tensor) {
|
||||||
|
output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, h, c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Gru(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor, err error) {
|
||||||
|
|
||||||
|
// NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land.
|
||||||
|
// The first Ctensor will have address given by `ctensorPtr1` here.
|
||||||
|
// The next pointer can be calculated based on `ctensorPtr1`
|
||||||
|
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
|
||||||
|
|
||||||
|
var cparamsData []lib.Ctensor
|
||||||
|
for _, t := range paramsData {
|
||||||
|
cparamsData = append(cparamsData, t.ctensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chasBiases int32 = 0
|
||||||
|
if hasBiases {
|
||||||
|
chasBiases = 1
|
||||||
|
}
|
||||||
|
var ctrain int32 = 0
|
||||||
|
if train {
|
||||||
|
ctrain = 1
|
||||||
|
}
|
||||||
|
var cbidirectional int32 = 0
|
||||||
|
if bidirectional {
|
||||||
|
cbidirectional = 1
|
||||||
|
}
|
||||||
|
var cbatchFirst int32 = 0
|
||||||
|
if batchFirst {
|
||||||
|
cbatchFirst = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgGru(ctensorPtr1, ts.ctensor, hx.ctensor, cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return output, h, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tensor{ctensor: *ctensorPtr1}, Tensor{ctensor: *ctensorPtr2}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustGru(hx Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h Tensor) {
|
||||||
|
output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1 Tensor, ts2 Tensor, err error) {
|
||||||
|
|
||||||
|
// NOTE: `lib.AtgTopk` will return 2 tensors in C memory. First tensor pointer
|
||||||
|
// is given by ctensorPtr1
|
||||||
|
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
|
||||||
|
var clargest int32 = 0
|
||||||
|
if largest {
|
||||||
|
clargest = 1
|
||||||
|
}
|
||||||
|
var csorted int32 = 0
|
||||||
|
if sorted {
|
||||||
|
csorted = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgTopk(ctensorPtr1, ts.ctensor, k, dim, clargest, csorted)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return ts1, ts2, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tensor{ctensor: *ctensorPtr1}, Tensor{ctensor: *ctensorPtr2}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Tensor, ts2 Tensor) {
|
||||||
|
|
||||||
|
ts1, ts2, err := ts.TopK(k, dim, largest, sorted)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ts1, ts2
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
13035
tensor/tensor-generated.go
Normal file
13035
tensor/tensor-generated.go
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -55,6 +55,7 @@ func (ts Tensor) Size() (retVal []int64, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
retVal = decodeSize(szPtr, dim)
|
retVal = decodeSize(szPtr, dim)
|
||||||
|
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,6 +64,7 @@ func (ts Tensor) MustSize() (retVal []int64) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,33 +297,34 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) {
|
||||||
return device.OfCInt(int32(cInt)), nil
|
return device.OfCInt(int32(cInt)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
/*
|
||||||
|
* func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
||||||
// Get a C null pointer
|
*
|
||||||
// https://stackoverflow.com/a/2022369
|
* // Get a C null pointer
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
* // https://stackoverflow.com/a/2022369
|
||||||
if del {
|
* ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer ts.MustDrop()
|
* if del {
|
||||||
}
|
* defer ts.MustDrop()
|
||||||
|
* }
|
||||||
lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
|
*
|
||||||
if err = TorchErr(); err != nil {
|
* lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
|
||||||
return retVal, err
|
* if err = TorchErr(); err != nil {
|
||||||
}
|
* return retVal, err
|
||||||
|
* }
|
||||||
return Tensor{ctensor: *ptr}, nil
|
*
|
||||||
|
* return Tensor{ctensor: *ptr}, nil
|
||||||
}
|
*
|
||||||
|
* }
|
||||||
func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
|
*
|
||||||
retVal, err := ts.Eq1(other, del)
|
* func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
|
||||||
if err != nil {
|
* retVal, err := ts.Eq1(other, del)
|
||||||
log.Fatal(err)
|
* if err != nil {
|
||||||
}
|
* log.Fatal(err)
|
||||||
|
* }
|
||||||
return retVal
|
*
|
||||||
}
|
* return retVal
|
||||||
|
* }
|
||||||
|
* */
|
||||||
// Float64Value returns a float value on tensors holding a single element.
|
// Float64Value returns a float value on tensors holding a single element.
|
||||||
// An error is returned otherwise.
|
// An error is returned otherwise.
|
||||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||||
|
@ -440,7 +443,7 @@ func (ts Tensor) IsSparse() (retVal bool, err error) {
|
||||||
|
|
||||||
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
|
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
|
||||||
func (ts Tensor) ZeroGrad() {
|
func (ts Tensor) ZeroGrad() {
|
||||||
grad := ts.MustGrad()
|
grad := ts.MustGrad(false)
|
||||||
if grad.MustDefined() {
|
if grad.MustDefined() {
|
||||||
grad.Detach_()
|
grad.Detach_()
|
||||||
grad.Zero_()
|
grad.Zero_()
|
||||||
|
@ -1022,8 +1025,8 @@ func (r Reduction) ToInt() (retVal int) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Values returns values of tensor in a slice of float64.
|
// Float64Values returns values of tensor in a slice of float64.
|
||||||
func (ts Tensor) Values() []float64 {
|
func (ts Tensor) Float64Values() []float64 {
|
||||||
numel := ts.Numel()
|
numel := ts.Numel()
|
||||||
vec := make([]float64, numel)
|
vec := make([]float64, numel)
|
||||||
|
|
||||||
|
@ -1102,5 +1105,5 @@ func (ts Tensor) Swish() (retVal Tensor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||||
return ts.MustAvgPool2D([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
|
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
|
||||||
}
|
}
|
||||||
|
|
35
tensor/tensor.go1
Normal file
35
tensor/tensor.go1
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package tensor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tensor struct {
|
||||||
|
ctensor lib.Ctensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Print() {
|
||||||
|
lib.AtPrint(ts.ctensor)
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop drops (frees) the tensor
|
||||||
|
func (ts Tensor) Drop() (err error) {
|
||||||
|
lib.AtFree(ts.ctensor)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustDrop drops the tensor. It will be panic if error
|
||||||
|
func (ts Tensor) MustDrop() {
|
||||||
|
if err := ts.Drop(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,12 +8,26 @@ import (
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestTensorInit(t *testing.T) {
|
||||||
|
tensor := ts.MustArange1(ts.IntScalar(1), ts.IntScalar(5), gotch.Int64, gotch.CPU)
|
||||||
|
|
||||||
|
tensor.Print()
|
||||||
|
|
||||||
|
want := []float64{1, 2, 3, 4}
|
||||||
|
got := tensor.Float64Values()
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("Expected tensor values: %v\n", want)
|
||||||
|
t.Errorf("Got tensor values: %v\n", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInplaceAssign(t *testing.T) {
|
func TestInplaceAssign(t *testing.T) {
|
||||||
tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5})
|
tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5})
|
||||||
|
|
||||||
tensor.Add1_(ts.IntScalar(1))
|
tensor.MustAdd1_(ts.IntScalar(1))
|
||||||
tensor.Mul1_(ts.IntScalar(2))
|
tensor.MustMul1_(ts.IntScalar(2))
|
||||||
tensor.Sub1_(ts.IntScalar(1))
|
tensor.MustSub1_(ts.IntScalar(1))
|
||||||
|
|
||||||
want := []int64{7, 3, 9, 3, 11}
|
want := []int64{7, 3, 9, 3, 11}
|
||||||
got := tensor.Vals()
|
got := tensor.Vals()
|
||||||
|
@ -83,5 +97,3 @@ func TestIter(t *testing.T) {
|
||||||
t.Errorf("Got tensor values: %v\n", got1)
|
t.Errorf("Got tensor values: %v\n", got1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: more tests
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ func anConv2d(p nn.Path, cIn, cOut, ksize, padding, stride int64) (retVal nn.Con
|
||||||
}
|
}
|
||||||
|
|
||||||
func anMaxPool2d(xs ts.Tensor, ksize, stride int64) (retVal ts.Tensor) {
|
func anMaxPool2d(xs ts.Tensor, ksize, stride int64) (retVal ts.Tensor) {
|
||||||
return xs.MustMaxPool2D([]int64{ksize, ksize}, []int64{stride, stride}, []int64{0, 0}, []int64{1, 1}, false, false)
|
return xs.MustMaxPool2d([]int64{ksize, ksize}, []int64{stride, stride}, []int64{0, 0}, []int64{1, 1}, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func features(p nn.Path) (retVal ts.ModuleT) {
|
func features(p nn.Path) (retVal ts.ModuleT) {
|
||||||
|
@ -68,7 +68,7 @@ func classifier(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
||||||
seq := nn.SeqT()
|
seq := nn.SeqT()
|
||||||
|
|
||||||
seq.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.5, train, false)
|
return ts.MustDropout(xs, 0.5, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(p.Sub("1"), 256*6*6, 4096, nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(p.Sub("1"), 256*6*6, 4096, nn.DefaultLinearConfig()))
|
||||||
|
@ -78,7 +78,7 @@ func classifier(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.5, train, false)
|
return ts.MustDropout(xs, 0.5, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(p.Sub("4"), 4096, 4096, nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(p.Sub("4"), 4096, 4096, nn.DefaultLinearConfig()))
|
||||||
|
@ -98,7 +98,7 @@ func AlexNet(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
||||||
seq.Add(features(p.Sub("features")))
|
seq.Add(features(p.Sub("features")))
|
||||||
|
|
||||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
tmp1 := xs.MustAdaptiveAvgPool2D([]int64{6, 6})
|
tmp1 := xs.MustAdaptiveAvgPool2d([]int64{6, 6}, false)
|
||||||
res := tmp1.FlatView()
|
res := tmp1.FlatView()
|
||||||
tmp1.MustDrop()
|
tmp1.MustDrop()
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -45,8 +45,8 @@ func readFile(filename string) (imagesTs ts.Tensor, labelsTs ts.Tensor) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
images := ts.MustZeros([]int64{samplesPerFile, cfC, cfH, cfW}, gotch.Float.CInt(), gotch.CPU.CInt())
|
images := ts.MustZeros([]int64{samplesPerFile, cfC, cfH, cfW}, gotch.Float, gotch.CPU)
|
||||||
labels := ts.MustZeros([]int64{samplesPerFile}, gotch.Int64.CInt(), gotch.CPU.CInt())
|
labels := ts.MustZeros([]int64{samplesPerFile}, gotch.Int64, gotch.CPU)
|
||||||
|
|
||||||
for idx := 0; idx < int(samplesPerFile); idx++ {
|
for idx := 0; idx < int(samplesPerFile); idx++ {
|
||||||
contentOffset := int(bytesPerImage) * idx
|
contentOffset := int(bytesPerImage) * idx
|
||||||
|
@ -101,8 +101,8 @@ func CFLoadDir(dir string) (retVal Dataset) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return Dataset{
|
return Dataset{
|
||||||
TrainImages: ts.MustCat(trainImages, 0, true),
|
TrainImages: ts.MustCat(trainImages, 0),
|
||||||
TrainLabels: ts.MustCat(trainLabels, 0, true),
|
TrainLabels: ts.MustCat(trainLabels, 0),
|
||||||
TestImages: testImages,
|
TestImages: testImages,
|
||||||
TestLabels: testLabels,
|
TestLabels: testLabels,
|
||||||
Labels: 10,
|
Labels: 10,
|
||||||
|
|
|
@ -57,7 +57,7 @@ func RandomFlip(t ts.Tensor) (retVal ts.Tensor) {
|
||||||
if rand.Float64() == 1.0 {
|
if rand.Float64() == 1.0 {
|
||||||
src = tView
|
src = tView
|
||||||
} else {
|
} else {
|
||||||
src = tView.MustFlip([]int64{2})
|
src = tView.MustFlip([]int64{2}, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
tView.MustDrop()
|
tView.MustDrop()
|
||||||
|
@ -82,7 +82,7 @@ func RandomCrop(t ts.Tensor, pad int64) (retVal ts.Tensor) {
|
||||||
|
|
||||||
szH := size[2]
|
szH := size[2]
|
||||||
szW := size[3]
|
szW := size[3]
|
||||||
padded := t.MustReflectionPad2d([]int64{pad, pad, pad, pad})
|
padded := t.MustReflectionPad2d([]int64{pad, pad, pad, pad}, false)
|
||||||
output, err := t.ZerosLike(false)
|
output, err := t.ZerosLike(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
|
|
@ -39,7 +39,7 @@ func denseLayer(p nn.Path, cIn, bnSize, growth int64) (retVal ts.ModuleT) {
|
||||||
ys := ys5.Apply(conv2)
|
ys := ys5.Apply(conv2)
|
||||||
ys5.MustDrop()
|
ys5.MustDrop()
|
||||||
|
|
||||||
res := ts.MustCat([]ts.Tensor{xs, ys}, 1, false)
|
res := ts.MustCat([]ts.Tensor{xs, ys}, 1)
|
||||||
ys.MustDrop()
|
ys.MustDrop()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -84,7 +84,7 @@ func densenet(p nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth in
|
||||||
|
|
||||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
tmp := xs.MustRelu(false)
|
tmp := xs.MustRelu(false)
|
||||||
return tmp.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
return tmp.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
nfeat := cIn
|
nfeat := cIn
|
||||||
|
@ -103,7 +103,7 @@ func densenet(p nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth in
|
||||||
|
|
||||||
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
tmp1 := xs.MustRelu(false)
|
tmp1 := xs.MustRelu(false)
|
||||||
tmp2 := tmp1.MustAvgPool2D([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, 1, true)
|
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, 1, true)
|
||||||
res := tmp2.FlatView()
|
res := tmp2.FlatView()
|
||||||
tmp2.MustDrop()
|
tmp2.MustDrop()
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -218,7 +218,7 @@ func block(p nn.Path, args BlockArgs) (retVal ts.ModuleT) {
|
||||||
if args.SeRatio == 0 {
|
if args.SeRatio == 0 {
|
||||||
ys4 = ys3
|
ys4 = ys3
|
||||||
} else {
|
} else {
|
||||||
tmp1 := ys3.MustAdaptiveAvgPool2D([]int64{1, 1})
|
tmp1 := ys3.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
tmp2 := tmp1.ApplyT(se, train)
|
tmp2 := tmp1.ApplyT(se, train)
|
||||||
tmp1.MustDrop()
|
tmp1.MustDrop()
|
||||||
tmp3 := tmp2.MustSigmoid(true)
|
tmp3 := tmp2.MustSigmoid(true)
|
||||||
|
@ -288,7 +288,7 @@ func efficientnet(p nn.Path, params params, nclasses int64) (retVal ts.ModuleT)
|
||||||
classifier := nn.SeqT()
|
classifier := nn.SeqT()
|
||||||
|
|
||||||
classifier.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
classifier.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.2, train, false)
|
return ts.MustDropout(xs, 0.2, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
classifier.Add(nn.NewLinear(p.Sub("_fc"), outC, nclasses, nn.DefaultLinearConfig()))
|
classifier.Add(nn.NewLinear(p.Sub("_fc"), outC, nclasses, nn.DefaultLinearConfig()))
|
||||||
|
@ -306,7 +306,7 @@ func efficientnet(p nn.Path, params params, nclasses int64) (retVal ts.ModuleT)
|
||||||
tmp5.MustDrop()
|
tmp5.MustDrop()
|
||||||
tmp7 := tmp6.Swish()
|
tmp7 := tmp6.Swish()
|
||||||
tmp6.MustDrop()
|
tmp6.MustDrop()
|
||||||
tmp8 := tmp7.MustAdaptiveAvgPool2D([]int64{1, 1})
|
tmp8 := tmp7.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
tmp7.MustDrop()
|
tmp7.MustDrop()
|
||||||
tmp9 := tmp8.MustSqueeze1(-1, true)
|
tmp9 := tmp8.MustSqueeze1(-1, true)
|
||||||
tmp10 := tmp9.MustSqueeze1(-1, true)
|
tmp10 := tmp9.MustSqueeze1(-1, true)
|
||||||
|
|
|
@ -236,7 +236,7 @@ func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
|
||||||
ntrainTs := trainTs.MustSize()[0]
|
ntrainTs := trainTs.MustSize()[0]
|
||||||
trainImages = append(trainImages, trainTs)
|
trainImages = append(trainImages, trainTs)
|
||||||
|
|
||||||
trainLabelOnes := ts.MustOnes([]int64{ntrainTs}, gotch.Int64.CInt(), gotch.CPU.CInt())
|
trainLabelOnes := ts.MustOnes([]int64{ntrainTs}, gotch.Int64, gotch.CPU)
|
||||||
trainLabels = append(trainLabels, trainLabelOnes.MustMul1(ts.IntScalar(labelIndex), true))
|
trainLabels = append(trainLabels, trainLabelOnes.MustMul1(ts.IntScalar(labelIndex), true))
|
||||||
|
|
||||||
// test
|
// test
|
||||||
|
@ -249,15 +249,15 @@ func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
|
||||||
ntestTs := testTs.MustSize()[0]
|
ntestTs := testTs.MustSize()[0]
|
||||||
testImages = append(testImages, testTs)
|
testImages = append(testImages, testTs)
|
||||||
|
|
||||||
testLabelOnes := ts.MustOnes([]int64{ntestTs}, gotch.Int64.CInt(), gotch.CPU.CInt())
|
testLabelOnes := ts.MustOnes([]int64{ntestTs}, gotch.Int64, gotch.CPU)
|
||||||
testLabels = append(testLabels, testLabelOnes.MustMul1(ts.IntScalar(labelIndex), true))
|
testLabels = append(testLabels, testLabelOnes.MustMul1(ts.IntScalar(labelIndex), true))
|
||||||
}
|
}
|
||||||
|
|
||||||
return Dataset{
|
return Dataset{
|
||||||
TrainImages: ts.MustCat(trainImages, 0, true),
|
TrainImages: ts.MustCat(trainImages, 0),
|
||||||
TrainLabels: ts.MustCat(trainLabels, 0, true),
|
TrainLabels: ts.MustCat(trainLabels, 0),
|
||||||
TestImages: ts.MustCat(testImages, 0, true),
|
TestImages: ts.MustCat(testImages, 0),
|
||||||
TestLabels: ts.MustCat(testLabels, 0, true),
|
TestLabels: ts.MustCat(testLabels, 0),
|
||||||
Labels: int64(len(classes)),
|
Labels: int64(len(classes)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -1301,8 +1301,8 @@ func (in ImageNet) Top(input ts.Tensor, k int64) (retVal []TopItem) {
|
||||||
|
|
||||||
var topItems []TopItem
|
var topItems []TopItem
|
||||||
|
|
||||||
vals := valsTs.Values()
|
vals := valsTs.Float64Values()
|
||||||
idxs := idxsTs.Values()
|
idxs := idxsTs.Float64Values()
|
||||||
|
|
||||||
for i := 0; i < int(k); i++ {
|
for i := 0; i < int(k); i++ {
|
||||||
val := vals[i]
|
val := vals[i]
|
||||||
|
|
|
@ -53,7 +53,7 @@ func convBn2(p nn.Path, cIn, cOut int64, ksize []int64, pad []int64) (retVal ts.
|
||||||
}
|
}
|
||||||
|
|
||||||
func inMaxPool2D(xs ts.Tensor, ksize, stride int64) (retVal ts.Tensor) {
|
func inMaxPool2D(xs ts.Tensor, ksize, stride int64) (retVal ts.Tensor) {
|
||||||
return xs.MustMaxPool2D([]int64{ksize, ksize}, []int64{stride, stride}, []int64{0, 0}, []int64{1, 1}, false, false)
|
return xs.MustMaxPool2d([]int64{ksize, ksize}, []int64{stride, stride}, []int64{0, 0}, []int64{1, 1}, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func inceptionA(p nn.Path, cIn, cPool int64) (retVal ts.ModuleT) {
|
func inceptionA(p nn.Path, cIn, cPool int64) (retVal ts.ModuleT) {
|
||||||
|
@ -78,10 +78,10 @@ func inceptionA(p nn.Path, cIn, cPool int64) (retVal ts.ModuleT) {
|
||||||
b3Ts := b3Tmp2.ApplyT(b33, train)
|
b3Ts := b3Tmp2.ApplyT(b33, train)
|
||||||
b3Tmp2.MustDrop()
|
b3Tmp2.MustDrop()
|
||||||
|
|
||||||
bpoolTmp := xs.MustAvgPool2D([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
||||||
bpoolTs := bpoolTmp.ApplyT(bpool, train)
|
bpoolTs := bpoolTmp.ApplyT(bpool, train)
|
||||||
|
|
||||||
res := ts.MustCat([]ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1, true)
|
res := ts.MustCat([]ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
})
|
})
|
||||||
|
@ -104,7 +104,7 @@ func inceptionB(p nn.Path, cIn int64) (retVal ts.ModuleT) {
|
||||||
|
|
||||||
bpoolTs := inMaxPool2D(xs, 3, 2)
|
bpoolTs := inMaxPool2D(xs, 3, 2)
|
||||||
|
|
||||||
res := ts.MustCat([]ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1, true)
|
res := ts.MustCat([]ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
})
|
})
|
||||||
|
@ -145,10 +145,10 @@ func inceptionC(p nn.Path, cIn int64, c7 int64) (retVal ts.ModuleT) {
|
||||||
b3Ts := b3Tmp4.ApplyT(b35, train)
|
b3Ts := b3Tmp4.ApplyT(b35, train)
|
||||||
b3Tmp4.MustDrop()
|
b3Tmp4.MustDrop()
|
||||||
|
|
||||||
bpTmp1 := xs.MustAvgPool2D([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
||||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||||
|
|
||||||
res = ts.MustCat([]ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1, true)
|
res = ts.MustCat([]ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -180,7 +180,7 @@ func inceptionD(p nn.Path, cIn int64) (retVal ts.ModuleT) {
|
||||||
|
|
||||||
bpoolTs := inMaxPool2D(xs, 3, 2)
|
bpoolTs := inMaxPool2D(xs, 3, 2)
|
||||||
|
|
||||||
return ts.MustCat([]ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1, true)
|
return ts.MustCat([]ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1)
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -205,19 +205,19 @@ func inceptionE(p nn.Path, cIn int64) (retVal ts.ModuleT) {
|
||||||
b2Tmp := xs.ApplyT(b21, train)
|
b2Tmp := xs.ApplyT(b21, train)
|
||||||
b2aTs := b2Tmp.ApplyT(b22a, train)
|
b2aTs := b2Tmp.ApplyT(b22a, train)
|
||||||
b2bTs := b2Tmp.ApplyT(b22b, train)
|
b2bTs := b2Tmp.ApplyT(b22b, train)
|
||||||
b2Ts := ts.MustCat([]ts.Tensor{b2aTs, b2bTs}, 1, true)
|
b2Ts := ts.MustCat([]ts.Tensor{b2aTs, b2bTs}, 1)
|
||||||
|
|
||||||
b3Tmp1 := xs.ApplyT(b31, train)
|
b3Tmp1 := xs.ApplyT(b31, train)
|
||||||
b3Tmp2 := b3Tmp1.ApplyT(b32, train)
|
b3Tmp2 := b3Tmp1.ApplyT(b32, train)
|
||||||
b3Tmp1.MustDrop()
|
b3Tmp1.MustDrop()
|
||||||
b3aTs := b3Tmp2.ApplyT(b33a, train)
|
b3aTs := b3Tmp2.ApplyT(b33a, train)
|
||||||
b3bTs := b3Tmp2.ApplyT(b33b, train)
|
b3bTs := b3Tmp2.ApplyT(b33b, train)
|
||||||
b3Ts := ts.MustCat([]ts.Tensor{b3aTs, b3bTs}, 1, true)
|
b3Ts := ts.MustCat([]ts.Tensor{b3aTs, b3bTs}, 1)
|
||||||
|
|
||||||
bpTmp1 := xs.MustAvgPool2D([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
||||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||||
|
|
||||||
return ts.MustCat([]ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1, true)
|
return ts.MustCat([]ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -263,10 +263,10 @@ func InceptionV3(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
||||||
seq.Add(inceptionE(p.Sub("Mixed_7c"), 2048))
|
seq.Add(inceptionE(p.Sub("Mixed_7c"), 2048))
|
||||||
|
|
||||||
seq.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
tmp1 := xs.MustAdaptiveAvgPool2D([]int64{1, 1})
|
tmp1 := xs.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
tmp2 := tmp1.MustDropout(0.5, train, true)
|
tmp2 := ts.MustDropout(tmp1, 0.5, train)
|
||||||
|
tmp1.MustDrop()
|
||||||
res := tmp2.FlatView()
|
res := tmp2.FlatView()
|
||||||
tmp2.MustDrop()
|
|
||||||
return res
|
return res
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@ func MobileNetV2(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
||||||
classifier := nn.SeqT()
|
classifier := nn.SeqT()
|
||||||
|
|
||||||
classifier.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
classifier.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.5, train, false)
|
return ts.MustDropout(xs, 0.5, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
classifier.Add(nn.NewLinear(cp.Sub("1"), 1280, nclasses, nn.DefaultLinearConfig()))
|
classifier.Add(nn.NewLinear(cp.Sub("1"), 1280, nclasses, nn.DefaultLinearConfig()))
|
||||||
|
|
|
@ -92,7 +92,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
||||||
bn1 := c1.ApplyT(bn1, train)
|
bn1 := c1.ApplyT(bn1, train)
|
||||||
c1.MustDrop()
|
c1.MustDrop()
|
||||||
relu := bn1.MustRelu(true)
|
relu := bn1.MustRelu(true)
|
||||||
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||||
l1 := maxpool.ApplyT(layer1, train)
|
l1 := maxpool.ApplyT(layer1, train)
|
||||||
l2 := l1.ApplyT(layer2, train)
|
l2 := l1.ApplyT(layer2, train)
|
||||||
l1.MustDrop()
|
l1.MustDrop()
|
||||||
|
@ -100,7 +100,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
||||||
l2.MustDrop()
|
l2.MustDrop()
|
||||||
l4 := l3.ApplyT(layer4, train)
|
l4 := l3.ApplyT(layer4, train)
|
||||||
l3.MustDrop()
|
l3.MustDrop()
|
||||||
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
|
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
l4.MustDrop()
|
l4.MustDrop()
|
||||||
fv := avgpool.FlatView()
|
fv := avgpool.FlatView()
|
||||||
avgpool.MustDrop()
|
avgpool.MustDrop()
|
||||||
|
@ -118,7 +118,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
||||||
bn1 := c1.ApplyT(bn1, train)
|
bn1 := c1.ApplyT(bn1, train)
|
||||||
c1.MustDrop()
|
c1.MustDrop()
|
||||||
relu := bn1.MustRelu(true)
|
relu := bn1.MustRelu(true)
|
||||||
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||||
l1 := maxpool.ApplyT(layer1, train)
|
l1 := maxpool.ApplyT(layer1, train)
|
||||||
maxpool.MustDrop()
|
maxpool.MustDrop()
|
||||||
l2 := l1.ApplyT(layer2, train)
|
l2 := l1.ApplyT(layer2, train)
|
||||||
|
@ -127,7 +127,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
||||||
l2.MustDrop()
|
l2.MustDrop()
|
||||||
l4 := l3.ApplyT(layer4, train)
|
l4 := l3.ApplyT(layer4, train)
|
||||||
l3.MustDrop()
|
l3.MustDrop()
|
||||||
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
|
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
l4.MustDrop()
|
l4.MustDrop()
|
||||||
retVal = avgpool.FlatView()
|
retVal = avgpool.FlatView()
|
||||||
avgpool.MustDrop()
|
avgpool.MustDrop()
|
||||||
|
@ -215,7 +215,7 @@ func bottleneckResnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVa
|
||||||
bn1 := c1.ApplyT(bn1, train)
|
bn1 := c1.ApplyT(bn1, train)
|
||||||
c1.MustDrop()
|
c1.MustDrop()
|
||||||
relu := bn1.MustRelu(true)
|
relu := bn1.MustRelu(true)
|
||||||
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||||
l1 := maxpool.ApplyT(layer1, train)
|
l1 := maxpool.ApplyT(layer1, train)
|
||||||
l2 := l1.ApplyT(layer2, train)
|
l2 := l1.ApplyT(layer2, train)
|
||||||
l1.MustDrop()
|
l1.MustDrop()
|
||||||
|
@ -223,7 +223,7 @@ func bottleneckResnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVa
|
||||||
l2.MustDrop()
|
l2.MustDrop()
|
||||||
l4 := l3.ApplyT(layer4, train)
|
l4 := l3.ApplyT(layer4, train)
|
||||||
l3.MustDrop()
|
l3.MustDrop()
|
||||||
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
|
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
l4.MustDrop()
|
l4.MustDrop()
|
||||||
fv := avgpool.FlatView()
|
fv := avgpool.FlatView()
|
||||||
avgpool.MustDrop()
|
avgpool.MustDrop()
|
||||||
|
@ -239,7 +239,7 @@ func bottleneckResnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVa
|
||||||
bn1 := c1.ApplyT(bn1, train)
|
bn1 := c1.ApplyT(bn1, train)
|
||||||
c1.MustDrop()
|
c1.MustDrop()
|
||||||
relu := bn1.MustRelu(true)
|
relu := bn1.MustRelu(true)
|
||||||
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||||
l1 := maxpool.ApplyT(layer1, train)
|
l1 := maxpool.ApplyT(layer1, train)
|
||||||
maxpool.MustDrop()
|
maxpool.MustDrop()
|
||||||
l2 := l1.ApplyT(layer2, train)
|
l2 := l1.ApplyT(layer2, train)
|
||||||
|
@ -248,7 +248,7 @@ func bottleneckResnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVa
|
||||||
l2.MustDrop()
|
l2.MustDrop()
|
||||||
l4 := l3.ApplyT(layer4, train)
|
l4 := l3.ApplyT(layer4, train)
|
||||||
l3.MustDrop()
|
l3.MustDrop()
|
||||||
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
|
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
l4.MustDrop()
|
l4.MustDrop()
|
||||||
retVal = avgpool.FlatView()
|
retVal = avgpool.FlatView()
|
||||||
avgpool.MustDrop()
|
avgpool.MustDrop()
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func snMaxPool2D(xs ts.Tensor) (retVal ts.Tensor) {
|
func snMaxPool2D(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
return xs.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{0, 0}, []int64{1, 1}, true, false)
|
return xs.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{0, 0}, []int64{1, 1}, true, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func fire(p nn.Path, cIn int64, cSqueeze int64, cExp1 int64, cExp3 int64) (retVal ts.ModuleT) {
|
func fire(p nn.Path, cIn int64, cSqueeze int64, cExp1 int64, cExp3 int64) (retVal ts.ModuleT) {
|
||||||
|
@ -31,7 +31,7 @@ func fire(p nn.Path, cIn int64, cSqueeze int64, cExp1 int64, cExp3 int64) (retVa
|
||||||
exp3Tmp := tmp2.Apply(exp3)
|
exp3Tmp := tmp2.Apply(exp3)
|
||||||
exp3Ts := exp3Tmp.MustRelu(true)
|
exp3Ts := exp3Tmp.MustRelu(true)
|
||||||
|
|
||||||
return ts.MustCat([]ts.Tensor{exp1Ts, exp3Ts}, 1, true)
|
return ts.MustCat([]ts.Tensor{exp1Ts, exp3Ts}, 1)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,14 +119,14 @@ func squeezenet(p nn.Path, v1_0 bool, nclasses int64) (retVal ts.ModuleT) {
|
||||||
}
|
}
|
||||||
|
|
||||||
features.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
features.AddFnT(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.5, train, false)
|
return ts.MustDropout(xs, 0.5, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
features.Add(nn.NewConv2D(cp.Sub("1"), 512, nclasses, 1, finalConvConfig))
|
features.Add(nn.NewConv2D(cp.Sub("1"), 512, nclasses, 1, finalConvConfig))
|
||||||
|
|
||||||
features.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
features.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
tmp1 := xs.MustRelu(false)
|
tmp1 := xs.MustRelu(false)
|
||||||
tmp2 := tmp1.MustAdaptiveAvgPool2D([]int64{1, 1})
|
tmp2 := tmp1.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
|
||||||
tmp1.MustDrop()
|
tmp1.MustDrop()
|
||||||
res := tmp2.FlatView()
|
res := tmp2.FlatView()
|
||||||
tmp2.MustDrop()
|
tmp2.MustDrop()
|
||||||
|
|
|
@ -101,7 +101,7 @@ func vgg(path nn.Path, config [][]int64, nclasses int64, batchNorm bool) nn.Sequ
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.5, train, false)
|
return ts.MustDropout(xs, 0.5, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("3")), 4096, 4096, nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("3")), 4096, 4096, nn.DefaultLinearConfig()))
|
||||||
|
@ -111,7 +111,7 @@ func vgg(path nn.Path, config [][]int64, nclasses int64, batchNorm bool) nn.Sequ
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.MustDropout(0.5, train, false)
|
return ts.MustDropout(xs, 0.5, train)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("6")), 4096, nclasses, nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("6")), 4096, nclasses, nn.DefaultLinearConfig()))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user