package my_nn import ( "fmt" "reflect" "sort" "strings" "sync" "github.com/charmbracelet/log" "git.andr3h3nriqu3s.com/andr3/gotch" "git.andr3h3nriqu3s.com/andr3/gotch/nn" "git.andr3h3nriqu3s.com/andr3/gotch/ts" ) // SEP is a separator to separate path elements in the tensor names. const SEP = "." type Var struct { Tensor *ts.Tensor Group uint // optimizer parameter group Type string // can be "parameter" or "buffer" Trainable bool // marked this variable is either trainable or not.For "buffer" type, it's always `false` Persitent bool // only applied to "buffer" type. All parameters are persistent (when do VarStore.Save()). } // VarStore is used to store variables used by one or multiple layers. // It specifies a SINGLE device where all variables are stored. type VarStore struct { sync.Mutex device gotch.Device vars map[string]Var } // Path is variable store with an associated path for variables naming. type Path struct { path []string varstore *VarStore group uint // optimizer parameter group } // Entry holds an entry corresponding to a given name in Path. type Entry struct { name string path *Path } // NewVarStore creates a new variable store located on the specified device func NewVarStore(device gotch.Device) *VarStore { return &VarStore{ device: device, vars: make(map[string]Var, 0), } } // NOTE: // To get (initiate) a path, call vs.Root() // VarStore methods: // ================= // Device returns device for this VarStore. func (vs *VarStore) Device() gotch.Device { return vs.device } // Len returns the number of tensors currently kept in this VarStore. func (vs *VarStore) Len() int { vs.Lock() defer vs.Unlock() return len(vs.vars) } // IsEmpty returns true if no tensors currently kept in this VarStore. func (vs *VarStore) IsEmpty() bool { vs.Lock() defer vs.Unlock() return (len(vs.vars) == 0) } // TrainableVariabless returns reference to all trainable variables kept in VarStore. func (vs *VarStore) TrainableVariables() []*ts.Tensor { vs.Lock() defer vs.Unlock() var trainables []*ts.Tensor for _, v := range vs.vars { x := v.Tensor if x.MustRequiresGrad() { trainables = append(trainables, x) } } return trainables } func (vs *VarStore) GetVarOfName(name string) Var { return vs.vars[name] } func (vs *VarStore) GetTensorOfVar(name string) *ts.Tensor { return vs.vars[name].Tensor } func (vs *VarStore) UpdateVarTensor(name string, tensor *ts.Tensor, del bool) { if (del) { or_panic(vs.vars[name].Tensor.Drop()) } v := vs.vars[name] v.Tensor = tensor vs.vars[name] = v } // Variables returns reference of all variables and their names in a map[variable_name]Tensor // // NOTE. returned map includes all variables of "parameter" and "buffer" type. func (vs *VarStore) Variables() map[string]ts.Tensor { vs.Lock() defer vs.Unlock() namedTensors := make(map[string]ts.Tensor, 0) for k, v := range vs.vars { namedTensors[k] = *v.Tensor } return namedTensors } // Root gets the root path for this VarStore. // // NOTE: Variables are named and organized using paths. This function returns // the top level path for the var store and can be combined with '/' // to create sub-paths. func (vs *VarStore) Root() *Path { return &Path{ path: []string{}, varstore: vs, group: 0, } } // Save saves the VarStore variable values to a file. // // NOTE: Weight values for all the tensors currently stored in the // var-store gets saved in the given file. func (vs *VarStore) Save(filepath string) error { vs.Lock() defer vs.Unlock() var namedTensors []ts.NamedTensor for k, v := range vs.vars { if v.Type == "parameter" || (v.Type == "buffer" && v.Persitent) { namedTensors = append(namedTensors, ts.NamedTensor{ Name: k, Tensor: v.Tensor, }) } } // return ts.SaveMulti(namedTensors, filepath) return ts.SaveMultiNew(namedTensors, filepath) } // Load loads VarStore variable values from a file. // // NOTE: Weight values for all the tensors currently stored in the // VarStore gets loaded from the given file. Note that the set of // variables stored in the VarStore is not changed, only the values // for these tensors are modified. // It will throw error if name of the loaded tensors can not find // in the current VarStore named tensors set. func (vs *VarStore) Load(filepath string) error { namedTensors, err := ts.LoadMultiWithDevice(filepath, vs.device) if err != nil { return err } var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) for _, namedTensor := range namedTensors { namedTensorsMap[namedTensor.Name] = namedTensor.Tensor } // Match and in-place copy value (update) from newly loaded tensors // to existing named tensors if name is matched. Throw error otherwise. vs.Lock() defer vs.Unlock() for name, v := range vs.vars { // missing variable currTs, ok := namedTensorsMap[name] if !ok { err = fmt.Errorf("VarStore.Load() failed: there's a tensor with name %q in VarStore, but not found in the loaded weights.\n", name) return err } // mismatched shape sourceShape := currTs.MustSize() destShape := v.Tensor.MustSize() if !reflect.DeepEqual(destShape, sourceShape) { err = fmt.Errorf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", name, destShape, sourceShape) return err } ts.NoGrad(func() { v.Tensor.Copy_(currTs) }) } for _, x := range namedTensors { x.Tensor.MustDrop() } ts.CleanUp() return nil } // LoadWeights loads pretrained weights to VarStore. func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error { var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) for _, namedTensor := range namedTensors { namedTensorsMap[namedTensor.Name] = namedTensor.Tensor } // Match and in-place copy value (update) from newly loaded tensors // to existing named tensors if name is matched. Throw error otherwise. vs.Lock() defer vs.Unlock() for name, v := range vs.vars { // missing variable currTs, ok := namedTensorsMap[name] if !ok { err := fmt.Errorf("VarStore.LoadWeights() failed: there's a tensor with name %q in VarStore, but not found in the loaded weights.\n", name) return err } // mismatched shape sourceShape := currTs.MustSize() destShape := v.Tensor.MustSize() if !reflect.DeepEqual(destShape, sourceShape) { err := fmt.Errorf("VarStore.LoadWeights() failed. Mismatched shape error for variable name: %v - At store: %v - At source %v\n", name, destShape, sourceShape) return err } ts.NoGrad(func() { v.Tensor.Copy_(currTs) }) } ts.CleanUp() return nil } // LoadPartial loads the VarStore variable values from a file if it exists. // // Weight values for the tensors currently stored in the var-store and the given file get // loaded from the given file. If a variable in the var store is not present in the given file, // it is skipped and its values are not updated. This method should be used if pre-trained // weight for only parts of the model are available. // Note that the set of variables stored in the var-store is not changed, only the values // for these tensors are modified. // // Returns a String Vector containing the names of missing variables. func (vs *VarStore) LoadPartial(filepath string) ([]string, error) { namedTensors, err := ts.LoadMultiWithDevice(filepath, vs.device) if err != nil { return nil, err } var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) for _, namedTensor := range namedTensors { namedTensorsMap[namedTensor.Name] = namedTensor.Tensor } var missingVariables []string // Match and in-place copy value (update) from newly loaded tensors // to existing named tensors if name is matched. Throw error otherwise. vs.Lock() defer vs.Unlock() for name, v := range vs.vars { var currTs *ts.Tensor var ok bool // missing variable if currTs, ok = namedTensorsMap[name]; !ok { missingVariables = append(missingVariables, name) continue } // mismatched shape destShape := currTs.MustSize() sourceShape := v.Tensor.MustSize() if !reflect.DeepEqual(destShape, sourceShape) { fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", name, destShape, sourceShape) missingVariables = append(missingVariables, name) continue } ts.NoGrad(func() { v.Tensor.Copy_(currTs) }) } for _, x := range namedTensors { x.Tensor.MustDrop() } ts.CleanUp() return missingVariables, nil } // LoadWeightsPartial loads the VarStore variable values from a file if it exists. // // Weight values for the tensors currently stored in the var-store and the given file get // loaded from the given file. If a variable in the var store is not present in the given file, // it is skipped and its values are not updated. This method should be used if pre-trained // weight for only parts of the model are available. // Note that the set of variables stored in the var-store is not changed, only the values // for these tensors are modified. // // Returns a String Vector containing the names of missing variables. func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string, error) { var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) for _, namedTensor := range namedTensors { namedTensorsMap[namedTensor.Name] = namedTensor.Tensor } var missingVariables []string // Match and in-place copy value (update) from newly loaded tensors // to existing named tensors if name is matched. Throw error otherwise. vs.Lock() defer vs.Unlock() for name, v := range vs.vars { var currTs *ts.Tensor var ok bool // missing variable if currTs, ok = namedTensorsMap[name]; !ok { missingVariables = append(missingVariables, name) continue } // mismatched shape destShape := currTs.MustSize() sourceShape := v.Tensor.MustSize() if !reflect.DeepEqual(destShape, sourceShape) { fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", name, destShape, sourceShape) missingVariables = append(missingVariables, name) continue } ts.NoGrad(func() { v.Tensor.Copy_(currTs) }) } ts.CleanUp() return missingVariables, nil } // Freeze freezes this VarStore. // // Gradients for the variables in this store are not tracked anymore. func (vs *VarStore) Freeze() error { vs.Lock() defer vs.Unlock() for name, v := range vs.vars { err := v.Tensor.RequiresGrad_(false) if err != nil { err = fmt.Errorf("VarStore.Freeze() set 'requiresGrad' for tensor %q failed.", name) return err } } return nil } // Unfreeze unfreezes a VarStore. // // Gradients for the variables in this store are tracked again. func (vs *VarStore) Unfreeze() error { vs.Lock() defer vs.Unlock() for name, v := range vs.vars { if v.Type == "parameter" && v.Trainable { err := v.Tensor.RequiresGrad_(true) err = fmt.Errorf("VarStore.Freeze() set 'requiresGrad' for tensor %q failed.", name) return err } } return nil } // Copy copies variable values from a source VarStore to this VarStore. // // All the variables in this var store have to exist with the same // name in the source var store, otherwise an error is returned. func (vs *VarStore) Copy(src *VarStore) error { vs.Lock() defer vs.Unlock() src.Lock() defer src.Unlock() srcVars := src.vars device := vs.device for k := range vs.vars { if _, ok := srcVars[k]; !ok { err := fmt.Errorf("VarStore.Copy() failed: cannot find %q in the source VarStore.\n", k) return err } } for k, v := range vs.vars { srcV := srcVars[k] srcDevTs, err := srcV.Tensor.To(device, false) if err != nil { return err } ts.NoGrad(func() { v.Tensor.Copy_(srcDevTs) }) srcDevTs.MustDrop() } ts.CleanUp() return nil } // Summary prints a simple list of all named variables with their shapes. func (vs *VarStore) Summary() { vars := vs.vars layers := make([]string, 0, len(vars)) for name := range vars { layers = append(layers, name) } sort.Strings(layers) var dtype gotch.DType isFirst := true for _, l := range layers { var x *ts.Tensor var isBuffer bool for name, v := range vars { if name == l { x = v.Tensor // Get DType of first tensor for representation only if isFirst { dtype = x.DType() } isFirst = false isBuffer = v.Type == "buffer" break } } if isBuffer { fmt.Printf("%s - [buffer] - %+v\n", l, x.MustSize()) } else { fmt.Printf("%s - %+v\n", l, x.MustSize()) } } fmt.Printf("Num of layers: %v\n", len(vars)) fmt.Printf("DType: %v\n", dtype) } // Destroy deletes all tensors in varstore and set it to nil. func (vs *VarStore) Destroy() { vs.Lock() for n, v := range vs.vars { v.Tensor.MustDrop() delete(vs.vars, n) } vs.Unlock() vs = nil } // ToDType casts all variables in VarStore to specified DType. // // NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible. func (vs *VarStore) ToDType(dtype gotch.DType) { vs.Root().ToDType(dtype) } // ToFloat casts all float-like variables in VarStore to `Float` dtype. // // NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype. func (vs *VarStore) ToFloat() { vs.Root().ToFloat() } // ToDouble casts all float-like variables in VarStore to `Double` dtype. // // NOTE. float-like includes `Half`, `Float` and `Double` dtype. func (vs *VarStore) ToDouble() { vs.Root().ToDouble() } // ToHalf casts all float-like variables in VarStore to `Half` dtype. // // NOTE. float-like includes `Half`, `Float` and `Double` dtype. func (vs *VarStore) ToHalf() { vs.Root().ToHalf() } // ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype. // // NOTE. float-like includes `Half`, `Float` and `Double` dtype. func (vs *VarStore) ToBFloat16() { vs.Root().ToBFloat16() } func (vs *VarStore) ToDevice(device gotch.Device) { p := vs.Root() p.ToDevice(device) } // Path methods: // ============= // Sub gets a sub-path of the given path. func (p *Path) Sub(str string) *Path { if strings.Contains(str, SEP) { log.Fatalf("Path.Sub() failed: name cannot contain %v (%v)\n", SEP, str) } path := p.path path = append(path, str) return &Path{ path: path, varstore: p.varstore, group: p.group, } } // Paths returns all sub paths from current path. func (p *Path) Paths() []string { return p.path } // Device gets the device where the VarStore variables are stored. func (p *Path) Device() gotch.Device { return p.varstore.device } // NOTE: Cannot name as `path` as having a field name `path` func (p *Path) getpath(name string) string { if strings.Contains(name, SEP) { log.Fatalf("Sub name cannot contain %v (%v)\n", SEP, name) } if len(p.path) == 0 { return name } else { return fmt.Sprintf("%v%v%v", strings.Join(p.path, SEP), SEP, name) } } func (p *Path) addNamed(name string, newTs *ts.Tensor, trainable bool, varType string, persistent bool) (*ts.Tensor, string, error) { path := p.getpath(name) p.varstore.Lock() defer p.varstore.Unlock() if _, ok := p.varstore.vars[path]; ok { path = fmt.Sprintf("%v__%v", path, len(p.varstore.vars)) } var ( tensor *ts.Tensor err error ) if trainable { tensor, err = newTs.SetRequiresGrad(true, false) if err != nil { return nil, path, err } } else { tensor, err = newTs.ShallowClone() if err != nil { return nil, path, err } } v := Var{ Tensor: tensor, Group: p.group, Trainable: trainable, Type: varType, Persitent: persistent, } p.varstore.vars[path] = v return tensor, path, nil } func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string, persistent bool) (*ts.Tensor, error) { ts, _, err := p.addNamed(name, newTs, trainable, varType, persistent) return ts, err } type AddOpts struct { VarType string Persistent bool } type AddOpt func(*AddOpts) func defaultAddOpts() *AddOpts { return &AddOpts{ VarType: "parameter", Persistent: true, } } func WithVarType(v string) AddOpt { if v != "parameter" && v != "buffer" { log.Fatalf("WithVarType() failed(): invalid option variable type. Input must be either 'parameter' or 'buffer'.") } return func(o *AddOpts) { o.VarType = v } } func WithPersistent(v bool) AddOpt { return func(o *AddOpts) { o.Persistent = v } } // Add adds a tensor to a given path. // // Args // - name: intention name of variable in VarStore (if duplicated, it will be added a suffix number) // - x: tensor holding values to keep in VarStore // - trainable: marked whether tensor is trainable. // - o.VarType: variable type, i.e., either "parameter" or "buffer" // - o.Persistent: whether to save this variables when `VarStore.Save()` is called. Only applied to `buffer` type. // Returns a reference to a tensor stored in VarStore and error if occurred. func (p *Path) Add(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) { o := defaultAddOpts() for _, opt := range opts { opt(o) } return p.add(name, x, trainable, o.VarType, o.Persistent) } func (p *Path) AddNamed(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, string, error) { o := defaultAddOpts() for _, opt := range opts { opt(o) } return p.addNamed(name, x, trainable, o.VarType, o.Persistent) } // MustAdd adds a tensor to a given path. // // Args // - name: intention name of variable in VarStore (if duplicated, it will be added a suffix number) // - x: tensor holding values to keep in VarStore // - trainable: marked whether tensor is trainable. // - o.VarType: variable type, i.e., either "parameter" or "buffer" // - o.Persistent: whether to save this variables when `VarStore.Save()` is called. Only applied to `buffer` type. // Returns a reference to a tensor stored in VarStore. func (p *Path) MustAdd(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) *ts.Tensor { x, err := p.Add(name, x, trainable, opts...) if err != nil { log.Fatal(err) } return x } // Remove removes a variable from `VarStore` func (p *Path) Remove(name string) error { p.varstore.Lock() defer p.varstore.Unlock() _, ok := p.varstore.vars[name] if !ok { err := fmt.Errorf("Path.Remove() failed: cannot find a variable with name %q in VarStore.", name) return err } delete(p.varstore.vars, name) return nil } // MustRemove removes a variable from `VarStore` func (p *Path) MustRemove(name string) { err := p.Remove(name) if err != nil { err = fmt.Errorf("Path.MustRemove() failed: %w", err) log.Fatal(err) } } func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) { path := p.getpath(name) // if found, return it if v, ok := p.varstore.vars[path]; ok { return v.Tensor, nil } // not found, add it return p.Add(name, tensor, trainable, opts...) } func (p *Path) SetGroup(g uint) { p.varstore.Lock() defer p.varstore.Unlock() // TODO. set group for individual variables. // TBD. variables of current path only or all sub paths as well? // For now, just set group for all variable at the path path := strings.Join(p.path, SEP) for name, v := range p.varstore.vars { vpaths := strings.Split(name, SEP) vpath := strings.Join(vpaths[:len(vpaths)-1], SEP) if vpath == path { v.Group = g p.varstore.vars[name] = v } } p.group = g } // ToDType casts all variables in this path and its sub-paths to the specified dtype. // // NOTE. this method should be used for floating-point conversion, i.e., // "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double". func (p *Path) ToDType(dtype gotch.DType) { p.varstore.Lock() defer p.varstore.Unlock() path := strings.Join(p.path, SEP) for name, v := range p.varstore.vars { if strings.Contains(name, path) { newVar := v newVar.Tensor = v.Tensor.MustTotype(dtype, true) p.varstore.vars[name] = newVar } } } // toFloat casts all float-like variables in this current path and sub-paths to specified dtype. func (p *Path) toFloat(dtype gotch.DType) { p.varstore.Lock() defer p.varstore.Unlock() path := strings.Join(p.path, SEP) for name, v := range p.varstore.vars { if strings.Contains(name, path) { dtype := v.Tensor.DType() if gotch.IsFloatDType(dtype) { newVar := v newVar.Tensor = v.Tensor.MustTotype(dtype, true) p.varstore.vars[name] = newVar } } } ts.CleanUp() } // ToFloat casts all variables in current path and subpaths to `Float` precision. func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) { dtype := gotch.Float if len(floatDTypeOpt) > 0 { dt := floatDTypeOpt[0] if !gotch.IsFloatDType(dt) { // Ingore the option if gotch.Debug { log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt) } } else { dtype = dt } } p.toFloat(dtype) } // ToDouble casts all variables in current path and subpaths to `Double` precision dtype. func (p *Path) ToDouble() { p.toFloat(gotch.Double) } // ToHalf casts all variables in current path and subpaths to `Half` precision dtype. func (p *Path) ToHalf() { p.toFloat(gotch.Half) } // ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype. func (p *Path) ToBFloat16() { p.toFloat(gotch.BFloat16) } func (p *Path) ToDevice(device gotch.Device) { var err error p.varstore.Lock() defer p.varstore.Unlock() path := strings.Join(p.path, SEP) for name, v := range p.varstore.vars { if strings.Contains(name, path) { newVar := v var requires bool requires, err = v.Tensor.RequiresGrad() or_panic(err) var retain bool retain, err = v.Tensor.RetainsGrad(false) or_panic(err) newVar.Tensor, err = v.Tensor.To(device, true) or_panic(err) newVar.Tensor, err = newVar.Tensor.SetRequiresGrad(requires, true) or_panic(err) if retain { err = newVar.Tensor.RetainGrad(false) or_panic(err) } p.varstore.vars[name] = newVar } } ts.CleanUp() } // ZerosNoTrain creates a new variable initialized with zeros. // // The new variable is named according to the name parameter and // has the specified shape. The variable will not be trainable so // gradients will not be tracked. // The variable uses a float tensor initialized with zeros. func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { device := p.Device() dtype := gotch.DefaultDType z, err := ts.Zeros(dims, dtype, device) if err != nil { err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err) return nil, err } out, err := p.Add(name, z, false, opts...) if err != nil { return nil, err } z.MustDrop() return out, nil } // MustZerosNoTrain creates a new variable initialized with zeros. // // The new variable is named according to the name parameter and // has the specified shape. The variable will not be trainable so // gradients will not be tracked. // The variable uses a float tensor initialized with zeros. func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor { x, err := p.ZerosNoTrain(name, dims, opts...) if err != nil { log.Fatal(err) } return x } // OnesNoTrain creates a new variable initialized with ones. // // The new variable is named according to the name parameter and // has the specified shape. The variable will not be trainable so // gradients will not be tracked. // The variable uses a float tensor initialized with ones. func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { device := p.Device() dtype := gotch.DefaultDType z, err := ts.Ones(dims, dtype, device) if err != nil { err = fmt.Errorf("Path.OneNoTrain() failed: %w", err) return nil, err } out, err := p.Add(name, z, false, opts...) if err != nil { return nil, err } z.MustDrop() return out, nil } // MustOnesNoTrain creates a new variable initialized with ones. // // The new variable is named according to the name parameter and // has the specified shape. The variable will not be trainable so // gradients will not be tracked. // The variable uses a float tensor initialized with ones. func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor { x, err := p.OnesNoTrain(name, dims, opts...) if err != nil { log.Fatal(err) } return x } // NewVar creates a new variable. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized as per the // related argument. func (p *Path) NewVarNamed(name string, dims []int64, ini nn.Init, opts ...AddOpt) (*ts.Tensor, string, error) { dtype := gotch.DefaultDType // v := ini.InitTensor(dims, p.varstore.device, dtype) var v *ts.Tensor v = ini.InitTensor(dims, p.varstore.device, dtype) out, path, err := p.AddNamed(name, v, true, opts...) if err != nil { return nil, "", err } err = v.Drop() or_panic(err) return out,path, err } // NewVar creates a new variable. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized as per the // related argument. func (p *Path) NewVar(name string, dims []int64, ini nn.Init, opts ...AddOpt) (*ts.Tensor, error) { dtype := gotch.DefaultDType // v := ini.InitTensor(dims, p.varstore.device, dtype) var v *ts.Tensor v = ini.InitTensor(dims, p.varstore.device, dtype) out, err := p.Add(name, v, true, opts...) if err != nil { return nil, err } err = v.Drop() or_panic(err) return out, err } // MustNewVar create a new variable. It panics if error. func (p *Path) MustNewVar(name string, dims []int64, ini nn.Init, opts ...AddOpt) *ts.Tensor { x, err := p.NewVar(name, dims, ini, opts...) if err != nil { log.Fatal(err) } return x } // Zeros creates a new variable initialized with zeros. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized with zeros. func (p *Path) Zeros(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return p.NewVar(name, dims, nn.NewConstInit(0.0), opts...) } // MustZeros create a new variables with zero values. It panics if error. func (p *Path) MustZeros(name string, dims []int64, opts ...AddOpt) *ts.Tensor { x, err := p.Zeros(name, dims, opts...) if err != nil { log.Fatal(err) } return x } // Ones creates a new variable initialized with ones. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized with ones. func (p *Path) Ones(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return p.NewVar(name, dims, nn.NewConstInit(1.0), opts...) } // MustOnes creates a new variable initialized with ones. It panics if error occurred. func (p *Path) MustOnes(name string, dims []int64, opts ...AddOpt) *ts.Tensor { x, err := p.Ones(name, dims, opts...) if err != nil { log.Fatal(err) } return x } // RandnStandard creates a new variable initialized randomly with normal distribution. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized randomly using a // standard normal distribution. func (p *Path) RandnStandard(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return p.NewVar(name, dims, nn.NewRandnInit(0.0, 1.0), opts...) } // MustRandnStandard creates a new variable initialized randomly with normal distribution. It panics if error occurred. func (p *Path) MustRandnStandard(name string, dims []int64, opts ...AddOpt) *ts.Tensor { x, err := p.RandnStandard(name, dims, opts...) if err != nil { log.Fatal(err) } return x } // Randn creates a new variable initialized randomly with normal distribution. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized randomly using a // normal distribution with the specified mean and standard deviation. func (p *Path) Randn(name string, dims []int64, mean float64, stdev float64, opts ...AddOpt) (*ts.Tensor, error) { return p.NewVar(name, dims, nn.NewRandnInit(mean, stdev), opts...) } // MustRandn creates a new variable initialized randomly with normal distribution. It panics if error occurred. func (p *Path) MustRandn(name string, dims []int64, mean float64, stdev float64, opts ...AddOpt) *ts.Tensor { x, err := p.Randn(name, dims, mean, stdev, opts...) if err != nil { log.Fatal(err) } return x } // Uniform creates a new variable initialized randomly with uniform distribution. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized randomly using a // uniform distribution between the specified bounds. func (p *Path) Uniform(name string, dims []int64, lo, up float64, opts ...AddOpt) (*ts.Tensor, error) { return p.NewVar(name, dims, nn.NewUniformInit(lo, up), opts...) } // MustUniform creates a new variable initialized randomly with uniform distribution. It panics if error occurred. func (p *Path) MustUniform(name string, dims []int64, lo, up float64, opts ...AddOpt) *ts.Tensor { x, err := p.Uniform(name, dims, lo, up, opts...) if err != nil { log.Fatal(err) } return x } // KaimingUniform creates a new variable initialized randomly with kaiming uniform. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized randomly using a // uniform distribution which bounds follow Kaiming initialization. func (p *Path) KaimingUniform(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return p.NewVar(name, dims, nn.NewKaimingUniformInit(), opts...) } // MustKaimingUniform creates a new variable initialized randomly with kaiming uniforms. It panics if error occurred. func (p *Path) MustKaimingUniform(name string, dims []int64, opts ...AddOpt) *ts.Tensor { x, err := p.KaimingUniform(name, dims, opts...) if err != nil { log.Fatal(err) } return x } // VarCopy creates a new variable initialized by copying an existing tensor. // // The new variable is named according to the name parameter and // has the specified shape. The variable is trainable, its gradient // will be tracked. // The variable uses a float tensor initialized by copying some // given tensor. func (p *Path) VarCopy(name string, t *ts.Tensor) (*ts.Tensor, error) { size, err := t.Size() if err != nil { err = fmt.Errorf("Path.VarCopy() failed: %w\n", err) return nil, err } v, err := p.Zeros(name, size) if err != nil { return nil, err } ts.NoGrad(func() { ts.Copy_(v, t) }) return v, nil } // VarCopy creates a new variable initialized by copying an existing tensor. func (p *Path) MustVarCopy(name string, t *ts.Tensor) *ts.Tensor { x, err := p.VarCopy(name, t) if err != nil { log.Fatal(err) } return x } // Get gets a reference to tensor corresponding to a given name if present. func (p *Path) Get(name string) (*ts.Tensor, error) { p.varstore.Lock() defer p.varstore.Unlock() v, ok := p.varstore.vars[name] if !ok { err := fmt.Errorf("Path.Get() failed: Cannot find variable for name: %v\n", name) return nil, err } return v.Tensor, nil } // MustGet gets a reference to a tensor corresponding to a given name if present. It panics if error occurred. func (p *Path) MustGet(name string) *ts.Tensor { x, err := p.Get(name) if err != nil { log.Fatal(err) } return x } // Entry gets the entry corresponding to a given name for in-place manipulation. func (p *Path) Entry(name string) *Entry { p.varstore.Lock() defer p.varstore.Unlock() return &Entry{ name: name, path: p, } } // Entry methods: // ============== // OrVar returns the existing entry if found, otherwise create a new variable. // // If this entry name matches the name of a variables stored in the // var store, the corresponding tensor is returned. Otherwise a new // variable is added to the var-store with the entry name and is // initialized according to the init parameter. func (e *Entry) OrVar(dims []int64, init nn.Init, opts ...AddOpt) (*ts.Tensor, error) { v := init.InitTensor(dims, e.path.varstore.device) out, err := e.path.getOrAddWithLock(e.name, v, true, opts...) if err != nil { return nil, err } v.MustDrop() return out, nil } // MustOrVar returns the existing entry if found, otherwise creates a new variable. It panics if error. func (e *Entry) MustOrVar(dims []int64, init nn.Init, opts ...AddOpt) *ts.Tensor { x, err := e.OrVar(dims, init, opts...) if err != nil { log.Fatal(err) } return x } // OrZeros returns the existing entry if found, otherwise creates a new variable. func (e *Entry) OrZeros(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return e.OrVar(dims, nn.NewConstInit(0.0)) } // MustOrZeros returns the exising entry if found, otherwise creates a new variable. func (e *Entry) MustOrZeros(dims []int64, opts ...AddOpt) *ts.Tensor { x, err := e.OrZeros(dims, opts...) if err != nil { log.Fatal(err) } return x } // OrVarCopy returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrVarCopy(tensor *ts.Tensor) (*ts.Tensor, error) { size, err := tensor.Size() if err != nil { return nil, err } v, err := e.OrZeros(size) if err != nil { return nil, err } ts.NoGrad(func() { ts.Copy_(v, tensor) }) return v, nil } // MustOrVarCopy returns the existing entry if found, otherwise create a new variable. func (e *Entry) MustOrVarCopy(tensor *ts.Tensor) *ts.Tensor { x, err := e.OrVarCopy(tensor) if err != nil { log.Fatal(err) } return x } // OrKaimingUniform returns the existing entry if, otherwise create a new variable. func (e *Entry) OrKaimingUniform(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return e.OrVar(dims, nn.NewKaimingUniformInit(), opts...) } // MustOrKaimingUniform returns the existing entry if, otherwise create a new variable. func (e *Entry) MustOrKaimingUniform(dims []int64, opts ...AddOpt) *ts.Tensor { x, err := e.OrKaimingUniform(dims, opts...) if err != nil { log.Fatal(err) } return x } // OrOnes returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrOnes(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return e.OrVar(dims, nn.NewConstInit(1.0), opts...) } // MustOrOnes returns the existing entry if found, otherwise create a new variable. func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor { x, err := e.OrOnes(dims, opts...) if err != nil { log.Fatal(err) } return x } // OrOnesNoTrain returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { dtype := gotch.DefaultDType o := ts.MustOnes(dims, dtype, e.path.Device()) out, err := e.path.getOrAddWithLock(e.name, o, true, opts...) if err != nil { return nil, err } o.MustDrop() return out, nil } // MustOrOnesNoTrain returns the existing entry if found, otherwise create a new variable. func (e *Entry) MustOrOnesNoTrain(dims []int64, opts ...AddOpt) *ts.Tensor { x, err := e.OrOnesNoTrain(dims, opts...) if err != nil { log.Fatal(err) } return x } // OrRandn returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrRandn(dims []int64, mean, stdev float64, opts ...AddOpt) (*ts.Tensor, error) { return e.OrVar(dims, nn.NewRandnInit(mean, stdev), opts...) } // MustOrRandn returns the existing entry if, otherwise create a new variable. func (e *Entry) MustOrRandn(dims []int64, mean, stdev float64, opts ...AddOpt) *ts.Tensor { x, err := e.OrRandn(dims, mean, stdev, opts...) if err != nil { log.Fatal(err) } return x } // OrRandnStandard returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrRandnStandard(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { return e.OrVar(dims, nn.NewRandnInit(0.0, 1.0), opts...) } // MustOrRandnStandard returns the existing entry if, otherwise create a new variable. func (e *Entry) MustOrRandnStandard(dims []int64, opts ...AddOpt) *ts.Tensor { x, err := e.OrRandnStandard(dims, opts...) if err != nil { log.Fatal(err) } return x } // OrUniform returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrUniform(dims []int64, lo, up float64, opts ...AddOpt) (*ts.Tensor, error) { return e.OrVar(dims, nn.NewUniformInit(lo, up), opts...) } // MustOrUniform returns the existing entry if found, otherwise create a new variable. func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.Tensor { x, err := e.OrUniform(dims, lo, up, opts...) if err != nil { log.Fatal(err) } return x } // OrZerosNoTrain returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { dtype := gotch.DefaultDType z := ts.MustZeros(dims, dtype, e.path.Device()) out, err := e.path.getOrAddWithLock(e.name, z, true, opts...) if err != nil { return nil, err } z.MustDrop() return out, nil } // MustOrZerosNoTrain returns the existing entry if found, otherwise create a new variable. func (e *Entry) MustOrZerosNoTrain(dims []int64, opts ...AddOpt) *ts.Tensor { x, err := e.OrZerosNoTrain(dims, opts...) if err != nil { log.Fatal(err) } return x }