gotch/ts/module.go
2022-03-12 18:20:20 +11:00

152 lines
4.1 KiB
Go

package ts
// Module interface is a container with only one method `Forward`
//
// The following is `module` concept from Pytorch documenation:
// Base class for all neural network modules. Your models should also subclass this class.
// Modules can also contain other Modules, allowing to nest them in a tree structure.
// You can assign the submodules as regular attributes. Submodules assigned in this way will
// be registered, and will have their parameters converted too when you call .cuda(), etc.
type Module interface {
// ModuleT
Forward(xs *Tensor) *Tensor
}
// ModuleT is a `Module` with an additional train parameter
// The train parameter is commonly used to have different behavior
// between training and evaluation. E.g. When using dropout or batch-normalization.
type ModuleT interface {
// Forward(xs Tensor) Tensor
ForwardT(xs *Tensor, train bool) *Tensor
}
/*
* // DefaultModuleT implements default method `BatchAccuracyForLogits`.
* // NOTE: when creating a struct that implement `ModuleT`, it should
* // include `DefaultModule` so that the 'default' methods `BatchAccuracyForLogits`
* // is automatically implemented.
* // Concept taken from Rust language trait **Default Implementation**
* // Ref: https://doc.rust-lang.org/1.22.1/book/second-edition/ch10-02-traits.html
* //
* // Example:
* //
* // type FooModule struct{
* // DefaultModuleT
* // OtherField string
* // }
* type DefaultModuleT struct{}
*
* func (dmt DefaultModuleT) Forward(xs Tensor) (retVal Tensor) {
* // TODO: implement
*
* return
* }
*
* // Implement Module interface
* func (dmt DefaultModuleT) ForwardT(xs Tensor, train bool) (retVal Tensor) {
* // TODO: implement
*
* return dmt.Forward(xs)
* }
* */
// NOTE: this func has been moved to `nn/sequential` as `NoGradGuard`
// seem not working in Go and the function needs to add varstore variable
// parameter. Hence, it is moved to `nn` to avoid cycle reference.
/*
* // BatchAccuracyForLigits calculate accuracy in batch.
* //
* // TODO: It would be nice if it is one method an object that implements ModuleT
* // interface.
* func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
*
* var (
* sumAccuracy float64 = 0.0
* sampleCount float64 = 0.0
* )
*
* _ = MustGradSetEnabled(false)
*
* iter2 := MustNewIter2(xs, ys, int64(batchSize))
* for {
* item, ok := iter2.Next()
* if !ok {
* break
* }
*
* size := float64(item.Data.MustSize()[0])
* bImages := item.Data.MustTo(d, true)
* bLabels := item.Label.MustTo(d, true)
*
* logits := m.ForwardT(bImages, false)
* acc := logits.AccuracyForLogits(bLabels)
* sumAccuracy += acc.Values()[0] * size
* sampleCount += size
*
* bImages.MustDrop()
* bLabels.MustDrop()
* acc.MustDrop()
* }
*
* _ = MustGradSetEnabled(true)
*
* return sumAccuracy / sampleCount
*
* }
* */
// Tensor methods for Module and ModuleT:
// ======================================
// Apply forwards tensor itself through a module.
func (ts *Tensor) Apply(m Module) (retVal *Tensor) {
return m.Forward(ts)
}
// Apply forwards tensor itself through a module T.
func (ts *Tensor) ApplyT(m ModuleT, train bool) (retVal *Tensor) {
return m.ForwardT(ts, train)
}
// ApplyOpt forwards a tensor itself through a module if given, shallow-copies
// the tensor otherwise.
func (ts *Tensor) ApplyOpt(opts ...ModuleOption) (retVal *Tensor) {
switch {
case len(opts) > 0:
m := opts[0]()
return m.Forward(ts)
default:
return ts.MustShallowClone()
}
}
type ModuleOption func() Module
func WithModule(m Module) ModuleOption {
return func() Module {
return m
}
}
// ApplyOptT forwards a tensor itself through a module T if given, shallow-copies
// the tensor otherwise.
func (ts *Tensor) ApplyOptT(train bool, opts ...ModuleTOption) (retVal *Tensor) {
switch {
case len(opts) > 0:
m := opts[0]()
return m.ForwardT(ts, train)
default:
return ts.MustShallowClone()
}
}
type ModuleTOption func() ModuleT
func WithModuleT(m ModuleT) ModuleTOption {
return func() ModuleT {
return m
}
}