updated
This commit is contained in:
parent
e8429ee6c1
commit
724223dff3
|
@ -10,6 +10,7 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
|
@ -128,7 +129,7 @@ func main() {
|
|||
opt.SetLR(learningRate(epoch))
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
iter.Shuffle()
|
||||
// iter.Shuffle()
|
||||
// iter = iter.ToDevice(device)
|
||||
|
||||
for {
|
||||
|
@ -140,9 +141,10 @@ func main() {
|
|||
devicedData := item.Data.MustTo(vs.Device(), true)
|
||||
devicedLabel := item.Label.MustTo(vs.Device(), true)
|
||||
|
||||
bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
// bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
|
||||
logits := net.ForwardT(bimages, true)
|
||||
// logits := net.ForwardT(bimages, true)
|
||||
logits := net.ForwardT(devicedData, true)
|
||||
|
||||
// logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
|
||||
loss := logits.CrossEntropyForLogits(devicedLabel)
|
||||
|
@ -155,7 +157,7 @@ func main() {
|
|||
// item.Label.MustDrop()
|
||||
devicedData.MustDrop()
|
||||
devicedLabel.MustDrop()
|
||||
bimages.MustDrop()
|
||||
// bimages.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
}
|
||||
|
@ -165,6 +167,17 @@ func main() {
|
|||
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f\n", epoch, memUsed, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
// Print out GPU used
|
||||
nvidia := "nvidia-smi"
|
||||
cmd := exec.Command(nvidia)
|
||||
stdout, err := cmd.Output()
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
fmt.Println(string(stdout))
|
||||
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
|
|
Loading…
Reference in New Issue
Block a user