feat(tensor/patched): add NLLLoss
This commit is contained in:
parent
2bd23fcc89
commit
98ca761d30
|
@ -152,3 +152,36 @@ func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Te
|
||||||
|
|
||||||
return ts1, ts2
|
return ts1, ts2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
|
||||||
|
// with default weight, reduction and ignoreIndex
|
||||||
|
func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
if del {
|
||||||
|
defer ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
weight := NewTensor()
|
||||||
|
|
||||||
|
reduction := int64(1) // Mean of loss
|
||||||
|
ignoreIndex := int64(-100)
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) {
|
||||||
|
retVal, err := ts.NllLoss(target, del)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user