added API Path.Remove()
This commit is contained in:
parent
10cf9ff568
commit
664928551b
|
@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Added API `Path.Remove()`; `Path.MustRemove()`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -606,6 +606,30 @@ func (p *Path) MustAdd(name string, x *ts.Tensor, trainable bool, opts ...AddOpt
|
|||
return x
|
||||
}
|
||||
|
||||
// Remove removes a variable from `VarStore`
|
||||
func (p *Path) Remove(name string) error {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
|
||||
_, ok := p.varstore.vars[name]
|
||||
if !ok {
|
||||
err := fmt.Errorf("Path.Remove() failed: cannot find a variable with name %q in VarStore.", name)
|
||||
return err
|
||||
}
|
||||
|
||||
delete(p.varstore.vars, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustRemove removes a variable from `VarStore`
|
||||
func (p *Path) MustRemove(name string) {
|
||||
err := p.Remove(name)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Path.MustRemove() failed: %w", err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
path := p.getpath(name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user