gotch/ts/other.go
2023-07-07 22:30:08 +10:00

28 lines
1000 B
Go

package ts
// Other tensor methods
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
weight := NewTensor()
reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100)
dtype := ts.DType()
logSm := ts.MustLogSoftmax(-1, dtype, false)
return logSm.MustNllLoss(targets, weight, reduction, ignoreIndex, true)
}
// AccuracyForLogits returns the average accuracy for some given logits assuming that
// targets represent ground-truth.
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
argmax := ts.MustArgmax([]int64{-1}, false, false)
eq1 := argmax.MustEqTensor(targets, true)
dtype := ts.DType()
return eq1.MustTotype(dtype, true).MustMean(dtype, true)
}
func (ts *Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal *Tensor) {
return ts.MustMaxPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
}