From 8baf32d860339aebfac3c6a11434f3b2b73a237a Mon Sep 17 00:00:00 2001 From: sugarme Date: Fri, 23 Jul 2021 16:26:39 +1000 Subject: [PATCH] fixed initializing CUDA device at example/mnist cnn model --- example/mnist/cnn.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index 1e5aaa8..b8020c8 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -74,9 +74,8 @@ func runCNN1() { testImages := ds.TestImages testLabels := ds.TestLabels - cuda := gotch.CudaBuilder(0) - vs := nn.NewVarStore(cuda.CudaIfAvailable()) - // vs := nn.NewVarStore(gotch.CPU) + device := gotch.CudaIfAvailable() + vs := nn.NewVarStore(device) net := newNet(vs.Root()) opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)