diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index 1e5aaa8..b8020c8 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -74,9 +74,8 @@ func runCNN1() { testImages := ds.TestImages testLabels := ds.TestLabels - cuda := gotch.CudaBuilder(0) - vs := nn.NewVarStore(cuda.CudaIfAvailable()) - // vs := nn.NewVarStore(gotch.CPU) + device := gotch.CudaIfAvailable() + vs := nn.NewVarStore(device) net := newNet(vs.Root()) opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)