28 lines
1000 B
Go
28 lines
1000 B
Go
package ts
|
|
|
|
// Other tensor methods
|
|
|
|
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
|
|
func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
|
|
weight := NewTensor()
|
|
reduction := int64(1) // Mean of loss
|
|
ignoreIndex := int64(-100)
|
|
|
|
dtype := ts.DType()
|
|
logSm := ts.MustLogSoftmax(-1, dtype, false)
|
|
return logSm.MustNllLoss(targets, weight, reduction, ignoreIndex, true)
|
|
}
|
|
|
|
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
|
// targets represent ground-truth.
|
|
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
|
|
argmax := ts.MustArgmax([]int64{-1}, false, false)
|
|
eq1 := argmax.MustEqTensor(targets, true)
|
|
dtype := ts.DType()
|
|
return eq1.MustTotype(dtype, true).MustMean(dtype, true)
|
|
}
|
|
|
|
func (ts *Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal *Tensor) {
|
|
return ts.MustMaxPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
|
|
}
|