added API Path.Remove()

This commit is contained in:
sugarme 2022-03-12 21:57:23 +11:00
parent 10cf9ff568
commit 664928551b
2 changed files with 25 additions and 0 deletions

View File

@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Added API `Path.Remove()`; `Path.MustRemove()`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -606,6 +606,30 @@ func (p *Path) MustAdd(name string, x *ts.Tensor, trainable bool, opts ...AddOpt
return x
}
// Remove removes a variable from `VarStore`
func (p *Path) Remove(name string) error {
p.varstore.Lock()
defer p.varstore.Unlock()
_, ok := p.varstore.vars[name]
if !ok {
err := fmt.Errorf("Path.Remove() failed: cannot find a variable with name %q in VarStore.", name)
return err
}
delete(p.varstore.vars, name)
return nil
}
// MustRemove removes a variable from `VarStore`
func (p *Path) MustRemove(name string) {
err := p.Remove(name)
if err != nil {
err = fmt.Errorf("Path.MustRemove() failed: %w", err)
log.Fatal(err)
}
}
func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) {
path := p.getpath(name)