resolved #4 and bump version
This commit is contained in:
parent
66efd1583a
commit
105434ed22
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -59,12 +59,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
### Added
|
||||
- [#6]: Go native tensor print using `fmt.Formatter` interface. Now, a tensor can be printed out like: `fmt.Printf("%.3f", tensor)` (for float type)
|
||||
|
||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||
[#6]: https://github.com/sugarme/gotch/issues/6
|
||||
|
||||
## [0.3.3]
|
||||
|
||||
### Fixed
|
||||
- nn/sequential: fixed missing case number of layers = 1 causing panic
|
||||
- nn/varstore: fixed(nn/varstore): fixed nil pointer at LoadPartial due to not break loop
|
||||
|
||||
# [0.3.4]
|
||||
|
||||
### Added
|
||||
- [#4] Automatically download and install Libtorch and setup environment variables.
|
||||
|
||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||
[#6]: https://github.com/sugarme/gotch/issues/6
|
||||
[#4]: https://github.com/sugarme/gotch/issues/4
|
||||
|
||||
|
|
32
README.md
32
README.md
|
@ -1,9 +1,9 @@
|
|||
# GoTch [![License](https://img.shields.io/:license-apache-blue.svg)](https://opensource.org/licenses/Apache-2.0)[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/sugarme/gotch?tab=doc)[![Travis CI](https://api.travis-ci.org/sugarme/gotch.svg?branch=master)](https://travis-ci.org/sugarme/gotch)[![Go Report Card](https://goreportcard.com/badge/github.com/sugarme/gotch)](https://goreportcard.com/report/github.com/sugarme/gotch)
|
||||
# Gotch [![License](https://img.shields.io/:license-apache-blue.svg)](https://opensource.org/licenses/Apache-2.0)[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/sugarme/gotch?tab=doc)[![Travis CI](https://api.travis-ci.org/sugarme/gotch.svg?branch=master)](https://travis-ci.org/sugarme/gotch)[![Go Report Card](https://goreportcard.com/badge/github.com/sugarme/gotch)](https://goreportcard.com/report/github.com/sugarme/gotch)
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
GoTch creates a thin wrapper to Pytorch C++ APIs (Libtorch) to make use of its already optimized C++ tensor APIs (~ over 1400) and dynamic graph computation with CUDA support and provides idiomatic Go APIs for developing and implementing Deep Learning in Go.
|
||||
Gotch creates a thin wrapper to Pytorch C++ APIs (Libtorch) to make use of its already optimized C++ tensor APIs (~ over 1400) and dynamic graph computation with CUDA support and provides idiomatic Go APIs for developing and implementing Deep Learning in Go.
|
||||
|
||||
**Some features are**
|
||||
- [x] Comprehensive Pytorch tensor APIs (~ 1404)
|
||||
|
@ -14,34 +14,24 @@ GoTch creates a thin wrapper to Pytorch C++ APIs (Libtorch) to make use of its a
|
|||
- [x] Most recent image models
|
||||
- [ ] NLP Language models - [Transformer](https://github.com/sugarme/transformer) in separate package built with GoTch and [pure Go Tokenizer](https://github.com/sugarme/tokenizer).
|
||||
|
||||
GoTch is in active development mode and may have API breaking changes. Feel free to pull request, report issues or discuss any concerns. All contributions are welcome.
|
||||
Gotch is in active development mode and may have API breaking changes. Feel free to pull request, report issues or discuss any concerns. All contributions are welcome.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **Libtorch** C++ v1.7.0 library of [Pytorch](https://pytorch.org/)
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
- **CPU**
|
||||
When installing `gotch` with `go get`, it will automatically download and install corresponding `Libtorch` and cache at `$HOME/.cache/gotch/LIBTORCH_VERSION` directory.
|
||||
|
||||
Default values: `LIBTORCH_VER=1.7.0` and `GOTCH_VER=v0.3.2`
|
||||
- Default CUDA version is `10.1` if CUDA is available otherwise using CPU version.
|
||||
- Default Pytorch C++ API version is `1.7.0`
|
||||
|
||||
```bash
|
||||
go get -u github.com/sugarme/gotch@v0.3.2
|
||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-cpu.sh
|
||||
```bash
|
||||
|
||||
```
|
||||
go get github.com/sugarme/gotch@v0.3.4
|
||||
|
||||
- **GPU**
|
||||
|
||||
Default values: `LIBTORCH_VER=1.7.0`, `CUDA_VER=10.1` and `GOTCH_VER=v0.3.2`
|
||||
|
||||
```bash
|
||||
go get -u github.com/sugarme/gotch@v0.3.2
|
||||
bash ${GOPATH}/pkg/mod/github.com/sugarme/gotch@v0.3.2/setup-gpu.sh
|
||||
|
||||
```
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
|
@ -198,7 +188,7 @@ func main() {
|
|||
//Logits: 0.000 0.000 0.000 0.225 0.321 0.147 0.000 0.207 0.000 0.000
|
||||
```
|
||||
|
||||
## Play with GoTch on Google Colab or locally
|
||||
## Play with `gotch` on Google Colab or locally
|
||||
|
||||
- [Tensor Initiation](example/basic) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/tensor/tensor-initiation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
||||
- [Tensor Indexing](example/basic) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/tensor/tensor-indexing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
||||
|
@ -219,7 +209,7 @@ More coming soon...
|
|||
|
||||
## License
|
||||
|
||||
GoTch is Apache 2.0 licensed.
|
||||
Gotch is Apache 2.0 licensed.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
|
|
44
env.sh
Executable file
44
env.sh
Executable file
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Catch info from Go exported environment variables
|
||||
GOTCH_VERSION="${GOTCH_VER:-NotSpecified}"
|
||||
LIBTORCH="${GOTCH_LIBTORCH:-NotSpecified}" # Libtorch root path
|
||||
CUDA_VERSION="${GOTCH_CUDA_VERSION:-NotSpecified}" # e.g 10.1; cpu
|
||||
|
||||
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
LIBRARY_PATH="$LIBTORCH/lib"
|
||||
# CPATH="$CPATH:$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
CPATH="$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
|
||||
if [ $CUDA_VERSION == "cpu" ]
|
||||
then
|
||||
# LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib"
|
||||
LD_LIBRARY_PATH="$LIBTORCH/lib"
|
||||
|
||||
# prepare C lib for CPU version
|
||||
sudo rm $GOTCH/libtch/lib.go
|
||||
sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go
|
||||
sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu
|
||||
sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp
|
||||
else
|
||||
# LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
LD_LIBRARY_PATH="$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
fi
|
||||
|
||||
#update .bashrc
|
||||
FILE="$HOME/.bashrc"
|
||||
LN_GOTCH_LIBTORCH="export GOTCH_LIBTORCH=$LIBTORCH"
|
||||
LN_LIBRARY_PATH="export LIBRARY_PATH=$LIBRARY_PATH"
|
||||
LN_CPATH="export CPATH=$CPATH"
|
||||
LN_LD_LIBRARY_PATH="export LD_LIBRARY_PATH=$LD_LIBRARY_PATH"
|
||||
|
||||
# replace line if matching pattern otherwise, insert a new line to the bottom.
|
||||
# -qF quiet, plain text
|
||||
grep -qF 'export GOTCH_LIBTORCH' "$FILE" && sed -i 's|^export GOTCH_LIBTORCH.*|'"$LN_GOTCH_LIBTORCH"'|g' "$FILE" || echo "$LN_GOTCH_LIBTORCH" >> "$FILE"
|
||||
grep -qF 'export LIBRARY_PATH' "$FILE" && sed -i 's|^export LIBRARY_PATH.*|'"$LN_LIBRARY_PATH"'|g' "$FILE" || echo "$LN_LIBRARY_PATH" >> "$FILE"
|
||||
grep -qF 'export CPATH' "$FILE" && sed -i 's|^export CPATH.*|'"$LN_CPATH"'|g' "$FILE" || echo "$LN_CPATH" >> "$FILE"
|
||||
grep -qF 'export LD_LIBRARY_PATH' "$FILE" && sed -i 's|^export LD_LIBRARY_PATH.*|'"$LN_LD_LIBRARY_PATH"'|g' "$FILE" || echo "$LN_LD_LIBRARY_PATH" >> "$FILE"
|
||||
|
||||
# refresh environment for all next opening shells.
|
||||
exec "$BASH"
|
||||
|
194
file-util.go
Normal file
194
file-util.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// downloadFile downloads file from URL and stores it in local filepath.
|
||||
// It writes to the destination file as it downloads it, without loading
|
||||
// the entire file into memory. An `io.TeeReader` is passed into Copy()
|
||||
// to report progress on the download.
|
||||
func downloadFile(url string, filepath string) error {
|
||||
|
||||
// Create the file with .tmp extension, so that we won't overwrite a
|
||||
// file until it's downloaded fully
|
||||
out, err := os.Create(filepath + ".tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create our bytes counter and pass it to be used alongside our writer
|
||||
counter := &writeCounter{}
|
||||
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The progress use the same line so print a new line once it's finished downloading
|
||||
fmt.Println()
|
||||
|
||||
// Rename the tmp file back to the original file
|
||||
err = os.Rename(filepath+".tmp", filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCounter counts the number of bytes written to it. By implementing the Write method,
|
||||
// it is of the io.Writer interface and we can pass this into io.TeeReader()
|
||||
// Every write to this writer, will print the progress of the file write.
|
||||
type writeCounter struct {
|
||||
Total uint64
|
||||
}
|
||||
|
||||
func (wc *writeCounter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
wc.Total += uint64(n)
|
||||
wc.printProgress()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// PrintProgress prints the progress of a file write
|
||||
func (wc writeCounter) printProgress() {
|
||||
// Clear the line by using a character return to go back to the start and remove
|
||||
// the remaining characters by filling it with spaces
|
||||
fmt.Printf("\r%s", strings.Repeat(" ", 50))
|
||||
|
||||
// Return again and print current status of download
|
||||
fmt.Printf("\rDownloading... %s complete", byteCountIEC(wc.Total))
|
||||
}
|
||||
|
||||
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
|
||||
func byteCountIEC(b uint64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFileStat, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !sourceFileStat.Mode().IsRegular() {
|
||||
return fmt.Errorf("%s is not a regular file", src)
|
||||
}
|
||||
|
||||
source, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
destination, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destination.Close()
|
||||
_, err = io.Copy(destination, source)
|
||||
return err
|
||||
}
|
||||
|
||||
// Unzip unzips file
|
||||
// Ref. https://stackoverflow.com/questions/20357223
|
||||
func Unzip(src, dest string) error {
|
||||
r, err := zip.OpenReader(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := r.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
os.MkdirAll(dest, 0755)
|
||||
|
||||
// Closure to address file descriptors issue with all the deferred .Close() methods
|
||||
extractAndWriteFile := func(f *zip.File) error {
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := rc.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
path := filepath.Join(dest, f.Name)
|
||||
|
||||
// Check for ZipSlip (Directory traversal)
|
||||
if !strings.HasPrefix(path, filepath.Clean(dest)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("illegal file path: %s", path)
|
||||
}
|
||||
|
||||
if f.FileInfo().IsDir() {
|
||||
os.MkdirAll(path, f.Mode())
|
||||
} else {
|
||||
os.MkdirAll(filepath.Dir(path), f.Mode())
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(f, rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, f := range r.File {
|
||||
err := extractAndWriteFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CurrDir returns absolute path of the file contains code calling this function.
|
||||
func CurrDir() string {
|
||||
// pc, file, line, ok := runtime.Caller(1)
|
||||
_, file, _, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
log.Fatal("Runtime Caller error: cannot get information of current caller.")
|
||||
}
|
||||
|
||||
return filepath.Dir(file)
|
||||
}
|
119
init.go
Normal file
119
init.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
// "flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
GotchCacheDir string
|
||||
TorchDir string
|
||||
IsCudaAvailable bool = false
|
||||
HomeDir string
|
||||
)
|
||||
|
||||
func init() {
|
||||
if v, ok := os.LookupEnv("TORCH_VERSION"); ok {
|
||||
TorchVersion = v
|
||||
}
|
||||
if v, ok := os.LookupEnv("TORCH_DIR"); ok {
|
||||
TorchDir = v
|
||||
}
|
||||
|
||||
IsCudaAvailable = isCudaAvailable()
|
||||
if !isCudaAvailable() {
|
||||
CudaVersion = "cpu"
|
||||
}
|
||||
|
||||
HomeDir = homeDir()
|
||||
GotchCacheDir = fmt.Sprintf("%v/.cache/gotch", HomeDir)
|
||||
|
||||
InstallLibtorch()
|
||||
}
|
||||
|
||||
// isCudaAvailable check whether cuda is installed using 'nvcc'
|
||||
func isCudaAvailable() bool {
|
||||
_, err := exec.Command("nvcc", "--version").Output()
|
||||
if err != nil {
|
||||
log.Printf("CUDA is not detected using 'nvcc --version'\n")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// homeDir returns home directory of current user.
|
||||
func homeDir() string {
|
||||
usr, err := user.Current()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return usr.HomeDir
|
||||
}
|
||||
|
||||
func InstallLibtorch() {
|
||||
dev := strings.ReplaceAll(CudaVersion, ".", "")
|
||||
cu := "cpu"
|
||||
if !strings.Contains(dev, "cpu") {
|
||||
cu = fmt.Sprintf("cu%v", dev)
|
||||
}
|
||||
url := fmt.Sprintf("https://download.pytorch.org/libtorch/%v/libtorch-cxx11-abi-shared-with-deps-%v%%2B%v.zip", cu, TorchVersion, cu)
|
||||
|
||||
// Create dir if not exist
|
||||
TorchDir = fmt.Sprintf("%v/libtorch-%v-%v", GotchCacheDir, TorchVersion, cu)
|
||||
if _, err := os.Stat(TorchDir); os.IsNotExist(err) {
|
||||
os.MkdirAll(TorchDir, 0755)
|
||||
}
|
||||
// install libtorch if not done yet.
|
||||
installLibtorch(url)
|
||||
|
||||
// Export to current context so that shell script can catch them
|
||||
libtorch := fmt.Sprintf("%v/libtorch", TorchDir)
|
||||
os.Setenv("GOTCH_LIBTORCH", libtorch)
|
||||
os.Setenv("GOTCH_CUDA_VERSION", CudaVersion)
|
||||
os.Setenv("GOTCH_VER", GotchVersion)
|
||||
|
||||
envFile := fmt.Sprintf("%v/env.sh", CurrDir())
|
||||
_, err := exec.Command(envFile).Output()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func installLibtorch(url string) {
|
||||
filename := path.Base(url)
|
||||
cachedFileCandidate := fmt.Sprintf("%s/%v", TorchDir, filename)
|
||||
|
||||
// Check whether zip file exists otherwise, download it
|
||||
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
||||
// zip file exists
|
||||
// check one file 'libc10.so' if exists, assuming libtorch has been installed.
|
||||
libc10 := fmt.Sprintf("%v/libtorch/lib/libc10.so", TorchDir)
|
||||
if _, err := os.Stat(libc10); err == nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// download zip file
|
||||
err := downloadFile(url, cachedFileCandidate)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unzip
|
||||
log.Printf("Unzipping...")
|
||||
err := Unzip(cachedFileCandidate, TorchDir)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Hello() {
|
||||
fmt.Println("Hello")
|
||||
}
|
35
setup-cpu.sh
35
setup-cpu.sh
|
@ -1,35 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Env
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.3}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
|
||||
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
LIBTORCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch"
|
||||
LIBRARY_PATH="$LIBTORCH/lib"
|
||||
CPATH="$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc:$LIBTORCH/include/torch/csrc/api/include"
|
||||
LD_LIBRARY_PATH="$LIBTORCH/lib"
|
||||
|
||||
# Precompiled libtorch
|
||||
sudo rm -rf $LIBTORCH
|
||||
sudo mkdir -p $LIBTORCH
|
||||
wget -O /tmp/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}+cpu.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip
|
||||
sudo unzip /tmp/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}+cpu.zip -d $GOTCH/libtch
|
||||
|
||||
# Update .bashrc
|
||||
FILE="$HOME/.bashrc"
|
||||
LN_LIBTORCH="export LIBTORCH=$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch"
|
||||
LN_LIBRARY_PATH="export LIBRARY_PATH=$LIBTORCH/lib"
|
||||
LN_CPATH="export CPATH=$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc:$LIBTORCH/include/torch/csrc/api/include"
|
||||
LN_LD_LIBRARY_PATH="export LD_LIBRARY_PATH=$LIBTORCH/lib:$LD_LIBRARY_PATH"
|
||||
sudo grep -xqF -- "$LN_LIBTORCH" "$FILE" || sudo echo "$LN_LIBTORCH" >> "$FILE"
|
||||
sudo grep -xqF -- "$LN_LIBRARY_PATH" "$FILE" || sudo echo "$LN_LIBRARY_PATH" >> "$FILE"
|
||||
sudo grep -xqF -- "$LN_CPATH" "$FILE" || sudo echo "$LN_CPATH" >> "$FILE"
|
||||
sudo grep -xqF -- "$LN_LD_LIBRARY_PATH" "$FILE" || sudo echo "$LN_LD_LIBRARY_PATH" >> "$FILE"
|
||||
|
||||
sudo rm $GOTCH/libtch/lib.go
|
||||
sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go
|
||||
sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu
|
||||
sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp
|
||||
|
||||
source "$FILE"
|
32
setup-gpu.sh
32
setup-gpu.sh
|
@ -1,32 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.3}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
CU_VERSION="${CUDA_VERSION//./}"
|
||||
|
||||
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
LIBTORCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch"
|
||||
LIBRARY_PATH="$LIBTORCH/lib"
|
||||
CPATH="$CPATH:$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
|
||||
sudo rm -rf $LIBTORCH
|
||||
sudo mkdir -p $LIBTORCH
|
||||
|
||||
|
||||
wget -O /tmp/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CU_VERSION}.zip https://download.pytorch.org/libtorch/cu${CU_VERSION}/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CU_VERSION}.zip
|
||||
sudo unzip /tmp/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CU_VERSION}.zip -d $GOTCH/libtch
|
||||
|
||||
# update .bashrc
|
||||
FILE="$HOME/.bashrc"
|
||||
LN_LIBTORCH="export LIBTORCH=$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch"
|
||||
LN_LIBRARY_PATH="export LIBRARY_PATH=$LIBTORCH/lib"
|
||||
LN_CPATH="export CPATH=$CPATH:$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
LN_LD_LIBRARY_PATH="export LD_LIBRARY_PATH=$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
sudo grep -xqF -- "$LN_LIBTORCH" "$FILE" || sudo echo "$LN_LIBTORCH" >> "$FILE"
|
||||
sudo grep -xqF -- "$LN_LIBRARY_PATH" "$FILE" || sudo echo "$LN_LIBRARY_PATH" >> "$FILE"
|
||||
sudo grep -xqF -- "$LN_CPATH" "$FILE" || sudo echo "$LN_CPATH" >> "$FILE"
|
||||
sudo grep -xqF -- "$LN_LD_LIBRARY_PATH" "$FILE" || sudo echo "$LN_LD_LIBRARY_PATH" >> "$FILE"
|
||||
|
||||
source "$FILE"
|
18
setup.sh
18
setup.sh
|
@ -1,18 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
export GOTCH_VERSION="${GOTCH_VER:-v0.3.2}"
|
||||
export LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
export CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
export CU_VERSION="${CUDA_VERSION//./}"
|
||||
|
||||
export GOTCH="$HOME/projects/sugarme/gotch"
|
||||
export LIBTORCH="$HOME/projects/sugarme/gotch/libtch/libtorch"
|
||||
export LIBRARY_PATH="$LIBTORCH/lib"
|
||||
export CPATH="$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
export LD_LIBRARY_PATH="$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
|
||||
sudo rm -rf $LIBTORCH
|
||||
sudo mkdir -p $LIBTORCH
|
||||
|
||||
wget -O /tmp/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CU_VERSION}.zip https://download.pytorch.org/libtorch/cu${CU_VERSION}/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CU_VERSION}.zip
|
||||
sudo unzip /tmp/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2Bcu${CU_VERSION}.zip -d $GOTCH/libtch
|
7
version.go
Normal file
7
version.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package gotch
|
||||
|
||||
var (
|
||||
GotchVersion string = "v0.3.4"
|
||||
TorchVersion string = "1.7.0"
|
||||
CudaVersion string = "10.1"
|
||||
)
|
Loading…
Reference in New Issue
Block a user