diff --git a/CHANGELOG.md b/CHANGELOG.md index 20085ee..246ee57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,11 +65,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - nn/sequential: fixed missing case number of layers = 1 causing panic - nn/varstore: fixed(nn/varstore): fixed nil pointer at LoadPartial due to not break loop -# [0.3.4] +## [0.3.4] ### Added - [#4] Automatically download and install Libtorch and setup environment variables. +## [0.3.5] + +### Added +- Added function `gotch.CudaIfAvailable()`. NOTE that: `device := gotch.NewCuda().CudaIfAvailable()` will throw error if CUDA is not available. + +### Changed +- Switched back to install libtorch inside gotch library as go init() function is triggered after cgo called. + + [#10]: https://github.com/sugarme/gotch/issues/10 [#6]: https://github.com/sugarme/gotch/issues/6 [#4]: https://github.com/sugarme/gotch/issues/4 diff --git a/README.md b/README.md index f8897b1..77da0e7 100644 --- a/README.md +++ b/README.md @@ -22,14 +22,22 @@ Gotch is in active development mode and may have API breaking changes. Feel free ## Installation -When installing `gotch` with `go get`, it will automatically download and install corresponding `Libtorch` and cache at `$HOME/.cache/gotch/LIBTORCH_VERSION` directory. - - Default CUDA version is `10.1` if CUDA is available otherwise using CPU version. - Default Pytorch C++ API version is `1.7.0` ```bash - go get github.com/sugarme/gotch@v0.3.4 + wget https://github.com/sugarme/gotch/blob/v0.3.5/setup.sh + chmod +x setup.sh + + # Default + sudo bash setup.sh + + # Specify CUDA version + export CUDA_VER=YOUR_PC_CUDA_VERSION && sudo bash setup.sh + + # CPU + export CUDA_VER=cpu && sudo bash setup.sh ``` diff --git a/device.go b/device.go index 9180466..38df73e 100644 --- a/device.go +++ b/device.go @@ -114,3 +114,13 @@ func (d Device) IsCuda() bool { return true } + +// CudaIfAvailable returns a GPU device if available, else CPU. +func CudaIfAvailable() Device { + switch { + case CUDA.IsAvailable(): + return CudaBuilder(0) + default: + return CPU + } +} diff --git a/file-util.go b/file-util.go deleted file mode 100644 index e692981..0000000 --- a/file-util.go +++ /dev/null @@ -1,194 +0,0 @@ -package gotch - -import ( - "archive/zip" - "fmt" - "io" - "log" - "net/http" - "os" - "path/filepath" - "runtime" - "strings" -) - -// downloadFile downloads file from URL and stores it in local filepath. -// It writes to the destination file as it downloads it, without loading -// the entire file into memory. An `io.TeeReader` is passed into Copy() -// to report progress on the download. -func downloadFile(url string, filepath string) error { - - // Create the file with .tmp extension, so that we won't overwrite a - // file until it's downloaded fully - out, err := os.Create(filepath + ".tmp") - if err != nil { - return err - } - defer out.Close() - - // Get the data - resp, err := http.Get(url) - if err != nil { - return err - } - defer resp.Body.Close() - - // Create our bytes counter and pass it to be used alongside our writer - counter := &writeCounter{} - _, err = io.Copy(out, io.TeeReader(resp.Body, counter)) - if err != nil { - return err - } - - // The progress use the same line so print a new line once it's finished downloading - fmt.Println() - - // Rename the tmp file back to the original file - err = os.Rename(filepath+".tmp", filepath) - if err != nil { - return err - } - - return nil -} - -// writeCounter counts the number of bytes written to it. By implementing the Write method, -// it is of the io.Writer interface and we can pass this into io.TeeReader() -// Every write to this writer, will print the progress of the file write. -type writeCounter struct { - Total uint64 -} - -func (wc *writeCounter) Write(p []byte) (int, error) { - n := len(p) - wc.Total += uint64(n) - wc.printProgress() - return n, nil -} - -// PrintProgress prints the progress of a file write -func (wc writeCounter) printProgress() { - // Clear the line by using a character return to go back to the start and remove - // the remaining characters by filling it with spaces - fmt.Printf("\r%s", strings.Repeat(" ", 50)) - - // Return again and print current status of download - fmt.Printf("\rDownloading... %s complete", byteCountIEC(wc.Total)) -} - -// byteCountIEC converts bytes to human-readable string in binary (IEC) format. -func byteCountIEC(b uint64) string { - const unit = 1024 - if b < unit { - return fmt.Sprintf("%d B", b) - } - div, exp := uint64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %ciB", - float64(b)/float64(div), "KMGTPE"[exp]) -} - -func copyFile(src, dst string) error { - sourceFileStat, err := os.Stat(src) - if err != nil { - return err - } - - if !sourceFileStat.Mode().IsRegular() { - return fmt.Errorf("%s is not a regular file", src) - } - - source, err := os.Open(src) - if err != nil { - return err - } - defer source.Close() - - destination, err := os.Create(dst) - if err != nil { - return err - } - defer destination.Close() - _, err = io.Copy(destination, source) - return err -} - -// Unzip unzips file -// Ref. https://stackoverflow.com/questions/20357223 -func Unzip(src, dest string) error { - r, err := zip.OpenReader(src) - if err != nil { - return err - } - defer func() { - if err := r.Close(); err != nil { - panic(err) - } - }() - - os.MkdirAll(dest, 0755) - - // Closure to address file descriptors issue with all the deferred .Close() methods - extractAndWriteFile := func(f *zip.File) error { - rc, err := f.Open() - if err != nil { - return err - } - defer func() { - if err := rc.Close(); err != nil { - panic(err) - } - }() - - path := filepath.Join(dest, f.Name) - - // Check for ZipSlip (Directory traversal) - if !strings.HasPrefix(path, filepath.Clean(dest)+string(os.PathSeparator)) { - return fmt.Errorf("illegal file path: %s", path) - } - - if f.FileInfo().IsDir() { - os.MkdirAll(path, f.Mode()) - } else { - os.MkdirAll(filepath.Dir(path), f.Mode()) - f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return err - } - defer func() { - if err := f.Close(); err != nil { - panic(err) - } - }() - - _, err = io.Copy(f, rc) - if err != nil { - return err - } - } - return nil - } - - for _, f := range r.File { - err := extractAndWriteFile(f) - if err != nil { - return err - } - } - - return nil -} - -// CurrDir returns absolute path of the file contains code calling this function. -func CurrDir() string { - // pc, file, line, ok := runtime.Caller(1) - _, file, _, ok := runtime.Caller(1) - if !ok { - log.Fatal("Runtime Caller error: cannot get information of current caller.") - } - - return filepath.Dir(file) -} diff --git a/init.go b/init.go deleted file mode 100644 index 36deebe..0000000 --- a/init.go +++ /dev/null @@ -1,119 +0,0 @@ -package gotch - -import ( - // "flag" - "fmt" - "log" - "os" - "os/exec" - "os/user" - "path" - "strings" -) - -var ( - GotchCacheDir string - TorchDir string - IsCudaAvailable bool = false - HomeDir string -) - -func init() { - if v, ok := os.LookupEnv("TORCH_VERSION"); ok { - TorchVersion = v - } - if v, ok := os.LookupEnv("TORCH_DIR"); ok { - TorchDir = v - } - - IsCudaAvailable = isCudaAvailable() - if !isCudaAvailable() { - CudaVersion = "cpu" - } - - HomeDir = homeDir() - GotchCacheDir = fmt.Sprintf("%v/.cache/gotch", HomeDir) - - InstallLibtorch() -} - -// isCudaAvailable check whether cuda is installed using 'nvcc' -func isCudaAvailable() bool { - _, err := exec.Command("nvcc", "--version").Output() - if err != nil { - log.Printf("CUDA is not detected using 'nvcc --version'\n") - return false - } - return true -} - -// homeDir returns home directory of current user. -func homeDir() string { - usr, err := user.Current() - if err != nil { - log.Fatal(err) - } - - return usr.HomeDir -} - -func InstallLibtorch() { - dev := strings.ReplaceAll(CudaVersion, ".", "") - cu := "cpu" - if !strings.Contains(dev, "cpu") { - cu = fmt.Sprintf("cu%v", dev) - } - url := fmt.Sprintf("https://download.pytorch.org/libtorch/%v/libtorch-cxx11-abi-shared-with-deps-%v%%2B%v.zip", cu, TorchVersion, cu) - - // Create dir if not exist - TorchDir = fmt.Sprintf("%v/libtorch-%v-%v", GotchCacheDir, TorchVersion, cu) - if _, err := os.Stat(TorchDir); os.IsNotExist(err) { - os.MkdirAll(TorchDir, 0755) - } - // install libtorch if not done yet. - installLibtorch(url) - - // Export to current context so that shell script can catch them - libtorch := fmt.Sprintf("%v/libtorch", TorchDir) - os.Setenv("GOTCH_LIBTORCH", libtorch) - os.Setenv("GOTCH_CUDA_VERSION", CudaVersion) - os.Setenv("GOTCH_VER", GotchVersion) - - envFile := fmt.Sprintf("%v/env.sh", CurrDir()) - _, err := exec.Command(envFile).Output() - if err != nil { - log.Fatal(err) - } -} - -func installLibtorch(url string) { - filename := path.Base(url) - cachedFileCandidate := fmt.Sprintf("%s/%v", TorchDir, filename) - - // Check whether zip file exists otherwise, download it - if _, err := os.Stat(cachedFileCandidate); err == nil { - // zip file exists - // check one file 'libc10.so' if exists, assuming libtorch has been installed. - libc10 := fmt.Sprintf("%v/libtorch/lib/libc10.so", TorchDir) - if _, err := os.Stat(libc10); err == nil { - return - } - } else { - // download zip file - err := downloadFile(url, cachedFileCandidate) - if err != nil { - log.Fatal(err) - } - } - - // Unzip - log.Printf("Unzipping...") - err := Unzip(cachedFileCandidate, TorchDir) - if err != nil { - log.Fatal(err) - } -} - -func Hello() { - fmt.Println("Hello") -} diff --git a/env.sh b/setup.sh similarity index 61% rename from env.sh rename to setup.sh index 159c22b..6b7dc2d 100755 --- a/env.sh +++ b/setup.sh @@ -1,11 +1,19 @@ #!/bin/bash -# Catch info from Go exported environment variables -GOTCH_VERSION="${GOTCH_VER:-NotSpecified}" -LIBTORCH="${GOTCH_LIBTORCH:-NotSpecified}" # Libtorch root path -CUDA_VERSION="${GOTCH_CUDA_VERSION:-NotSpecified}" # e.g 10.1; cpu +GOTCH_VERSION="${GOTCH_VER:-v0.3.5}" +LIBTORCH_VERSION="${LIBTORCH_VER:-1.7.0}" +CUDA_VERSION="${CUDA_VER:-10.1}" +if [ $CUDA_VERSION == "cpu" ] +then + CU_VERSION="cpu" +else + CU_VERSION="cu${CUDA_VERSION//./}" +fi + +# Libtorch paths GOTCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION" +LIBTORCH="$GOPATH/pkg/mod/github.com/sugarme/gotch@$GOTCH_VERSION/libtch/libtorch" LIBRARY_PATH="$LIBTORCH/lib" # CPATH="$CPATH:$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include" CPATH="$LIBTORCH/lib:$LIBTORCH/include:$LIBTORCH/include/torch/csrc/api/include" @@ -14,19 +22,48 @@ if [ $CUDA_VERSION == "cpu" ] then # LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib" LD_LIBRARY_PATH="$LIBTORCH/lib" - - # prepare C lib for CPU version - sudo rm $GOTCH/libtch/lib.go - sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go - sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu - sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp else # LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64" LD_LIBRARY_PATH="$LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64" fi -#update .bashrc -FILE="$HOME/.bashrc" +# Update current shell environment variables for newly installed Libtorch +export LIBRARY_PATH=$LIBRARY_PATH +export CPATH=$CPATH +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH + +# Install gotch +#============== +cwd=$(pwd) +mkdir /tmp/gotch-test +cd /tmp/gotch-test +go mod init "github.com/sugarme/gotch-test" +go get "github.com/sugarme/gotch@$GOTCH_VERSION" +rm -rf /tmp/gotch-test +cd $cwd + +if [ $CUDA_VERSION == "cpu" ] +then + # prepare C lib for CPU version + sudo rm $GOTCH/libtch/lib.go + sudo cp $GOTCH/libtch/lib.go.cpu $GOTCH/libtch/lib.go + sudo mv $GOTCH/libtch/dummy_cuda_dependency.cpp $GOTCH/libtch/dummy_cuda_dependency.cpp.gpu + sudo mv $GOTCH/libtch/fake_cuda_dependency.cpp.cpu $GOTCH/libtch/fake_cuda_dependency.cpp +fi + +# Install Libtorch +#================= +LIBTORCH_ZIP="libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2B${CU_VERSION}.zip" +LIBTORCH_URL="https://download.pytorch.org/libtorch/${CU_VERSION}/${LIBTORCH_ZIP}" +wget -O "/tmp/$LIBTORCH_ZIP" "$LIBTORCH_URL" +# delete old libtorch if existing +sudo rm -rf $LIBTORCH +sudo unzip "/tmp/$LIBTORCH_ZIP" -d $GOTCH/libtch +rm "/tmp/$LIBTORCH_ZIP" + +# Update .bashrc +#=============== +FILE="$HOME/.bashrc-test" LN_GOTCH_LIBTORCH="export GOTCH_LIBTORCH=$LIBTORCH" LN_LIBRARY_PATH="export LIBRARY_PATH=$LIBRARY_PATH" LN_CPATH="export CPATH=$CPATH" @@ -42,3 +79,4 @@ grep -qF 'export LD_LIBRARY_PATH' "$FILE" && sed -i 's|^export LD_LIBRARY_PATH.* # refresh environment for all next opening shells. exec "$BASH" + diff --git a/version.go b/version.go deleted file mode 100644 index 8a6cbd1..0000000 --- a/version.go +++ /dev/null @@ -1,7 +0,0 @@ -package gotch - -var ( - GotchVersion string = "v0.3.4" - TorchVersion string = "1.7.0" - CudaVersion string = "10.1" -)