From 98ca761d3014cdd786841969c44e809af3261c6c Mon Sep 17 00:00:00 2001 From: sugarme Date: Sun, 2 Aug 2020 16:17:49 +1000 Subject: [PATCH] feat(tensor/patched): add NLLLoss --- tensor/patch.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tensor/patch.go b/tensor/patch.go index 43a6d69..2d96582 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -152,3 +152,36 @@ func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Te return ts1, ts2 } + +// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated +// with default weight, reduction and ignoreIndex +func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + if del { + defer ts.MustDrop() + } + + weight := NewTensor() + + reduction := int64(1) // Mean of loss + ignoreIndex := int64(-100) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) { + retVal, err := ts.NllLoss(target, del) + if err != nil { + log.Fatal(err) + } + + return retVal +}