This commit is contained in:
sugarme 2023-07-07 13:20:51 +10:00
parent 34e87b1302
commit f9cb2f5cc6
2 changed files with 20 additions and 1 deletions

View File

@ -1,7 +1,6 @@
package nn
import (
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)

View File

@ -482,6 +482,11 @@ func (vs *VarStore) ToBFloat16() {
vs.Root().ToBFloat16()
}
func (vs *VarStore) ToDevice(device gotch.Device) {
p := vs.Root()
p.ToDevice(device)
}
// Path methods:
// =============
@ -745,6 +750,21 @@ func (p *Path) ToBFloat16() {
p.toFloat(gotch.BFloat16)
}
func (p *Path) ToDevice(device gotch.Device) {
p.varstore.Lock()
defer p.varstore.Unlock()
path := strings.Join(p.path, SEP)
for name, v := range p.varstore.vars {
if strings.Contains(name, path) {
newVar := v
newVar.Tensor = v.Tensor.MustTo(device, true)
p.varstore.vars[name] = newVar
}
}
ts.CleanUp(2000)
}
// ZerosNoTrain creates a new variable initialized with zeros.
//
// The new variable is named according to the name parameter and