package gotch
import (
// This file provides functions to work with local dataset cache, ...
// ModelUrls maps model name to its pretrained URL.
// This URLS taken from separate models in pytorch/vision repository
var ModelUrls map[string]string = map[string]string{
"alexnet": "",
"convnext_tiny": "",
"convnext_small": "",
"convnext_base": "",
"convnext_large": "",
"densenet121": "",
"densenet169": "",
"densenet201": "",
"densenet161": "",
//Weights ported from
"efficientnet_b0": "",
"efficientnet_b1": "",
"efficientnet_b2": "",
"efficientnet_b3": "",
"efficientnet_b4": "",
//Weights ported from
"efficientnet_b5": "",
"efficientnet_b6": "",
"efficientnet_b7": "",
//GoogLeNet ported from TensorFlow
"googlenet": "",
//Inception v3 ported from TensorFlow
"inception_v3_google": "",
"mnasnet0_5": "",
"mnasnet0_75": "",
"mnasnet1_0": "",
"mnasnet1_3": "",
"mobilenet_v2": "",
"mobilenet_v3_large": "",
"mobilenet_v3_small": "",
"regnet_y_400mf": "",
"regnet_y_800mf": "",
"regnet_y_1_6gf": "",
"regnet_y_3_2gf": "",
"regnet_y_8gf": "",
"regnet_y_16gf": "",
"regnet_y_32gf": "",
"regnet_x_400mf": "",
"regnet_x_800mf": "",
"regnet_x_1_6gf": "",
"regnet_x_3_2gf": "",
"regnet_x_8gf": "",
"regnet_x_16gf": "",
"regnet_x_32gf": "",
"resnet18": "",
"resnet34": "",
"resnet50": "",
"resnet101": "",
"resnet152": "",
"resnext50_32x4d": "",
"resnext101_32x8d": "",
"wide_resnet50_2": "",
"wide_resnet101_2": "",
"shufflenetv2_x0.5": "",
"shufflenetv2_x1.0": "",
"shufflenetv2_x1.5": "",
"shufflenetv2_x2.0": "",
"squeezenet1_0": "",
"squeezenet1_1": "",
"vgg11": "",
"vgg13": "",
"vgg16": "",
"vgg19": "",
"vgg11_bn": "",
"vgg13_bn": "",
"vgg16_bn": "",
"vgg19_bn": "",
"vit_b_16": "",
"vit_b_32": "",
"vit_l_16": "",
"vit_l_32": "",
// CachedPath resolves and caches data based on input string, then returns fullpath to the cached data.
// Parameters:
// - `filenameOrUrl`: full path to filename or url
// CachedPath does several things consequently:
// 1. Resolves input string to a fullpath cached filename candidate.
// 2. Check it at `CachePath`, if exists, then return the candidate. If not
// 3. Retrieves and Caches data to `CachePath` and returns path to cached data
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
filename := path.Base(filenameOrUrl)
// Resolves to "candidate" filename at `CacheDir`
cachedFileCandidate := fmt.Sprintf("%s/%s", CacheDir, filename)
// 1. Cached candidate file exists
if _, err := os.Stat(cachedFileCandidate); err == nil {
return cachedFileCandidate, nil
// 2. If valid fullpath to local file, caches it and return cached filename
if _, err := os.Stat(filenameOrUrl); err == nil {
err := copyFile(filenameOrUrl, cachedFileCandidate)
if err != nil {
return "", err
return cachedFileCandidate, nil
// 3. Cached candidate file NOT exist. Try to download it and save to `CacheDir`
if isValidURL(filenameOrUrl) {
if _, err := http.Get(filenameOrUrl); err == nil {
err := downloadFile(filenameOrUrl, cachedFileCandidate)
if err != nil {
return "", err
return cachedFileCandidate, nil
} else {
fmt.Printf("Error: %v\n", err)
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
return "", err
// Not resolves
err = fmt.Errorf("Unable to parse %q as a URL or as a local path.\n", filenameOrUrl)
return "", err
func isValidURL(url string) bool {
// TODO: implement
return true
// downloadFile downloads file from URL and stores it in local filepath.
// It writes to the destination file as it downloads it, without loading
// the entire file into memory. An `io.TeeReader` is passed into Copy()
// to report progress on the download.
func downloadFile(url string, filepath string) error {
// Create path if not existing
dir := path.Dir(filepath)
filename := path.Base(filepath)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, 0755); err != nil {
// Create the file with .tmp extension, so that we won't overwrite a
// file until it's downloaded fully
out, err := os.Create(filepath + ".tmp")
if err != nil {
return err
defer out.Close()
// Get the data
resp, err := http.Get(url)
if err != nil {
return err
defer resp.Body.Close()
// Check server response
if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("bad status: %s(%v)", resp.Status, resp.StatusCode)
if resp.StatusCode == 404 {
err = fmt.Errorf("download file not found: %q for downloading", url)
} else {
err = fmt.Errorf("download file failed: %q", url)
return err
// the total file size to download
size, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
downloadSize := uint64(size)
// Create our bytes counter and pass it to be used alongside our writer
counter := &writeCounter{FileSize: downloadSize}
_, err = io.Copy(out, io.TeeReader(resp.Body, counter))
if err != nil {
return err
fmt.Printf("\r%s... %s/%s completed", filename, byteCountIEC(counter.Total), byteCountIEC(counter.FileSize))
// The progress use the same line so print a new line once it's finished downloading
// Rename the tmp file back to the original file
err = os.Rename(filepath+".tmp", filepath)
if err != nil {
return err
return nil
// writeCounter counts the number of bytes written to it. By implementing the Write method,
// it is of the io.Writer interface and we can pass this into io.TeeReader()
// Every write to this writer, will print the progress of the file write.
type writeCounter struct {
Total uint64
FileSize uint64
func (wc *writeCounter) Write(p []byte) (int, error) {
n := len(p)
wc.Total += uint64(n)
return n, nil
// PrintProgress prints the progress of a file write
func (wc writeCounter) printProgress() {
// Clear the line by using a character return to go back to the start and remove
// the remaining characters by filling it with spaces
fmt.Printf("\r%s", strings.Repeat(" ", 50))
// Return again and print current status of download
fmt.Printf("\rDownloading... %s/%s", byteCountIEC(wc.Total), byteCountIEC(wc.FileSize))
// byteCountIEC converts bytes to human-readable string in binary (IEC) format.
func byteCountIEC(b uint64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
div, exp := uint64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
func copyFile(src, dst string) error {
sourceFileStat, err := os.Stat(src)
if err != nil {
return err
if !sourceFileStat.Mode().IsRegular() {
return fmt.Errorf("%s is not a regular file", src)
source, err := os.Open(src)
if err != nil {
return err
defer source.Close()
destination, err := os.Create(dst)
if err != nil {
return err
defer destination.Close()
_, err = io.Copy(destination, source)
return err