added nn.MSELoss()
This commit is contained in:
parent
c472134d94
commit
cc5792ecbf
|
@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Added `nn.MSELoss()`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
13
nn/loss.go
13
nn/loss.go
|
@ -107,3 +107,16 @@ func BCELoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tensor {
|
|||
loss := logits.MustSqueeze(false).MustBinaryCrossEntropyWithLogits(target, ws, posWeight, reduction, true)
|
||||
return loss
|
||||
}
|
||||
|
||||
// MSELoss calculates Mean-Square Loss.
|
||||
//
|
||||
// - reductionOpt: either 0 ("none"); 1 ("mean"); 2 ("sum"). Default=mean
|
||||
func MSELoss(logits, labels *ts.Tensor, reductionOpt ...int64) *ts.Tensor {
|
||||
reduction := int64(1)
|
||||
if len(reductionOpt) > 0 {
|
||||
reduction = reductionOpt[0]
|
||||
}
|
||||
out := logits.MustMseLoss(labels, reduction, false)
|
||||
|
||||
return out
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user