chore(colab)

This commit is contained in:
sugarme 2020-07-05 14:44:08 +10:00
parent 19d451be0a
commit d755672250

View File

@ -132,7 +132,7 @@ func main() {
fmt.Printf("max layer: %v\n", maxLayer)
// styleLayers := net.ForwardAllT(styleImg, false, maxLayer)
styleLayers := net.ForwardAllT(styleImg, false, maxLayer)
contentLayers := net.ForwardAllT(contentImg, false, maxLayer)
vs := nn.NewVarStore(device)
@ -150,14 +150,14 @@ func main() {
// var sLoss ts.Tensor
sLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
cLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
for _, _ = range StyleIndexes {
// for _, idx := range StyleIndexes {
// l := styleLoss(inputLayers[idx], styleLayers[idx])
// sLoss = sLoss.MustAdd(l, true)
// l.MustDrop()
for _, idx := range StyleIndexes {
l := styleLoss(inputLayers[idx], styleLayers[idx])
sLoss = sLoss.MustAdd(l, true)
l.MustDrop()
}
for _, idx := range ContentIndexes {
l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), true)
// l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), true)
l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), false)
cLoss = cLoss.MustAdd(l, true)
l.MustDrop()
}