chore(update)

This commit is contained in:
sugarme 2020-07-05 14:38:54 +10:00
parent 19b4122347
commit 6bdf7d585c

View File

@ -132,64 +132,59 @@ func main() {
fmt.Printf("max layer: %v\n", maxLayer)
// styleLayers := net.ForwardAllT(styleImg, false, maxLayer)
// contentLayers := net.ForwardAllT(contentImg, false, maxLayer)
styleLayers := net.ForwardAllT(styleImg, false, maxLayer)
contentLayers := net.ForwardAllT(contentImg, false, maxLayer)
vs := nn.NewVarStore(device)
path := vs.Root()
inputVar := path.VarCopy("img", contentImg)
// opt, err := nn.DefaultAdamConfig().Build(vs, LearningRate)
// if err != nil {
// log.Fatal(err)
// }
opt, err := nn.DefaultAdamConfig().Build(vs, LearningRate)
if err != nil {
log.Fatal(err)
}
// styleWeight := ts.FloatScalar(StyleWeight)
styleWeight := ts.FloatScalar(StyleWeight)
for stepIdx := 1; stepIdx <= int(TotalSteps); stepIdx++ {
inputLayers := net.ForwardAllT(inputVar, false, maxLayer)
fmt.Printf("Step %v ...done\n", stepIdx)
// var sLoss ts.Tensor
sLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
cLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
for _, _ = range StyleIndexes {
// for _, idx := range StyleIndexes {
// l := styleLoss(inputLayers[idx], styleLayers[idx])
// sLoss = sLoss.MustAdd(l, true)
// l.MustDrop()
}
for _, idx := range ContentIndexes {
l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), true)
cLoss = cLoss.MustAdd(l, true)
l.MustDrop()
}
for _, t := range inputLayers {
t.MustDrop()
}
/*
* // var sLoss ts.Tensor
* sLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
* cLoss := ts.MustZeros([]int64{1}, gotch.Float.CInt(), device.CInt())
* for _, idx := range StyleIndexes {
* l := styleLoss(inputLayers[idx], styleLayers[idx])
* sLoss = sLoss.MustAdd(l, true)
* l.MustDrop()
* }
* for _, idx := range ContentIndexes {
* l := inputLayers[idx].MustMseLoss(contentLayers[idx], ts.ReductionMean.ToInt(), true)
* cLoss = cLoss.MustAdd(l, true)
* l.MustDrop()
* }
*
* for _, t := range inputLayers {
* t.MustDrop()
* }
*
* lossMul := sLoss.MustMul1(styleWeight, true)
* loss := lossMul.MustAdd(cLoss, true)
* opt.BackwardStep(loss)
*
* if (stepIdx % 10) == 0 {
* clone := inputVar.MustShallowClone()
* img := clone.MustDetach()
* clone.MustDrop()
* err := in.SaveImage(img, fmt.Sprintf("../../data/neural-style-transfer/out%v.jpg", stepIdx))
* if err != nil {
* log.Fatal(err)
* }
* img.MustDrop()
* }
*
* fmt.Printf("Step %v ... Done. Loss %10.1f\n", stepIdx, loss.Values()[0])
* cLoss.MustDrop()
* loss.MustDrop()
* */
lossMul := sLoss.MustMul1(styleWeight, true)
loss := lossMul.MustAdd(cLoss, true)
opt.BackwardStep(loss)
if (stepIdx % 1000) == 0 {
clone := inputVar.MustShallowClone()
img := clone.MustDetach()
clone.MustDrop()
err := in.SaveImage(img, fmt.Sprintf("../../data/neural-style-transfer/out%v.jpg", stepIdx))
if err != nil {
log.Fatal(err)
}
img.MustDrop()
}
fmt.Printf("Step %v ... Done. Loss %10.1f\n", stepIdx, loss.Values()[0])
cLoss.MustDrop()
loss.MustDrop()
}
}