feat(tensor/patched): add NLLLoss
This commit is contained in:
parent
2bd23fcc89
commit
98ca761d30
|
@ -152,3 +152,36 @@ func (ts Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1 Te
|
|||
|
||||
return ts1, ts2
|
||||
}
|
||||
|
||||
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
|
||||
// with default weight, reduction and ignoreIndex
|
||||
func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
weight := NewTensor()
|
||||
|
||||
reduction := int64(1) // Mean of loss
|
||||
ignoreIndex := int64(-100)
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.NllLoss(target, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user