diff --git a/tensor/patch.go b/tensor/patch.go index 43a6d69..2d96582 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -152,3 +152,36 @@ func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Te return ts1, ts2 } + +// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated +// with default weight, reduction and ignoreIndex +func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + if del { + defer ts.MustDrop() + } + + weight := NewTensor() + + reduction := int64(1) // Mean of loss + ignoreIndex := int64(-100) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) { + retVal, err := ts.NllLoss(target, del) + if err != nil { + log.Fatal(err) + } + + return retVal +}