gotch/example/mnist/main.go

43 lines
613 B
Go
Raw Normal View History

package main
2020-06-18 08:14:48 +01:00
import (
"flag"
2023-07-04 14:03:49 +01:00
2024-04-21 15:15:00 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch"
2020-06-18 08:14:48 +01:00
)
2023-07-04 14:03:49 +01:00
var (
model string
deviceOpt string
device gotch.Device
)
2020-06-18 08:14:48 +01:00
func init() {
flag.StringVar(&model, "model", "linear", "specify a model to run")
2023-07-04 14:03:49 +01:00
flag.StringVar(&deviceOpt, "device", "cpu", "specify device to run on. Eitheir 'cpu' or 'cuda'")
2020-06-18 08:14:48 +01:00
}
func main() {
2020-06-18 08:14:48 +01:00
flag.Parse()
2023-07-04 14:03:49 +01:00
if deviceOpt == "cuda" {
device = gotch.CudaIfAvailable()
} else {
device = gotch.CPU
}
2020-06-18 08:14:48 +01:00
switch model {
case "linear":
runLinear()
case "nn":
runNN()
case "cnn":
// runCNN2()
runCNN1()
2020-06-18 08:14:48 +01:00
default:
panic("No specified model to run")
}
}