added CudaIfAvailable() and switch back to install libtorch inside gotch; bump new version

This commit is contained in:
sugarme 2020-12-31 00:32:09 +11:00
parent 105434ed22
commit 8eb03fbcf6
7 changed files with 81 additions and 336 deletions

View File

@ -65,11 +65,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- nn/sequential: fixed missing case number of layers = 1 causing panic
- nn/varstore: fixed(nn/varstore): fixed nil pointer at LoadPartial due to not break loop
# [0.3.4]
## [0.3.4]
### Added
- [#4] Automatically download and install Libtorch and setup environment variables.
## [0.3.5]
### Added
- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available.
### Changed
- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called.
[#10]: https://github.com/sugarme/gotch/issues/10
[#6]: https://github.com/sugarme/gotch/issues/6
[#4]: https://github.com/sugarme/gotch/issues/4

View File

@ -22,14 +22,22 @@ Gotch is in active development mode and may have API breaking changes. Feel free
## Installation
When installing `gotch` with `go get`, it will automatically download and install corresponding `Libtorch` and cache at `$HOME/.cache/gotch/LIBTORCH_VERSION` directory.
- Default CUDA version is `10.1` if CUDA is available otherwise using CPU version.
- Default Pytorch C++ API version is `1.7.0`
```bash
go get github.com/sugarme/gotch@v0.3.4
wget https://github.com/sugarme/gotch/blob/v0.3.5/setup.sh
chmod +x setup.sh
# Default
sudo bash setup.sh
# Specify CUDA version
export CUDA_VER=YOUR_PC_CUDA_VERSION && sudo bash setup.sh
# CPU
export CUDA_VER=cpu && sudo bash setup.sh
```

View File

@ -114,3 +114,13 @@ func (d Device) IsCuda() bool {
return true
}
// CudaIfAvailable returns a GPU device if available, else CPU.
func CudaIfAvailable() Device {
switch {
case CUDA.IsAvailable():
return CudaBuilder(0)
default:
return CPU
}
}

View File

@ -1,194 +0,0 @@
package gotch
import (
"archive/zip"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
)
// downloadFile downloads file from URL and stores it in local filepath.
// It writes to the destination file as it downloads it, without loading
// the entire file into memory. An `io.TeeReader` is passed into Copy()
// to report progress on the download.
func downloadFile(url string, filepath string) error {
// Create the file with .tmp extension, so that we won't overwrite a
// file until it's downloaded fully
out, err := os.Create(filepath + ".tmp")
if err != nil {
return err
}
defer out.Close()
// Get the data
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
// Create our bytes counter and pass it to be used alongside our writer
counter := &writeCounter{}
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
if err != nil {
return err
}
// The progress use the same line so print a new line once it's finished downloading
fmt.Println()
// Rename the tmp file back to the original file
err = os.Rename(filepath+".tmp", filepath)
if err != nil {
return err
}
return nil
}
// writeCounter counts the number of bytes written to it. By implementing the Write method,
// it is of the io.Writer interface and we can pass this into io.TeeReader()
// Every write to this writer, will print the progress of the file write.
type writeCounter struct {
Total uint64
}
func (wc *writeCounter) Write(p []byte) (int, error) {
n := len(p)
wc.Total += uint64(n)
wc.printProgress()
return n, nil
}
// PrintProgress prints the progress of a file write
func (wc writeCounter) printProgress() {
// Clear the line by using a character return to go back to the start and remove
// the remaining characters by filling it with spaces
fmt.Printf("\r%s", strings.Repeat(" ", 50))
// Return again and print current status of download
fmt.Printf("\rDownloading... %s complete", byteCountIEC(wc.Total))
}
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
func byteCountIEC(b uint64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := uint64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func copyFile(src, dst string) error {
sourceFileStat, err := os.Stat(src)
if err != nil {
return err
}
if !sourceFileStat.Mode().IsRegular() {
return fmt.Errorf("%s is not a regular file", src)
}
source, err := os.Open(src)
if err != nil {
return err
}
defer source.Close()
destination, err := os.Create(dst)
if err != nil {
return err
}
defer destination.Close()
_, err = io.Copy(destination, source)
return err
}
// Unzip unzips file
// Ref. https://stackoverflow.com/questions/20357223
func Unzip(src, dest string) error {
r, err := zip.OpenReader(src)
if err != nil {
return err
}
defer func() {
if err := r.Close(); err != nil {
panic(err)
}
}()
os.MkdirAll(dest, 0755)
// Closure to address file descriptors issue with all the deferred .Close() methods
extractAndWriteFile := func(f *zip.File) error {
rc, err := f.Open()
if err != nil {
return err
}
defer func() {
if err := rc.Close(); err != nil {
panic(err)
}
}()
path := filepath.Join(dest, f.Name)
// Check for ZipSlip (Directory traversal)
if !strings.HasPrefix(path, filepath.Clean(dest)+string(os.PathSeparator)) {
return fmt.Errorf("illegal file path: %s", path)
}
if f.FileInfo().IsDir() {
os.MkdirAll(path, f.Mode())
} else {
os.MkdirAll(filepath.Dir(path), f.Mode())
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return err
}
defer func() {
if err := f.Close(); err != nil {
panic(err)
}
}()
_, err = io.Copy(f, rc)
if err != nil {
return err
}
}
return nil
}
for _, f := range r.File {
err := extractAndWriteFile(f)
if err != nil {
return err
}
}
return nil
}
// CurrDir returns absolute path of the file contains code calling this function.
func CurrDir() string {
// pc, file, line, ok := runtime.Caller(1)
_, file, _, ok := runtime.Caller(1)
if !ok {
log.Fatal("Runtime Caller error: cannot get information of current caller.")
}
return filepath.Dir(file)
}

119
init.go
View File

@ -1,119 +0,0 @@
package gotch
import (
// "flag"
"fmt"
"log"
"os"
"os/exec"
"os/user"
"path"
"strings"
)
var (
GotchCacheDir string
TorchDir string
IsCudaAvailable bool = false
HomeDir string
)
func init() {
if v, ok := os.LookupEnv("TORCH_VERSION"); ok {
TorchVersion = v
}
if v, ok := os.LookupEnv("TORCH_DIR"); ok {
TorchDir = v
}
IsCudaAvailable = isCudaAvailable()
if !isCudaAvailable() {
CudaVersion = "cpu"
}
HomeDir = homeDir()
GotchCacheDir = fmt.Sprintf("%v/.cache/gotch", HomeDir)
InstallLibtorch()
}
// isCudaAvailable check whether cuda is installed using 'nvcc'
func isCudaAvailable() bool {
_, err := exec.Command("nvcc", "--version").Output()
if err != nil {
log.Printf("CUDA is not detected using 'nvcc --version'\n")
return false
}
return true
}
// homeDir returns home directory of current user.
func homeDir() string {
usr, err := user.Current()
if err != nil {
log.Fatal(err)
}
return usr.HomeDir
}
func InstallLibtorch() {
dev := strings.ReplaceAll(CudaVersion, ".", "")
cu := "cpu"
if !strings.Contains(dev, "cpu") {
cu = fmt.Sprintf("cu%v", dev)
}
url := fmt.Sprintf("https://download.pytorch.org/libtorch/%v/libtorch-cxx11-abi-shared-with-deps-%v%%2B%v.zip", cu, TorchVersion, cu)
// Create dir if not exist
TorchDir = fmt.Sprintf("%v/libtorch-%v-%v", GotchCacheDir, TorchVersion, cu)
if _, err := os.Stat(TorchDir); os.IsNotExist(err) {
os.MkdirAll(TorchDir, 0755)
}
// install libtorch if not done yet.
installLibtorch(url)
// Export to current context so that shell script can catch them
libtorch := fmt.Sprintf("%v/libtorch", TorchDir)
os.Setenv("GOTCH_LIBTORCH", libtorch)
os.Setenv("GOTCH_CUDA_VERSION", CudaVersion)
os.Setenv("GOTCH_VER", GotchVersion)
envFile := fmt.Sprintf("%v/env.sh", CurrDir())
_, err := exec.Command(envFile).Output()
if err != nil {
log.Fatal(err)
}
}
func installLibtorch(url string) {
filename := path.Base(url)
cachedFileCandidate := fmt.Sprintf("%s/%v", TorchDir, filename)
// Check whether zip file exists otherwise, download it
if _, err := os.Stat(cachedFileCandidate); err == nil {
// zip file exists
// check one file 'libc10.so' if exists, assuming libtorch has been installed.
libc10 := fmt.Sprintf("%v/libtorch/lib/libc10.so", TorchDir)
if _, err := os.Stat(libc10); err == nil {
return
}
} else {
// download zip file
err := downloadFile(url, cachedFileCandidate)
if err != nil {
log.Fatal(err)
}
}
// Unzip
log.Printf("Unzipping...")
err := Unzip(cachedFileCandidate, TorchDir)
if err != nil {
log.Fatal(err)
}
}
func Hello() {
fmt.Println("Hello")
}

View File

@ -1,11 +1,19 @@
#!/bin/bash
# Catch info from Go exported environment variables
GOTCH_VERSION="${GOTCH_VER:-NotSpecified}"
LIBTORCH="${GOTCH_LIBTORCH:-NotSpecified}" # Libtorch root path
CUDA_VERSION="${GOTCH_CUDA_VERSION:-NotSpecified}" # e.g 10.1; cpu
GOTCH_VERSION="${GOTCH_VER:-v0.3.5}"
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
CUDA_VERSION="${CUDA_VER:-10.1}"
if [ $CUDA_VERSION == "cpu" ]
then
CU_VERSION="cpu"
else
CU_VERSION="cu${CUDA_VERSION//./}"
fi
# Libtorch paths
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
LIBTORCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch"
LIBRARY_PATH="$LIBTORCH/lib"
# CPATH="$CPATH:$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
CPATH="$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
@ -14,19 +22,48 @@ if [ $CUDA_VERSION == "cpu" ]
then
# LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib"
LD_LIBRARY_PATH="$LIBTORCH/lib"
# prepare C lib for CPU version
sudo rm $GOTCH/libtch/lib.go
sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go
sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu
sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp
else
# LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
LD_LIBRARY_PATH="$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
fi
#update .bashrc
FILE="$HOME/.bashrc"
# Update current shell environment variables for newly installed Libtorch
export LIBRARY_PATH=$LIBRARY_PATH
export CPATH=$CPATH
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH
# Install gotch
#==============
cwd=$(pwd)
mkdir /tmp/gotch-test
cd /tmp/gotch-test
go mod init "github.com/sugarme/gotch-test"
go get "github.com/sugarme/gotch@$GOTCH_VERSION"
rm -rf /tmp/gotch-test
cd $cwd
if [ $CUDA_VERSION == "cpu" ]
then
# prepare C lib for CPU version
sudo rm $GOTCH/libtch/lib.go
sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go
sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu
sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp
fi
# Install Libtorch
#=================
LIBTORCH_ZIP="libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2B${CU_VERSION}.zip"
LIBTORCH_URL="https://download.pytorch.org/libtorch/${CU_VERSION}/${LIBTORCH_ZIP}"
wget -O "/tmp/$LIBTORCH_ZIP" "$LIBTORCH_URL"
# delete old libtorch if existing
sudo rm -rf $LIBTORCH
sudo unzip "/tmp/$LIBTORCH_ZIP" -d $GOTCH/libtch
rm "/tmp/$LIBTORCH_ZIP"
# Update .bashrc
#===============
FILE="$HOME/.bashrc-test"
LN_GOTCH_LIBTORCH="export GOTCH_LIBTORCH=$LIBTORCH"
LN_LIBRARY_PATH="export LIBRARY_PATH=$LIBRARY_PATH"
LN_CPATH="export CPATH=$CPATH"
@ -42,3 +79,4 @@ grep -qF 'export LD_LIBRARY_PATH' "$FILE" && sed -i 's|^export LD_LIBRARY_PATH.*
# refresh environment for all next opening shells.
exec "$BASH"

View File

@ -1,7 +0,0 @@
package gotch
var (
GotchVersion string = "v0.3.4"
TorchVersion string = "1.7.0"
CudaVersion string = "10.1"
)