added CudaIfAvailable() and switch back to install libtorch inside gotch; bump new version
This commit is contained in:
parent
105434ed22
commit
8eb03fbcf6
11
CHANGELOG.md
11
CHANGELOG.md
|
@ -65,11 +65,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- nn/sequential: fixed missing case number of layers = 1 causing panic
|
||||
- nn/varstore: fixed(nn/varstore): fixed nil pointer at LoadPartial due to not break loop
|
||||
|
||||
# [0.3.4]
|
||||
## [0.3.4]
|
||||
|
||||
### Added
|
||||
- [#4] Automatically download and install Libtorch and setup environment variables.
|
||||
|
||||
## [0.3.5]
|
||||
|
||||
### Added
|
||||
- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available.
|
||||
|
||||
### Changed
|
||||
- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called.
|
||||
|
||||
|
||||
[#10]: https://github.com/sugarme/gotch/issues/10
|
||||
[#6]: https://github.com/sugarme/gotch/issues/6
|
||||
[#4]: https://github.com/sugarme/gotch/issues/4
|
||||
|
|
14
README.md
14
README.md
|
@ -22,14 +22,22 @@ Gotch is in active development mode and may have API breaking changes. Feel free
|
|||
|
||||
## Installation
|
||||
|
||||
When installing `gotch` with `go get`, it will automatically download and install corresponding `Libtorch` and cache at `$HOME/.cache/gotch/LIBTORCH_VERSION` directory.
|
||||
|
||||
- Default CUDA version is `10.1` if CUDA is available otherwise using CPU version.
|
||||
- Default Pytorch C++ API version is `1.7.0`
|
||||
|
||||
```bash
|
||||
|
||||
go get github.com/sugarme/gotch@v0.3.4
|
||||
wget https://github.com/sugarme/gotch/blob/v0.3.5/setup.sh
|
||||
chmod +x setup.sh
|
||||
|
||||
# Default
|
||||
sudo bash setup.sh
|
||||
|
||||
# Specify CUDA version
|
||||
export CUDA_VER=YOUR_PC_CUDA_VERSION && sudo bash setup.sh
|
||||
|
||||
# CPU
|
||||
export CUDA_VER=cpu && sudo bash setup.sh
|
||||
|
||||
```
|
||||
|
||||
|
|
10
device.go
10
device.go
|
@ -114,3 +114,13 @@ func (d Device) IsCuda() bool {
|
|||
|
||||
return true
|
||||
}
|
||||
|
||||
// CudaIfAvailable returns a GPU device if available, else CPU.
|
||||
func CudaIfAvailable() Device {
|
||||
switch {
|
||||
case CUDA.IsAvailable():
|
||||
return CudaBuilder(0)
|
||||
default:
|
||||
return CPU
|
||||
}
|
||||
}
|
||||
|
|
194
file-util.go
194
file-util.go
|
@ -1,194 +0,0 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// downloadFile downloads file from URL and stores it in local filepath.
|
||||
// It writes to the destination file as it downloads it, without loading
|
||||
// the entire file into memory. An `io.TeeReader` is passed into Copy()
|
||||
// to report progress on the download.
|
||||
func downloadFile(url string, filepath string) error {
|
||||
|
||||
// Create the file with .tmp extension, so that we won't overwrite a
|
||||
// file until it's downloaded fully
|
||||
out, err := os.Create(filepath + ".tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create our bytes counter and pass it to be used alongside our writer
|
||||
counter := &writeCounter{}
|
||||
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The progress use the same line so print a new line once it's finished downloading
|
||||
fmt.Println()
|
||||
|
||||
// Rename the tmp file back to the original file
|
||||
err = os.Rename(filepath+".tmp", filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCounter counts the number of bytes written to it. By implementing the Write method,
|
||||
// it is of the io.Writer interface and we can pass this into io.TeeReader()
|
||||
// Every write to this writer, will print the progress of the file write.
|
||||
type writeCounter struct {
|
||||
Total uint64
|
||||
}
|
||||
|
||||
func (wc *writeCounter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
wc.Total += uint64(n)
|
||||
wc.printProgress()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// PrintProgress prints the progress of a file write
|
||||
func (wc writeCounter) printProgress() {
|
||||
// Clear the line by using a character return to go back to the start and remove
|
||||
// the remaining characters by filling it with spaces
|
||||
fmt.Printf("\r%s", strings.Repeat(" ", 50))
|
||||
|
||||
// Return again and print current status of download
|
||||
fmt.Printf("\rDownloading... %s complete", byteCountIEC(wc.Total))
|
||||
}
|
||||
|
||||
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
|
||||
func byteCountIEC(b uint64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
sourceFileStat, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !sourceFileStat.Mode().IsRegular() {
|
||||
return fmt.Errorf("%s is not a regular file", src)
|
||||
}
|
||||
|
||||
source, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
destination, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destination.Close()
|
||||
_, err = io.Copy(destination, source)
|
||||
return err
|
||||
}
|
||||
|
||||
// Unzip unzips file
|
||||
// Ref. https://stackoverflow.com/questions/20357223
|
||||
func Unzip(src, dest string) error {
|
||||
r, err := zip.OpenReader(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := r.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
os.MkdirAll(dest, 0755)
|
||||
|
||||
// Closure to address file descriptors issue with all the deferred .Close() methods
|
||||
extractAndWriteFile := func(f *zip.File) error {
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := rc.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
path := filepath.Join(dest, f.Name)
|
||||
|
||||
// Check for ZipSlip (Directory traversal)
|
||||
if !strings.HasPrefix(path, filepath.Clean(dest)+string(os.PathSeparator)) {
|
||||
return fmt.Errorf("illegal file path: %s", path)
|
||||
}
|
||||
|
||||
if f.FileInfo().IsDir() {
|
||||
os.MkdirAll(path, f.Mode())
|
||||
} else {
|
||||
os.MkdirAll(filepath.Dir(path), f.Mode())
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = io.Copy(f, rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, f := range r.File {
|
||||
err := extractAndWriteFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CurrDir returns absolute path of the file contains code calling this function.
|
||||
func CurrDir() string {
|
||||
// pc, file, line, ok := runtime.Caller(1)
|
||||
_, file, _, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
log.Fatal("Runtime Caller error: cannot get information of current caller.")
|
||||
}
|
||||
|
||||
return filepath.Dir(file)
|
||||
}
|
119
init.go
119
init.go
|
@ -1,119 +0,0 @@
|
|||
package gotch
|
||||
|
||||
import (
|
||||
// "flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
GotchCacheDir string
|
||||
TorchDir string
|
||||
IsCudaAvailable bool = false
|
||||
HomeDir string
|
||||
)
|
||||
|
||||
func init() {
|
||||
if v, ok := os.LookupEnv("TORCH_VERSION"); ok {
|
||||
TorchVersion = v
|
||||
}
|
||||
if v, ok := os.LookupEnv("TORCH_DIR"); ok {
|
||||
TorchDir = v
|
||||
}
|
||||
|
||||
IsCudaAvailable = isCudaAvailable()
|
||||
if !isCudaAvailable() {
|
||||
CudaVersion = "cpu"
|
||||
}
|
||||
|
||||
HomeDir = homeDir()
|
||||
GotchCacheDir = fmt.Sprintf("%v/.cache/gotch", HomeDir)
|
||||
|
||||
InstallLibtorch()
|
||||
}
|
||||
|
||||
// isCudaAvailable check whether cuda is installed using 'nvcc'
|
||||
func isCudaAvailable() bool {
|
||||
_, err := exec.Command("nvcc", "--version").Output()
|
||||
if err != nil {
|
||||
log.Printf("CUDA is not detected using 'nvcc --version'\n")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// homeDir returns home directory of current user.
|
||||
func homeDir() string {
|
||||
usr, err := user.Current()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return usr.HomeDir
|
||||
}
|
||||
|
||||
func InstallLibtorch() {
|
||||
dev := strings.ReplaceAll(CudaVersion, ".", "")
|
||||
cu := "cpu"
|
||||
if !strings.Contains(dev, "cpu") {
|
||||
cu = fmt.Sprintf("cu%v", dev)
|
||||
}
|
||||
url := fmt.Sprintf("https://download.pytorch.org/libtorch/%v/libtorch-cxx11-abi-shared-with-deps-%v%%2B%v.zip", cu, TorchVersion, cu)
|
||||
|
||||
// Create dir if not exist
|
||||
TorchDir = fmt.Sprintf("%v/libtorch-%v-%v", GotchCacheDir, TorchVersion, cu)
|
||||
if _, err := os.Stat(TorchDir); os.IsNotExist(err) {
|
||||
os.MkdirAll(TorchDir, 0755)
|
||||
}
|
||||
// install libtorch if not done yet.
|
||||
installLibtorch(url)
|
||||
|
||||
// Export to current context so that shell script can catch them
|
||||
libtorch := fmt.Sprintf("%v/libtorch", TorchDir)
|
||||
os.Setenv("GOTCH_LIBTORCH", libtorch)
|
||||
os.Setenv("GOTCH_CUDA_VERSION", CudaVersion)
|
||||
os.Setenv("GOTCH_VER", GotchVersion)
|
||||
|
||||
envFile := fmt.Sprintf("%v/env.sh", CurrDir())
|
||||
_, err := exec.Command(envFile).Output()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func installLibtorch(url string) {
|
||||
filename := path.Base(url)
|
||||
cachedFileCandidate := fmt.Sprintf("%s/%v", TorchDir, filename)
|
||||
|
||||
// Check whether zip file exists otherwise, download it
|
||||
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
||||
// zip file exists
|
||||
// check one file 'libc10.so' if exists, assuming libtorch has been installed.
|
||||
libc10 := fmt.Sprintf("%v/libtorch/lib/libc10.so", TorchDir)
|
||||
if _, err := os.Stat(libc10); err == nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// download zip file
|
||||
err := downloadFile(url, cachedFileCandidate)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Unzip
|
||||
log.Printf("Unzipping...")
|
||||
err := Unzip(cachedFileCandidate, TorchDir)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Hello() {
|
||||
fmt.Println("Hello")
|
||||
}
|
|
@ -1,11 +1,19 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Catch info from Go exported environment variables
|
||||
GOTCH_VERSION="${GOTCH_VER:-NotSpecified}"
|
||||
LIBTORCH="${GOTCH_LIBTORCH:-NotSpecified}" # Libtorch root path
|
||||
CUDA_VERSION="${GOTCH_CUDA_VERSION:-NotSpecified}" # e.g 10.1; cpu
|
||||
GOTCH_VERSION="${GOTCH_VER:-v0.3.5}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}"
|
||||
CUDA_VERSION="${CUDA_VER:-10.1}"
|
||||
|
||||
if [ $CUDA_VERSION == "cpu" ]
|
||||
then
|
||||
CU_VERSION="cpu"
|
||||
else
|
||||
CU_VERSION="cu${CUDA_VERSION//./}"
|
||||
fi
|
||||
|
||||
# Libtorch paths
|
||||
GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
LIBTORCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch"
|
||||
LIBRARY_PATH="$LIBTORCH/lib"
|
||||
# CPATH="$CPATH:$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
CPATH="$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include"
|
||||
|
@ -14,19 +22,48 @@ if [ $CUDA_VERSION == "cpu" ]
|
|||
then
|
||||
# LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib"
|
||||
LD_LIBRARY_PATH="$LIBTORCH/lib"
|
||||
|
||||
# prepare C lib for CPU version
|
||||
sudo rm $GOTCH/libtch/lib.go
|
||||
sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go
|
||||
sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu
|
||||
sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp
|
||||
else
|
||||
# LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
LD_LIBRARY_PATH="$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64"
|
||||
fi
|
||||
|
||||
#update .bashrc
|
||||
FILE="$HOME/.bashrc"
|
||||
# Update current shell environment variables for newly installed Libtorch
|
||||
export LIBRARY_PATH=$LIBRARY_PATH
|
||||
export CPATH=$CPATH
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH
|
||||
|
||||
# Install gotch
|
||||
#==============
|
||||
cwd=$(pwd)
|
||||
mkdir /tmp/gotch-test
|
||||
cd /tmp/gotch-test
|
||||
go mod init "github.com/sugarme/gotch-test"
|
||||
go get "github.com/sugarme/gotch@$GOTCH_VERSION"
|
||||
rm -rf /tmp/gotch-test
|
||||
cd $cwd
|
||||
|
||||
if [ $CUDA_VERSION == "cpu" ]
|
||||
then
|
||||
# prepare C lib for CPU version
|
||||
sudo rm $GOTCH/libtch/lib.go
|
||||
sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go
|
||||
sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu
|
||||
sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp
|
||||
fi
|
||||
|
||||
# Install Libtorch
|
||||
#=================
|
||||
LIBTORCH_ZIP="libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2B${CU_VERSION}.zip"
|
||||
LIBTORCH_URL="https://download.pytorch.org/libtorch/${CU_VERSION}/${LIBTORCH_ZIP}"
|
||||
wget -O "/tmp/$LIBTORCH_ZIP" "$LIBTORCH_URL"
|
||||
# delete old libtorch if existing
|
||||
sudo rm -rf $LIBTORCH
|
||||
sudo unzip "/tmp/$LIBTORCH_ZIP" -d $GOTCH/libtch
|
||||
rm "/tmp/$LIBTORCH_ZIP"
|
||||
|
||||
# Update .bashrc
|
||||
#===============
|
||||
FILE="$HOME/.bashrc-test"
|
||||
LN_GOTCH_LIBTORCH="export GOTCH_LIBTORCH=$LIBTORCH"
|
||||
LN_LIBRARY_PATH="export LIBRARY_PATH=$LIBRARY_PATH"
|
||||
LN_CPATH="export CPATH=$CPATH"
|
||||
|
@ -42,3 +79,4 @@ grep -qF 'export LD_LIBRARY_PATH' "$FILE" && sed -i 's|^export LD_LIBRARY_PATH.*
|
|||
# refresh environment for all next opening shells.
|
||||
exec "$BASH"
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
package gotch
|
||||
|
||||
var (
|
||||
GotchVersion string = "v0.3.4"
|
||||
TorchVersion string = "1.7.0"
|
||||
CudaVersion string = "10.1"
|
||||
)
|
Loading…
Reference in New Issue
Block a user