fix(example/neural-style-transfer): fixed panic at GPU test on Colab
This commit is contained in:
parent
d755672250
commit
8fc6efca1c
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
|
@ -143,6 +144,7 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
styleWeight := ts.FloatScalar(StyleWeight)
|
||||
for stepIdx := 1; stepIdx <= int(TotalSteps); stepIdx++ {
|
||||
inputLayers := net.ForwardAllT(inputVar, false, maxLayer)
|
||||
|
@ -156,7 +158,7 @@ func main() {
|
|||
l.MustDrop()
|
||||
}
|
||||
for _, idx := range ContentIndexes {
|
||||
// l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), true)
|
||||
// NOTE: set `del` = true called panic at GPU train (tested on Colab)
|
||||
l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), false)
|
||||
cLoss = cLoss.MustAdd(l, true)
|
||||
l.MustDrop()
|
||||
|
@ -173,18 +175,20 @@ func main() {
|
|||
if (stepIdx % 1000) == 0 {
|
||||
clone := inputVar.MustShallowClone()
|
||||
img := clone.MustDetach()
|
||||
imageTs := img.MustTo(device, true)
|
||||
clone.MustDrop()
|
||||
err := in.SaveImage(img, fmt.Sprintf("../../data/neural-style-transfer/out%v.jpg", stepIdx))
|
||||
err := in.SaveImage(imageTs, fmt.Sprintf("../../data/neural-style-transfer/out%v.jpg", stepIdx))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
img.MustDrop()
|
||||
imageTs.MustDrop()
|
||||
}
|
||||
|
||||
fmt.Printf("Step %v ... Done. Loss %10.1f\n", stepIdx, loss.Values()[0])
|
||||
cLoss.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
}
|
||||
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user