feat(tensor/patched): add NLLLoss

This commit is contained in:
sugarme 2020-08-02 16:17:49 +10:00
parent 2bd23fcc89
commit 98ca761d30

View File

@ -152,3 +152,36 @@ func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Te
return ts1, ts2
}
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
// with default weight, reduction and ignoreIndex
func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
if del {
defer ts.MustDrop()
}
weight := NewTensor()
reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100)
defer C.free(unsafe.Pointer(ptr))
lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) {
retVal, err := ts.NllLoss(target, del)
if err != nil {
log.Fatal(err)
}
return retVal
}