Clean up
This commit is contained in:
parent
34e87b1302
commit
f9cb2f5cc6
|
@ -1,7 +1,6 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
|
|
|
@ -482,6 +482,11 @@ func (vs *VarStore) ToBFloat16() {
|
|||
vs.Root().ToBFloat16()
|
||||
}
|
||||
|
||||
func (vs *VarStore) ToDevice(device gotch.Device) {
|
||||
p := vs.Root()
|
||||
p.ToDevice(device)
|
||||
}
|
||||
|
||||
// Path methods:
|
||||
// =============
|
||||
|
||||
|
@ -745,6 +750,21 @@ func (p *Path) ToBFloat16() {
|
|||
p.toFloat(gotch.BFloat16)
|
||||
}
|
||||
|
||||
func (p *Path) ToDevice(device gotch.Device) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
path := strings.Join(p.path, SEP)
|
||||
for name, v := range p.varstore.vars {
|
||||
if strings.Contains(name, path) {
|
||||
newVar := v
|
||||
newVar.Tensor = v.Tensor.MustTo(device, true)
|
||||
p.varstore.vars[name] = newVar
|
||||
}
|
||||
}
|
||||
|
||||
ts.CleanUp(2000)
|
||||
}
|
||||
|
||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
|
|
Loading…
Reference in New Issue
Block a user