docs | ||
example | ||
gen | ||
libtch | ||
nn | ||
tensor | ||
vision | ||
.gitignore | ||
.travis.yml | ||
CHANGELOG.md | ||
device.go | ||
dtype.go | ||
dune-project | ||
go.mod | ||
go.sum | ||
LICENSE | ||
README.md | ||
setup-cpu.sh | ||
setup-gpu.sh | ||
setup.sh |
GoTch
Overview
- GoTch is a C++ Libtorch Go binding for developing and implementing deep learning projects in Go.
- This package is to create a thin wrapper of Libtorch to make use of its tensor APIs and CUDA support while implementing as much idiomatic Go as possible.
- There are about 1404 auto-generated tensor APIs.
Dependencies
- Libtorch C++ v1.7.0 library of Pytorch
Installation
-
CPU
Default values:
LIBTORCH_VER=1.7.0
andGOTCH_VER=v0.3.2
go get -u github.com/sugarme/gotch@v0.3.2 bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-cpu.sh
-
GPU
Default values:
LIBTORCH_VER=1.7.0
,CUDA_VER=10.1
andGOTCH_VER=v0.3.2
go get -u github.com/sugarme/gotch@v0.3.2 bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-gpu.sh
Examples
Basic tensor operations
import (
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func basicOps() {
xs := ts.MustRand([]int64{3, 5, 6}, gotch.Float, gotch.CPU)
fmt.Printf("%8.3f\n", xs)
fmt.Printf("%i", xs)
// output
/*
(1,.,.) =
0.391 0.055 0.638 0.514 0.757 0.446
0.817 0.075 0.437 0.452 0.077 0.492
0.504 0.945 0.863 0.243 0.254 0.640
0.850 0.132 0.763 0.572 0.216 0.116
0.410 0.660 0.156 0.336 0.885 0.391
(2,.,.) =
0.952 0.731 0.380 0.390 0.374 0.001
0.455 0.142 0.088 0.039 0.862 0.939
0.621 0.198 0.728 0.914 0.168 0.057
0.655 0.231 0.680 0.069 0.803 0.243
0.853 0.729 0.983 0.534 0.749 0.624
(3,.,.) =
0.734 0.447 0.914 0.956 0.269 0.000
0.427 0.034 0.477 0.535 0.440 0.972
0.407 0.945 0.099 0.184 0.778 0.058
0.482 0.996 0.085 0.605 0.282 0.671
0.887 0.029 0.005 0.216 0.354 0.262
TENSOR INFO:
Shape: [3 5 6]
DType: float32
Device: {CPU 1}
Defined: true
*/
// Basic tensor operations
ts1 := ts.MustArange(ts.IntScalar(6), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
defer ts1.MustDrop()
ts2 := ts.MustOnes([]int64{3, 4}, gotch.Int64, gotch.CPU)
defer ts2.MustDrop()
mul := ts1.MustMatmul(ts2, false)
defer mul.MustDrop()
fmt.Println("ts1: ")
ts1.Print()
fmt.Println("ts2: ")
ts2.Print()
fmt.Println("mul tensor (ts1 x ts2): ")
mul.Print()
//ts1:
// 0 1 2
// 3 4 5
//[ CPULongType{2,3} ]
//ts2:
// 1 1 1 1
// 1 1 1 1
// 1 1 1 1
//[ CPULongType{3,4} ]
//mul tensor (ts1 x ts2):
// 3 3 3 3
// 12 12 12 12
//[ CPULongType{2,4} ]
// In-place operation
ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU)
fmt.Println("Before:")
ts3.Print()
ts3.MustAdd1_(ts.FloatScalar(2.0))
fmt.Printf("After (ts3 + 2.0): \n")
ts3.Print()
ts3.MustDrop()
//Before:
// 1 1 1
// 1 1 1
//[ CPUFloatType{2,3} ]
//After (ts3 + 2.0):
// 3 3 3
// 3 3 3
//[ CPUFloatType{2,3} ]
}
Simplified Convolutional neural network
import (
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
type Net struct {
conv1 nn.Conv2D
conv2 nn.Conv2D
fc nn.Linear
}
func newNet(vs nn.Path) Net {
conv1 := nn.NewConv2D(vs, 1, 16, 2, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 16, 10, 2, nn.DefaultConv2DConfig())
fc := nn.NewLinear(vs, 10, 10, nn.DefaultLinearConfig())
return Net{
conv1,
conv2,
fc,
}
}
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
xs = xs.MustView([]int64{-1, 1, 8, 8}, false)
outC1 := xs.Apply(n.conv1)
outMP1 := outC1.MaxPool2DDefault(2, true)
defer outMP1.MustDrop()
outC2 := outMP1.Apply(n.conv2)
outMP2 := outC2.MaxPool2DDefault(2, true)
outView2 := outMP2.MustView([]int64{-1, 10}, true)
defer outView2.MustDrop()
outFC := outView2.Apply(&n.fc)
return outFC.MustRelu(true)
}
func main() {
vs := nn.NewVarStore(gotch.CPU)
net := newNet(vs.Root())
xs := ts.MustOnes([]int64{8, 8}, gotch.Float, gotch.CPU)
logits := net.ForwardT(xs, false)
logits.Print()
}
// 0.0000 0.0000 0.0000 0.2477 0.2437 0.0000 0.0000 0.0000 0.0000 0.0171
//[ CPUFloatType{1,10} ]
- Real application examples can be found at example folder
Getting Started
-
See pkg.go.dev for detail APIs
License
GoTch is Apache 2.0 licensed.
Acknowledgement
- This project has been inspired and used many concepts from tch-rs Libtorch Rust binding.