added nn.MSELoss()

This commit is contained in:
sugarme 2022-03-16 20:46:04 +11:00
parent c472134d94
commit cc5792ecbf
2 changed files with 14 additions and 0 deletions

View File

@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Added `nn.MSELoss()`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -107,3 +107,16 @@ func BCELoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tensor {
loss := logits.MustSqueeze(false).MustBinaryCrossEntropyWithLogits(target, ws, posWeight, reduction, true)
return loss
}
// MSELoss calculates Mean-Square Loss.
//
// - reductionOpt: either 0 ("none"); 1 ("mean"); 2 ("sum"). Default=mean
func MSELoss(logits, labels *ts.Tensor, reductionOpt ...int64) *ts.Tensor {
reduction := int64(1)
if len(reductionOpt) > 0 {
reduction = reductionOpt[0]
}
out := logits.MustMseLoss(labels, reduction, false)
return out
}