diff --git a/DockerfileServer b/DockerfileServer index 0c06cf5..8817521 100644 --- a/DockerfileServer +++ b/DockerfileServer @@ -2,7 +2,7 @@ FROM docker.io/nvidia/cuda:11.8.0-devel-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update -RUN apt-get install -y wget sudo pkg-config libopencv-dev unzip python3-pip +RUN apt-get install -y wget sudo pkg-config libopencv-dev unzip python3-pip vim RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 @@ -25,30 +25,36 @@ RUN go install || true WORKDIR /root -RUN wget https://github.com/sugarme/gotch/releases/download/v0.9.0/setup-libtorch.sh +RUN wget https://git.andr3h3nriqu3s.com/andr3/gotch/raw/commit/22e75becf0432cda41a7c055a4d60ea435f76599/setup-libtorch.sh RUN chmod +x setup-libtorch.sh ENV CUDA_VER=11.8 -ENV GOTCH_VER=v0.9.1 +ENV GOTCH_VER=v0.9.2 RUN bash setup-libtorch.sh ENV GOTCH_LIBTORCH="/usr/local/lib/libtorch" + +ENV REFRESH_SETUP=0 + ENV LIBRARY_PATH="$LIBRARY_PATH:$GOTCH_LIBTORCH/lib" ENV export CPATH="$CPATH:$GOTCH_LIBTORCH/lib:$GOTCH_LIBTORCH/include:$GOTCH_LIBTORCH/include/torch/csrc/api/include" ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$GOTCH_LIBTORCH/lib:/usr/lib64-nvidia:/usr/local/cuda-${CUDA_VERSION}/lib64" -RUN wget https://github.com/sugarme/gotch/releases/download/v0.9.0/setup-gotch.sh +RUN wget https://git.andr3h3nriqu3s.com/andr3/gotch/raw/branch/master/setup-gotch.sh RUN chmod +x setup-gotch.sh RUN echo 'root ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers RUN bash setup-gotch.sh RUN ln -s /usr/local/lib/libtorch/include/torch/csrc /usr/local/lib/libtorch/include/torch/csrc/api/include/torch -RUN mkdir -p /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include/torch/csrc/api -RUN find /usr/local/lib/libtorch/include -maxdepth 1 -type d | tail -n +2 | grep -ve 'torch$' | xargs -I{} ln -s {} /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include -RUN ln -s /usr/local/lib/libtorch/include/torch/csrc/api/include /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include/torch/csrc/api/include -RUN find /usr/local/lib/libtorch/include/torch -maxdepth 1 -type f | xargs -I{} ln -s {} /go/pkg/mod/github.com/sugarme/gotch@v0.9.1/libtch/libtorch/include/torch +RUN mkdir -p /go/pkg/mod/git.andr3h3nriqu3s.com/andr3/gotch@v0.9.2/libtch/libtorch/include/torch/csrc/api +RUN find /usr/local/lib/libtorch/include -maxdepth 1 -type d | tail -n +2 | grep -ve 'torch$' | xargs -I{} ln -s {} /go/pkg/mod/git.andr3h3nriqu3s.com/andr3/gotch@v0.9.2/libtch/libtorch/include +RUN ln -s /usr/local/lib/libtorch/include/torch/csrc/api/include /go/pkg/mod/git.andr3h3nriqu3s.com/andr3/gotch@v0.9.2/libtch/libtorch/include/torch/csrc/api/include +RUN find /usr/local/lib/libtorch/include/torch -maxdepth 1 -type f | xargs -I{} ln -s {} /go/pkg/mod/git.andr3h3nriqu3s.com/andr3/gotch@v0.9.2/libtch/libtorch/include/torch RUN ln -s /usr/local/lib/libtorch/lib/libcudnn.so.8 /usr/local/lib/libcudnn.so WORKDIR /app +ENV CGO_CXXFLAGS="-I/usr/local/lib/libtorch/include/torch/csrc/api/include/ -I/usr/local/lib/libtorch/include" +ENV CGO_CFLAGS="-I/usr/local/lib/libtorch/include/torch/csrc/api/include/ -I/usr/local/lib/libtorch/include" + ADD . . -RUN go install || true +RUN go build -x || true CMD ["bash", "-c", "go run ."] diff --git a/go.mod b/go.mod index f93209a..2b2c7d7 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 golang.org/x/crypto v0.19.0 + git.andr3h3nriqu3s.com/andr3/gotch v0.9.2 ) require ( @@ -32,7 +33,6 @@ require ( github.com/muesli/termenv v0.15.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rivo/uniseg v0.4.6 // indirect - github.com/sugarme/gotch v0.9.1 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/sync v0.1.0 // indirect diff --git a/go.sum b/go.sum index c95a5a6..c732162 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +git.andr3h3nriqu3s.com/andr3/gotch v0.9.1 h1:1q34JKV8cX80n7LXbJswlXCiRtNbzcvJ/vbgb6an1tA= +git.andr3h3nriqu3s.com/andr3/gotch v0.9.1/go.mod h1:FXusE3CHt8NLf5wynUGaHtIbToRuYifsZaC5EZH0pJY= +git.andr3h3nriqu3s.com/andr3/gotch v0.9.2 h1:aZcsPgDVGVhrEFoer0upSkzPqJWNMxdUHRktP4s6MSc= +git.andr3h3nriqu3s.com/andr3/gotch v0.9.2/go.mod h1:FXusE3CHt8NLf5wynUGaHtIbToRuYifsZaC5EZH0pJY= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= @@ -70,10 +74,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/sugarme/gotch v0.9.1 h1:J6JCE1C2AfPmM1xk0p46LdzWtfNvbvZZnWdkj9v54jo= -github.com/sugarme/gotch v0.9.1/go.mod h1:dien16KQcZPg/g+YiEH3q3ldHlKO2//2I2i2Gp5OQcI= -github.com/wangkuiyi/gotorch v0.0.0-20201028015551-9afed2f3ad7b h1:oJfm5gCGdy9k2Yb+qmMR+HMRQ89CbVDsDi6DD9AZSTk= -github.com/wangkuiyi/gotorch v0.0.0-20201028015551-9afed2f3ad7b/go.mod h1:WC7g+ojb7tPOZhHI2+ZI7ZXTW7uzF9uFOZfZgIX+SjI= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= diff --git a/logic/models/train/torch/modelloader/modelloader.go b/logic/models/train/torch/modelloader/modelloader.go index f5e6e64..bace047 100644 --- a/logic/models/train/torch/modelloader/modelloader.go +++ b/logic/models/train/torch/modelloader/modelloader.go @@ -3,9 +3,9 @@ package imageloader import ( "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" - "github.com/sugarme/gotch" - torch "github.com/sugarme/gotch/ts" - "github.com/sugarme/gotch/vision" + "git.andr3h3nriqu3s.com/andr3/gotch" + torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" + "git.andr3h3nriqu3s.com/andr3/gotch/vision" ) type Dataset struct { diff --git a/logic/models/train/torch/nn/linear.go b/logic/models/train/torch/nn/linear.go new file mode 100644 index 0000000..c479ae4 --- /dev/null +++ b/logic/models/train/torch/nn/linear.go @@ -0,0 +1,168 @@ +package my_nn + +// linear is a fully-connected layer + +import ( + "math" + + "git.andr3h3nriqu3s.com/andr3/gotch/nn" + "git.andr3h3nriqu3s.com/andr3/gotch/ts" +) + +// LinearConfig is a configuration for a linear layer +type LinearConfig struct { + WsInit nn.Init // iniital weights + BsInit nn.Init // optional initial bias + Bias bool +} + +// DefaultLinearConfig creates default LinearConfig with +// weights initiated using KaimingUniform and Bias is set to true +func DefaultLinearConfig() *LinearConfig { + negSlope := math.Sqrt(5) + return &LinearConfig{ + // NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!! + // Avoid using it now. + WsInit: nn.NewKaimingUniformInit(nn.WithKaimingNegativeSlope(negSlope)), + BsInit: nil, + Bias: true, + } +} + +// Linear is a linear fully-connected layer +type Linear struct { + Ws *ts.Tensor + weight_name string + Bs *ts.Tensor + bias_name string +} + +// NewLinear creates a new linear layer +// y = x*wT + b +// inDim - input dimension (x) [input features - columns] +// outDim - output dimension (y) [output features - columns] +// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim} +func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { + var bias_name string + var bs *ts.Tensor + var err error + if c.Bias { + switch { + case c.BsInit == nil: + shape := []int64{inDim, outDim} + fanIn, _, err := nn.CalculateFans(shape) + or_panic(err) + bound := 0.0 + if fanIn > 0 { + bound = 1 / math.Sqrt(float64(fanIn)) + } + bsInit := nn.NewUniformInit(-bound, bound) + bs, bias_name, err = vs.NewVarNamed("bias", []int64{outDim}, bsInit) + or_panic(err) + + // Find better way to do this + bs, err = bs.T(true) + or_panic(err) + bs, err = bs.T(true) + or_panic(err) + + bs, err = bs.SetRequiresGrad(true, true) + or_panic(err) + + err = bs.RetainGrad(false) + or_panic(err) + + vs.varstore.UpdateVarTensor(bias_name, bs, true) + + case c.BsInit != nil: + bs, bias_name, err = vs.NewVarNamed("bias", []int64{outDim}, c.BsInit) + or_panic(err) + } + } + + ws, weight_name, err := vs.NewVarNamed("weight", []int64{outDim, inDim}, c.WsInit) + or_panic(err) + + ws, err = ws.T(true) + or_panic(err) + + ws, err = ws.SetRequiresGrad(true, true) + or_panic(err) + + err = ws.RetainGrad(false) + or_panic(err) + + + vs.varstore.UpdateVarTensor(weight_name, ws, true) + + + return &Linear{ + Ws: ws, + weight_name: weight_name, + Bs: bs, + bias_name: bias_name, + } +} + +func (l *Linear) ExtractFromVarstore(vs *VarStore) { + l.Ws = vs.GetTensorOfVar(l.weight_name) + l.Bs = vs.GetTensorOfVar(l.bias_name) +} + +// Implement `Module` for `Linear` struct: +// ======================================= + +// Forward proceeds input node through linear layer. +// NOTE: +// - It assumes that node has dimensions of 2 (matrix). +// To make it work for matrix multiplication, input node should +// has same number of **column** as number of **column** in +// `LinearLayer` `Ws` property as weights matrix will be +// transposed before multiplied to input node. (They are all used `inDim`) +// - Input node should have shape of `shape{batch size, input features}`. +// (shape{batchSize, inDim}). The input features is `inDim` while the +// output feature is `outDim` in `LinearConfig` struct. +// +// Example: +// +// inDim := 3 +// outDim := 2 +// batchSize := 4 +// weights: 2x3 +// [ 1 1 1 +// 1 1 1 ] +// +// input node: 3x4 +// [ 1 1 1 +// 1 1 1 +// 1 1 1 +// 1 1 1 ] +func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) { + mul, err := xs.Matmul(l.Ws, false) + or_panic(err) + if l.Bs != nil { + mul, err = mul.Add(l.Bs, false) + or_panic(err) + } + + out, err := mul.Relu(false) + or_panic(err) + + return out +} + +// ForwardT implements ModuleT interface for Linear layer. +// +// NOTE: train param will not be used. +func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) { + mul, err := xs.Matmul(l.Ws, true) + or_panic(err) + + + mul, err = mul.Add(l.Bs, true) + or_panic(err) + + out, err := mul.Relu(true) + or_panic(err) + return out +} diff --git a/logic/models/train/torch/nn/optimizer.go b/logic/models/train/torch/nn/optimizer.go new file mode 100644 index 0000000..3fd55e5 --- /dev/null +++ b/logic/models/train/torch/nn/optimizer.go @@ -0,0 +1,603 @@ +package my_nn + +// Optimizers to be used for gradient-descent based training. + +import ( + "fmt" + "math" + + "github.com/charmbracelet/log" + "git.andr3h3nriqu3s.com/andr3/gotch/ts" +) + +// Optimizer is a struct object to run gradient descent. +type Optimizer struct { + varstore *VarStore + opt *ts.COptimizer + // variablesInOptimizer uint8 + variablesInOptimizer map[string]struct{} + config OptimizerConfig //interface{} + stepCount int + lr float64 +} + +func (o *Optimizer) Debug() { + for n, _ := range o.variablesInOptimizer { + v := o.varstore.GetVarOfName(n) + leaf, err := v.Tensor.IsLeaf(false) + or_panic(err) + + retains, err := v.Tensor.RetainsGrad(false) + or_panic(err) + + log.Info("[opt] var test", "n", n, "leaf", leaf, "retains", retains) + } +} + +func (o *Optimizer) RefreshValues() (err error) { + opt, err := o.config.buildCOpt(o.lr) + if err != nil { + return + } + + for name := range o.variablesInOptimizer { + v := o.varstore.GetVarOfName(name) + if v.Trainable { + if err = opt.AddParameter(v.Tensor, v.Group); err != nil { + err = fmt.Errorf("Optimizer defaultBuild - AddParameter failed: %w\n", err) + return + } + } + } + + o.opt = opt + return +} + +// OptimizerConfig defines Optimizer configurations. These configs can be used to build optimizer. +type OptimizerConfig interface { + buildCOpt(lr float64) (*ts.COptimizer, error) + + // Build builds an optimizer with the specified learning rate handling variables stored in `vs`. + // + // NOTE: Build is a 'default' method. It can be called by wrapping + // 'DefaultBuild' function + // E.g. AdamOptimizerConfig struct have a method to fullfil `Build` method of + // OptimizerConfig by wrapping `DefaultBuild` like + // (config AdamOptimizerConfig) Build(vs VarStore, lr float64) (retVal Optimizer, err error){ + // return defaultBuild(config, vs, lr) + // } + Build(vs *VarStore, lr float64) (*Optimizer, error) +} + +// defaultBuild is `default` Build method for OptimizerConfig interface +func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (*Optimizer, error) { + opt, err := config.buildCOpt(lr) + if err != nil { + return nil, err + } + + names := make(map[string]struct{}) + for name, v := range vs.vars { + if v.Trainable { + log.Info("Adding parameter", "name", name, "g", v.Group) + if err = opt.AddParameter(v.Tensor, v.Group); err != nil { + err = fmt.Errorf("Optimizer defaultBuild - AddParameter failed: %w\n", err) + return nil, err + } + } + names[name] = struct{}{} + } + + return &Optimizer{ + varstore: vs, + opt: opt, + variablesInOptimizer: names, + config: config, + stepCount: 0, + lr: 0, + }, nil +} + +// SGD Optimizer: +//=============== + +// SGDConfig holds parameters for building the SGD (Stochastic Gradient Descent) optimizer. +type SGDConfig struct { + Momentum float64 + Dampening float64 + Wd float64 + Nesterov bool +} + +// DefaultSGDConfig creates SGDConfig with default values. +func DefaultSGDConfig() *SGDConfig { + return &SGDConfig{ + Momentum: 0.0, + Dampening: 0.0, + Wd: 0.0, + Nesterov: false, + } +} + +// NewSGD creates the configuration for a SGD optimizer with specified values +func NewSGDConfig(momentum, dampening, wd float64, nesterov bool) *SGDConfig { + return &SGDConfig{ + Momentum: momentum, + Dampening: dampening, + Wd: wd, + Nesterov: nesterov, + } +} + +// Implement OptimizerConfig interface for SGDConfig +func (c *SGDConfig) buildCOpt(lr float64) (*ts.COptimizer, error) { + return ts.Sgd(lr, c.Momentum, c.Dampening, c.Wd, c.Nesterov) +} + +func (c *SGDConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) { + return defaultBuild(c, vs, lr) +} + +// Adam optimizer: +// =============== + +type AdamConfig struct { + Beta1 float64 + Beta2 float64 + Wd float64 +} + +// DefaultAdamConfig creates AdamConfig with default values +func DefaultAdamConfig() *AdamConfig { + return &AdamConfig{ + Beta1: 0.9, + Beta2: 0.999, + Wd: 0.0, + } +} + +// NewAdamConfig creates AdamConfig with specified values +func NewAdamConfig(beta1, beta2, wd float64) *AdamConfig { + return &AdamConfig{ + Beta1: beta1, + Beta2: beta2, + Wd: wd, + } +} + +// Implement OptimizerConfig interface for AdamConfig +func (c *AdamConfig) buildCOpt(lr float64) (*ts.COptimizer, error) { + return ts.Adam(lr, c.Beta1, c.Beta2, c.Wd) +} + +func (c *AdamConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) { + return defaultBuild(c, vs, lr) +} + +// AdamW optimizer: +// =============== + +type AdamWConfig struct { + Beta1 float64 + Beta2 float64 + Wd float64 +} + +// DefaultAdamWConfig creates AdamWConfig with default values +func DefaultAdamWConfig() *AdamWConfig { + return &AdamWConfig{ + Beta1: 0.9, + Beta2: 0.999, + Wd: 0.01, + } +} + +// NewAdamWConfig creates AdamWConfig with specified values +func NewAdamWConfig(beta1, beta2, wd float64) *AdamWConfig { + return &AdamWConfig{ + Beta1: beta1, + Beta2: beta2, + Wd: wd, + } +} + +// Implement OptimizerConfig interface for AdamWConfig +func (c *AdamWConfig) buildCOpt(lr float64) (*ts.COptimizer, error) { + return ts.AdamW(lr, c.Beta1, c.Beta2, c.Wd) +} + +// Build builds AdamW optimizer +func (c *AdamWConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) { + return defaultBuild(c, vs, lr) +} + +// RMSProp optimizer: +// =============== + +type RMSPropConfig struct { + Alpha float64 + Eps float64 + Wd float64 + Momentum float64 + Centered bool +} + +// DefaultAdamConfig creates AdamConfig with default values +func DefaultRMSPropConfig() *RMSPropConfig { + return &RMSPropConfig{ + Alpha: 0.99, + Eps: 1e-8, + Wd: 0.0, + Momentum: 0.0, + Centered: false, + } +} + +// NewRMSPropConfig creates RMSPropConfig with specified values +func NewRMSPropConfig(alpha, eps, wd, momentum float64, centered bool) *RMSPropConfig { + return &RMSPropConfig{ + Alpha: alpha, + Eps: eps, + Wd: wd, + Momentum: momentum, + Centered: centered, + } +} + +// Implement OptimizerConfig interface for RMSPropConfig +func (c *RMSPropConfig) buildCOpt(lr float64) (*ts.COptimizer, error) { + return ts.RmsProp(lr, c.Alpha, c.Eps, c.Wd, c.Momentum, c.Centered) +} + +func (c *RMSPropConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) { + return defaultBuild(c, vs, lr) +} + +// Optimizer methods: +// ================== + +func (opt *Optimizer) addMissingVariables() { + type param struct { + tensor *ts.Tensor + group uint + } + trainables := make(map[string]param) + for name, v := range opt.varstore.vars { + if v.Trainable { + trainables[name] = param{tensor: v.Tensor, group: v.Group} + } + } + missingVariables := len(trainables) - len(opt.variablesInOptimizer) + if missingVariables > 0 { + log.Info("INFO: Optimizer.addMissingVariables()...") + for name, x := range trainables { + if _, ok := opt.variablesInOptimizer[name]; !ok { + opt.opt.AddParameter(x.tensor, x.group) + opt.variablesInOptimizer[name] = struct{}{} + } + } + } +} + +// ZeroGrad zeroes the gradient for the tensors tracked by this optimizer. +func (opt *Optimizer) ZeroGrad() error { + if err := opt.opt.ZeroGrad(); err != nil { + err = fmt.Errorf("Optimizer.ZeroGrad() failed: %w\n", err) + return err + } + return nil +} + +// MustZeroGrad zeroes the gradient for the tensors tracked by this optimizer. +func (opt *Optimizer) MustZeroGrad() { + err := opt.ZeroGrad() + if err != nil { + log.Fatal(err) + } +} + +// Clips gradient value at some specified maximum value. +func (opt *Optimizer) ClipGradValue(max float64) { + opt.varstore.Lock() + defer opt.varstore.Unlock() + + for _, v := range opt.varstore.vars { + if v.Trainable { + // v.Tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max)) + gradTs := v.Tensor.MustGrad(false) + gradTs.Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max)) + } + } +} + +// Step performs an optimization step, updating the tracked tensors based on their gradients. +func (opt *Optimizer) Step() error { + err := opt.opt.Step() + if err != nil { + err = fmt.Errorf("Optimizer.Step() failed: %w\n", err) + return err + } + opt.stepCount += 1 + + return nil +} + +// MustStep performs an optimization step, updating the tracked tensors based on their gradients. +func (opt *Optimizer) MustStep() { + err := opt.Step() + if err != nil { + log.Fatal(err) + } +} + +// ResetStepCount set step count to zero. +func (opt *Optimizer) ResetStepCount() { + opt.stepCount = 0 +} + +// StepCount get current step count. +func (opt *Optimizer) StepCount() int { + return opt.stepCount +} + +// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step. +func (opt *Optimizer) BackwardStep(loss *ts.Tensor) error { + err := opt.opt.ZeroGrad() + if err != nil { + err = fmt.Errorf("Optimizer.BackwardStep() failed: %w\n", err) + return err + } + + loss.MustBackward() + err = opt.opt.Step() + if err != nil { + err = fmt.Errorf("Optimizer.BackwardStep() failed: %w\n", err) + return err + } + + return nil +} + +// MustBackwardStep applies a backward step pass, update the gradients, and performs an optimization step. +func (opt *Optimizer) MustBackwardStep(loss *ts.Tensor) { + err := opt.BackwardStep(loss) + if err != nil { + log.Fatal(err) + } +} + +// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step. +// +// The gradients are clipped based on `max` before being applied. +func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) error { + err := opt.opt.ZeroGrad() + if err != nil { + err = fmt.Errorf("Optimizer.BackwardStepClip() failed: %w\n", err) + return err + } + loss.MustBackward() + opt.ClipGradValue(max) + err = opt.opt.Step() + if err != nil { + err = fmt.Errorf("Optimizer.BackwardStepClip() failed: %w\n", err) + return err + } + return nil +} + +// MustBackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step. +// +// The gradients are clipped based on `max` before being applied. +func (opt *Optimizer) MustBackwardStepClip(loss *ts.Tensor, max float64) { + err := opt.BackwardStepClip(loss, max) + if err != nil { + log.Fatal(err) + } +} + +type ClipOpts struct { + NormType float64 + ErrorIfNonFinite bool +} + +type ClipOpt func(*ClipOpts) + +func defaultClipOpts() *ClipOpts { + return &ClipOpts{ + NormType: 2.0, + ErrorIfNonFinite: false, // will switch to "true" in the future. + } +} + +func WithNormType(v float64) ClipOpt { + return func(o *ClipOpts) { + o.NormType = v + } +} + +func WithErrorIfNonFinite(v bool) ClipOpt { + return func(o *ClipOpts) { + o.ErrorIfNonFinite = v + } +} + +// / Clips gradient L2 norm over all trainable parameters. +// +// The norm is computed over all gradients together, as if they were +// concatenated into a single vector. +// +// / Args: +// - max: max norm of the gradient +// - o.NormType. Type of the used p-norm, can be "inf" for infinity norm. Default= 2.0 +// - o.ErrorIfNonFinite bool. If true, throw error if total norm of the gradients from paramters is "nan", "inf" or "-inf". Default=false +// Returns: total norm of the parameters (viewed as a single vector) +// ref. https://github.com/pytorch/pytorch/blob/cb4aeff7d8e4c70bb638cf159878c5204d0cc2da/torch/nn/utils/clip_grad.py#L59 +func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error { + o := defaultClipOpts() + for _, option := range opts { + option(o) + } + + opt.varstore.Lock() + defer opt.varstore.Unlock() + parameters := opt.varstore.TrainableVariables() + if len(parameters) == 0 { + // return ts.MustOfSlice([]float64{0.0}), nil + return nil + } + + var ( + norms []*ts.Tensor + totalNorm *ts.Tensor + ) + + device := opt.varstore.device + + // FIXME. What about mixed-precision? + dtype := parameters[0].DType() + + if o.NormType == math.Inf(1) { + for _, v := range opt.varstore.vars { + n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true) + norms = append(norms, n) + } + // total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + totalNorm = ts.MustStack(norms, 0).MustMax(true) + } else { + for _, v := range opt.varstore.vars { + // x := v.Tensor.MustGrad(false).MustNorm(true) + + // NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm + // Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm + x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true) + norms = append(norms, x) + } + } + + // totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true) + // total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true) + for _, x := range norms { + x.MustDrop() + } + + totalNormVal := totalNorm.Float64Values(true)[0] + // if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + if o.ErrorIfNonFinite && (math.IsNaN(totalNormVal) || math.IsInf(totalNormVal, 1)) { + err := fmt.Errorf("The total norm of order (%v) for gradients from 'parameters' is non-finite, so it cannot be clipped. To disable this error and scale the gradients by the non-finite norm anyway, set option.ErrorIfNonFinite= false", o.NormType) + return err + } + + // clip_coef = max_norm / (total_norm + 1e-6) + // clipCoefTs := ts.TensorFrom([]float64{max}).MustDiv(totalNorm, true) + clipCoef := max / (totalNormVal + 1e-6) + // NOTE: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + // avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + // when the gradients do not reside in CPU memory. + // clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + if clipCoef > 1.0 { + clipCoef = 1.0 + } + for _, v := range opt.varstore.vars { + if v.Trainable { + // p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device)) + // v.Tensor.MustGrad(false).MustDetach(true).MustMulScalar_(ts.FloatScalar(clipCoef)) + v.Tensor.MustGrad(false).MustMulScalar_(ts.FloatScalar(clipCoef)) + } + } + + return nil +} + +// BackwardStepClipNorm applies a backward step pass, update the gradients, and performs an optimization step. +// +// The gradients L2 norm is clipped based on `max`. +func (opt *Optimizer) BackwardStepClipNorm(loss *ts.Tensor, max float64, opts ...ClipOpt) error { + err := opt.opt.ZeroGrad() + if err != nil { + err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err) + return err + } + err = loss.Backward() + if err != nil { + err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err) + return err + } + + err = opt.ClipGradNorm(max, opts...) + if err != nil { + err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err) + return err + } + + err = opt.Step() + if err != nil { + err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err) + return err + } + + return nil +} + +// MustBackwardStepClipNorm applies a backward step pass, update the gradients, and performs an optimization step. +// +// The gradients L2 norm is clipped based on `max`. +func (opt *Optimizer) MustBackwardStepClipNorm(loss *ts.Tensor, max float64, opts ...ClipOpt) { + err := opt.BackwardStepClipNorm(loss, max, opts...) + if err != nil { + log.Fatal(err) + } +} + +// SetLR sets the optimizer learning rate. +// +// NOTE. it sets a SINGLE value of learning rate for all parameter groups. +// Most of the time, there's one parameter group. +func (opt *Optimizer) SetLR(lr float64) { + err := opt.opt.SetLearningRate(lr) + if err != nil { + log.Fatalf("Optimizer - SetLR method call error: %v\n", err) + } +} + +func (opt *Optimizer) GetLRs() []float64 { + lrs, err := opt.opt.GetLearningRates() + if err != nil { + log.Fatalf("Optimizer - GetLRs method call error: %v\n", err) + } + + return lrs +} + +// SetLRs sets learning rates for ALL parameter groups respectively. +func (opt *Optimizer) SetLRs(lrs []float64) { + err := opt.opt.SetLearningRates(lrs) + if err != nil { + log.Fatalf("Optimizer - SetLRs method call error: %v\n", err) + } +} + +// SetMomentum sets the optimizer momentum. +func (opt *Optimizer) SetMomentum(m float64) { + err := opt.opt.SetMomentum(m) + if err != nil { + log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err) + } +} + +func (opt *Optimizer) ParamGroupNum() int { + ngroup, err := opt.opt.ParamGroupNum() + if err != nil { + log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err) + } + + return int(ngroup) +} + +func (opt *Optimizer) AddParamGroup(tensors []*ts.Tensor) { + err := opt.opt.AddParamGroup(tensors) + if err != nil { + log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err) + } +} diff --git a/logic/models/train/torch/nn/utils.go b/logic/models/train/torch/nn/utils.go new file mode 100644 index 0000000..17070bc --- /dev/null +++ b/logic/models/train/torch/nn/utils.go @@ -0,0 +1,17 @@ +package my_nn + +import ( + torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" +) + +func or_panic(err error) { + if err != nil { + panic(err) + } +} + +type MyLayer interface { + torch.ModuleT + + ExtractFromVarstore(vs *VarStore) +} diff --git a/logic/models/train/torch/nn/varstore.go b/logic/models/train/torch/nn/varstore.go new file mode 100644 index 0000000..4940fc9 --- /dev/null +++ b/logic/models/train/torch/nn/varstore.go @@ -0,0 +1,1359 @@ +package my_nn + +import ( + "fmt" + "reflect" + "sort" + "strings" + "sync" + + "github.com/charmbracelet/log" + "git.andr3h3nriqu3s.com/andr3/gotch" + "git.andr3h3nriqu3s.com/andr3/gotch/nn" + "git.andr3h3nriqu3s.com/andr3/gotch/ts" +) + +// SEP is a separator to separate path elements in the tensor names. +const SEP = "." + +type Var struct { + Tensor *ts.Tensor + Group uint // optimizer parameter group + Type string // can be "parameter" or "buffer" + Trainable bool // marked this variable is either trainable or not.For "buffer" type, it's always `false` + Persitent bool // only applied to "buffer" type. All parameters are persistent (when do VarStore.Save()). +} + +// VarStore is used to store variables used by one or multiple layers. +// It specifies a SINGLE device where all variables are stored. +type VarStore struct { + sync.Mutex + device gotch.Device + vars map[string]Var +} + +// Path is variable store with an associated path for variables naming. +type Path struct { + path []string + varstore *VarStore + group uint // optimizer parameter group +} + +// Entry holds an entry corresponding to a given name in Path. +type Entry struct { + name string + path *Path +} + +// NewVarStore creates a new variable store located on the specified device +func NewVarStore(device gotch.Device) *VarStore { + return &VarStore{ + device: device, + vars: make(map[string]Var, 0), + } +} + +// NOTE: +// To get (initiate) a path, call vs.Root() + +// VarStore methods: +// ================= + +// Device returns device for this VarStore. +func (vs *VarStore) Device() gotch.Device { + return vs.device +} + +// Len returns the number of tensors currently kept in this VarStore. +func (vs *VarStore) Len() int { + vs.Lock() + defer vs.Unlock() + return len(vs.vars) +} + +// IsEmpty returns true if no tensors currently kept in this VarStore. +func (vs *VarStore) IsEmpty() bool { + vs.Lock() + defer vs.Unlock() + return (len(vs.vars) == 0) +} + +// TrainableVariabless returns reference to all trainable variables kept in VarStore. +func (vs *VarStore) TrainableVariables() []*ts.Tensor { + vs.Lock() + defer vs.Unlock() + + var trainables []*ts.Tensor + for _, v := range vs.vars { + x := v.Tensor + if x.MustRequiresGrad() { + trainables = append(trainables, x) + } + } + + return trainables +} + +func (vs *VarStore) GetVarOfName(name string) Var { + return vs.vars[name] +} + +func (vs *VarStore) GetTensorOfVar(name string) *ts.Tensor { + return vs.vars[name].Tensor +} + +func (vs *VarStore) UpdateVarTensor(name string, tensor *ts.Tensor, del bool) { + if (del) { + or_panic(vs.vars[name].Tensor.Drop()) + } + v := vs.vars[name] + v.Tensor = tensor + vs.vars[name] = v +} + +// Variables returns reference of all variables and their names in a map[variable_name]Tensor +// +// NOTE. returned map includes all variables of "parameter" and "buffer" type. +func (vs *VarStore) Variables() map[string]ts.Tensor { + vs.Lock() + defer vs.Unlock() + + namedTensors := make(map[string]ts.Tensor, 0) + for k, v := range vs.vars { + namedTensors[k] = *v.Tensor + } + + return namedTensors +} + +// Root gets the root path for this VarStore. +// +// NOTE: Variables are named and organized using paths. This function returns +// the top level path for the var store and can be combined with '/' +// to create sub-paths. +func (vs *VarStore) Root() *Path { + return &Path{ + path: []string{}, + varstore: vs, + group: 0, + } +} + +// Save saves the VarStore variable values to a file. +// +// NOTE: Weight values for all the tensors currently stored in the +// var-store gets saved in the given file. +func (vs *VarStore) Save(filepath string) error { + vs.Lock() + defer vs.Unlock() + + var namedTensors []ts.NamedTensor + for k, v := range vs.vars { + if v.Type == "parameter" || (v.Type == "buffer" && v.Persitent) { + namedTensors = append(namedTensors, ts.NamedTensor{ + Name: k, + Tensor: v.Tensor, + }) + } + } + + // return ts.SaveMulti(namedTensors, filepath) + return ts.SaveMultiNew(namedTensors, filepath) +} + +// Load loads VarStore variable values from a file. +// +// NOTE: Weight values for all the tensors currently stored in the +// VarStore gets loaded from the given file. Note that the set of +// variables stored in the VarStore is not changed, only the values +// for these tensors are modified. +// It will throw error if name of the loaded tensors can not find +// in the current VarStore named tensors set. +func (vs *VarStore) Load(filepath string) error { + namedTensors, err := ts.LoadMultiWithDevice(filepath, vs.device) + if err != nil { + return err + } + + var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) + for _, namedTensor := range namedTensors { + namedTensorsMap[namedTensor.Name] = namedTensor.Tensor + } + + // Match and in-place copy value (update) from newly loaded tensors + // to existing named tensors if name is matched. Throw error otherwise. + vs.Lock() + defer vs.Unlock() + + for name, v := range vs.vars { + // missing variable + currTs, ok := namedTensorsMap[name] + if !ok { + err = fmt.Errorf("VarStore.Load() failed: there's a tensor with name %q in VarStore, but not found in the loaded weights.\n", name) + return err + } + + // mismatched shape + sourceShape := currTs.MustSize() + destShape := v.Tensor.MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + err = fmt.Errorf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", name, destShape, sourceShape) + return err + } + + ts.NoGrad(func() { + v.Tensor.Copy_(currTs) + }) + } + + for _, x := range namedTensors { + x.Tensor.MustDrop() + } + + ts.CleanUp() + + return nil +} + +// LoadWeights loads pretrained weights to VarStore. +func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error { + var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) + for _, namedTensor := range namedTensors { + namedTensorsMap[namedTensor.Name] = namedTensor.Tensor + } + + // Match and in-place copy value (update) from newly loaded tensors + // to existing named tensors if name is matched. Throw error otherwise. + vs.Lock() + defer vs.Unlock() + + for name, v := range vs.vars { + // missing variable + currTs, ok := namedTensorsMap[name] + if !ok { + err := fmt.Errorf("VarStore.LoadWeights() failed: there's a tensor with name %q in VarStore, but not found in the loaded weights.\n", name) + return err + } + + // mismatched shape + sourceShape := currTs.MustSize() + destShape := v.Tensor.MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + err := fmt.Errorf("VarStore.LoadWeights() failed. Mismatched shape error for variable name: %v - At store: %v - At source %v\n", name, destShape, sourceShape) + return err + } + + ts.NoGrad(func() { + v.Tensor.Copy_(currTs) + }) + } + + ts.CleanUp() + + return nil +} + +// LoadPartial loads the VarStore variable values from a file if it exists. +// +// Weight values for the tensors currently stored in the var-store and the given file get +// loaded from the given file. If a variable in the var store is not present in the given file, +// it is skipped and its values are not updated. This method should be used if pre-trained +// weight for only parts of the model are available. +// Note that the set of variables stored in the var-store is not changed, only the values +// for these tensors are modified. +// +// Returns a String Vector containing the names of missing variables. +func (vs *VarStore) LoadPartial(filepath string) ([]string, error) { + namedTensors, err := ts.LoadMultiWithDevice(filepath, vs.device) + if err != nil { + return nil, err + } + + var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) + for _, namedTensor := range namedTensors { + namedTensorsMap[namedTensor.Name] = namedTensor.Tensor + } + + var missingVariables []string + + // Match and in-place copy value (update) from newly loaded tensors + // to existing named tensors if name is matched. Throw error otherwise. + vs.Lock() + defer vs.Unlock() + + for name, v := range vs.vars { + var currTs *ts.Tensor + var ok bool + + // missing variable + if currTs, ok = namedTensorsMap[name]; !ok { + missingVariables = append(missingVariables, name) + continue + } + + // mismatched shape + destShape := currTs.MustSize() + sourceShape := v.Tensor.MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", name, destShape, sourceShape) + missingVariables = append(missingVariables, name) + continue + } + + ts.NoGrad(func() { + v.Tensor.Copy_(currTs) + }) + } + + for _, x := range namedTensors { + x.Tensor.MustDrop() + } + + ts.CleanUp() + + return missingVariables, nil +} + +// LoadWeightsPartial loads the VarStore variable values from a file if it exists. +// +// Weight values for the tensors currently stored in the var-store and the given file get +// loaded from the given file. If a variable in the var store is not present in the given file, +// it is skipped and its values are not updated. This method should be used if pre-trained +// weight for only parts of the model are available. +// Note that the set of variables stored in the var-store is not changed, only the values +// for these tensors are modified. +// +// Returns a String Vector containing the names of missing variables. +func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string, error) { + var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0) + for _, namedTensor := range namedTensors { + namedTensorsMap[namedTensor.Name] = namedTensor.Tensor + } + + var missingVariables []string + + // Match and in-place copy value (update) from newly loaded tensors + // to existing named tensors if name is matched. Throw error otherwise. + vs.Lock() + defer vs.Unlock() + + for name, v := range vs.vars { + var currTs *ts.Tensor + var ok bool + + // missing variable + if currTs, ok = namedTensorsMap[name]; !ok { + missingVariables = append(missingVariables, name) + continue + } + + // mismatched shape + destShape := currTs.MustSize() + sourceShape := v.Tensor.MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", name, destShape, sourceShape) + missingVariables = append(missingVariables, name) + continue + } + + ts.NoGrad(func() { + v.Tensor.Copy_(currTs) + }) + } + + ts.CleanUp() + + return missingVariables, nil +} + +// Freeze freezes this VarStore. +// +// Gradients for the variables in this store are not tracked anymore. +func (vs *VarStore) Freeze() error { + vs.Lock() + defer vs.Unlock() + + for name, v := range vs.vars { + err := v.Tensor.RequiresGrad_(false) + if err != nil { + err = fmt.Errorf("VarStore.Freeze() set 'requiresGrad' for tensor %q failed.", name) + return err + } + } + + return nil +} + +// Unfreeze unfreezes a VarStore. +// +// Gradients for the variables in this store are tracked again. +func (vs *VarStore) Unfreeze() error { + vs.Lock() + defer vs.Unlock() + + for name, v := range vs.vars { + if v.Type == "parameter" && v.Trainable { + err := v.Tensor.RequiresGrad_(true) + err = fmt.Errorf("VarStore.Freeze() set 'requiresGrad' for tensor %q failed.", name) + return err + } + } + return nil +} + +// Copy copies variable values from a source VarStore to this VarStore. +// +// All the variables in this var store have to exist with the same +// name in the source var store, otherwise an error is returned. +func (vs *VarStore) Copy(src *VarStore) error { + vs.Lock() + defer vs.Unlock() + src.Lock() + defer src.Unlock() + + srcVars := src.vars + device := vs.device + + for k := range vs.vars { + if _, ok := srcVars[k]; !ok { + err := fmt.Errorf("VarStore.Copy() failed: cannot find %q in the source VarStore.\n", k) + return err + } + } + + for k, v := range vs.vars { + srcV := srcVars[k] + srcDevTs, err := srcV.Tensor.To(device, false) + if err != nil { + return err + } + ts.NoGrad(func() { + v.Tensor.Copy_(srcDevTs) + }) + srcDevTs.MustDrop() + } + + ts.CleanUp() + + return nil +} + +// Summary prints a simple list of all named variables with their shapes. +func (vs *VarStore) Summary() { + vars := vs.vars + layers := make([]string, 0, len(vars)) + for name := range vars { + layers = append(layers, name) + } + sort.Strings(layers) + var dtype gotch.DType + isFirst := true + for _, l := range layers { + var x *ts.Tensor + var isBuffer bool + for name, v := range vars { + if name == l { + x = v.Tensor + + // Get DType of first tensor for representation only + if isFirst { + dtype = x.DType() + } + isFirst = false + + isBuffer = v.Type == "buffer" + break + } + } + if isBuffer { + fmt.Printf("%s - [buffer] - %+v\n", l, x.MustSize()) + } else { + fmt.Printf("%s - %+v\n", l, x.MustSize()) + } + } + + fmt.Printf("Num of layers: %v\n", len(vars)) + fmt.Printf("DType: %v\n", dtype) +} + +// Destroy deletes all tensors in varstore and set it to nil. +func (vs *VarStore) Destroy() { + vs.Lock() + for n, v := range vs.vars { + v.Tensor.MustDrop() + + delete(vs.vars, n) + } + + vs.Unlock() + + vs = nil +} + +// ToDType casts all variables in VarStore to specified DType. +// +// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible. +func (vs *VarStore) ToDType(dtype gotch.DType) { + vs.Root().ToDType(dtype) +} + +// ToFloat casts all float-like variables in VarStore to `Float` dtype. +// +// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype. +func (vs *VarStore) ToFloat() { + vs.Root().ToFloat() +} + +// ToDouble casts all float-like variables in VarStore to `Double` dtype. +// +// NOTE. float-like includes `Half`, `Float` and `Double` dtype. +func (vs *VarStore) ToDouble() { + vs.Root().ToDouble() +} + +// ToHalf casts all float-like variables in VarStore to `Half` dtype. +// +// NOTE. float-like includes `Half`, `Float` and `Double` dtype. +func (vs *VarStore) ToHalf() { + vs.Root().ToHalf() +} + +// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype. +// +// NOTE. float-like includes `Half`, `Float` and `Double` dtype. +func (vs *VarStore) ToBFloat16() { + vs.Root().ToBFloat16() +} + +func (vs *VarStore) ToDevice(device gotch.Device) { + p := vs.Root() + p.ToDevice(device) +} + +// Path methods: +// ============= + +// Sub gets a sub-path of the given path. +func (p *Path) Sub(str string) *Path { + if strings.Contains(str, SEP) { + log.Fatalf("Path.Sub() failed: name cannot contain %v (%v)\n", SEP, str) + } + + path := p.path + path = append(path, str) + return &Path{ + path: path, + varstore: p.varstore, + group: p.group, + } +} + +// Paths returns all sub paths from current path. +func (p *Path) Paths() []string { + return p.path +} + +// Device gets the device where the VarStore variables are stored. +func (p *Path) Device() gotch.Device { + return p.varstore.device +} + +// NOTE: Cannot name as `path` as having a field name `path` +func (p *Path) getpath(name string) string { + if strings.Contains(name, SEP) { + log.Fatalf("Sub name cannot contain %v (%v)\n", SEP, name) + } + + if len(p.path) == 0 { + return name + } else { + return fmt.Sprintf("%v%v%v", strings.Join(p.path, SEP), SEP, name) + } +} + +func (p *Path) addNamed(name string, newTs *ts.Tensor, trainable bool, varType string, persistent bool) (*ts.Tensor, string, error) { + path := p.getpath(name) + + p.varstore.Lock() + defer p.varstore.Unlock() + + if _, ok := p.varstore.vars[path]; ok { + path = fmt.Sprintf("%v__%v", path, len(p.varstore.vars)) + } + + var ( + tensor *ts.Tensor + err error + ) + + if trainable { + tensor, err = newTs.SetRequiresGrad(true, false) + if err != nil { + return nil, path, err + } + } else { + tensor, err = newTs.ShallowClone() + if err != nil { + return nil, path, err + } + } + + v := Var{ + Tensor: tensor, + Group: p.group, + Trainable: trainable, + Type: varType, + Persitent: persistent, + } + p.varstore.vars[path] = v + + return tensor, path, nil +} + +func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string, persistent bool) (*ts.Tensor, error) { + ts, _, err := p.addNamed(name, newTs, trainable, varType, persistent) + return ts, err +} + +type AddOpts struct { + VarType string + Persistent bool +} + +type AddOpt func(*AddOpts) + +func defaultAddOpts() *AddOpts { + return &AddOpts{ + VarType: "parameter", + Persistent: true, + } +} + +func WithVarType(v string) AddOpt { + if v != "parameter" && v != "buffer" { + log.Fatalf("WithVarType() failed(): invalid option variable type. Input must be either 'parameter' or 'buffer'.") + } + + return func(o *AddOpts) { + o.VarType = v + } +} + +func WithPersistent(v bool) AddOpt { + return func(o *AddOpts) { + o.Persistent = v + } +} + +// Add adds a tensor to a given path. +// +// Args +// - name: intention name of variable in VarStore (if duplicated, it will be added a suffix number) +// - x: tensor holding values to keep in VarStore +// - trainable: marked whether tensor is trainable. +// - o.VarType: variable type, i.e., either "parameter" or "buffer" +// - o.Persistent: whether to save this variables when `VarStore.Save()` is called. Only applied to `buffer` type. +// Returns a reference to a tensor stored in VarStore and error if occurred. +func (p *Path) Add(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) { + o := defaultAddOpts() + for _, opt := range opts { + opt(o) + } + + return p.add(name, x, trainable, o.VarType, o.Persistent) +} + +func (p *Path) AddNamed(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, string, error) { + o := defaultAddOpts() + for _, opt := range opts { + opt(o) + } + + return p.addNamed(name, x, trainable, o.VarType, o.Persistent) +} + +// MustAdd adds a tensor to a given path. +// +// Args +// - name: intention name of variable in VarStore (if duplicated, it will be added a suffix number) +// - x: tensor holding values to keep in VarStore +// - trainable: marked whether tensor is trainable. +// - o.VarType: variable type, i.e., either "parameter" or "buffer" +// - o.Persistent: whether to save this variables when `VarStore.Save()` is called. Only applied to `buffer` type. +// Returns a reference to a tensor stored in VarStore. +func (p *Path) MustAdd(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) *ts.Tensor { + x, err := p.Add(name, x, trainable, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// Remove removes a variable from `VarStore` +func (p *Path) Remove(name string) error { + p.varstore.Lock() + defer p.varstore.Unlock() + + _, ok := p.varstore.vars[name] + if !ok { + err := fmt.Errorf("Path.Remove() failed: cannot find a variable with name %q in VarStore.", name) + return err + } + + delete(p.varstore.vars, name) + return nil +} + +// MustRemove removes a variable from `VarStore` +func (p *Path) MustRemove(name string) { + err := p.Remove(name) + if err != nil { + err = fmt.Errorf("Path.MustRemove() failed: %w", err) + log.Fatal(err) + } +} + +func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) { + path := p.getpath(name) + + // if found, return it + if v, ok := p.varstore.vars[path]; ok { + return v.Tensor, nil + } + + // not found, add it + return p.Add(name, tensor, trainable, opts...) +} + +func (p *Path) SetGroup(g uint) { + p.varstore.Lock() + defer p.varstore.Unlock() + + // TODO. set group for individual variables. + // TBD. variables of current path only or all sub paths as well? + // For now, just set group for all variable at the path + path := strings.Join(p.path, SEP) + for name, v := range p.varstore.vars { + vpaths := strings.Split(name, SEP) + vpath := strings.Join(vpaths[:len(vpaths)-1], SEP) + if vpath == path { + v.Group = g + p.varstore.vars[name] = v + } + } + p.group = g +} + +// ToDType casts all variables in this path and its sub-paths to the specified dtype. +// +// NOTE. this method should be used for floating-point conversion, i.e., +// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double". +func (p *Path) ToDType(dtype gotch.DType) { + p.varstore.Lock() + defer p.varstore.Unlock() + path := strings.Join(p.path, SEP) + for name, v := range p.varstore.vars { + if strings.Contains(name, path) { + newVar := v + newVar.Tensor = v.Tensor.MustTotype(dtype, true) + p.varstore.vars[name] = newVar + } + } +} + +// toFloat casts all float-like variables in this current path and sub-paths to specified dtype. +func (p *Path) toFloat(dtype gotch.DType) { + p.varstore.Lock() + defer p.varstore.Unlock() + path := strings.Join(p.path, SEP) + for name, v := range p.varstore.vars { + if strings.Contains(name, path) { + dtype := v.Tensor.DType() + if gotch.IsFloatDType(dtype) { + newVar := v + newVar.Tensor = v.Tensor.MustTotype(dtype, true) + p.varstore.vars[name] = newVar + } + } + } + + ts.CleanUp() +} + +// ToFloat casts all variables in current path and subpaths to `Float` precision. +func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) { + dtype := gotch.Float + if len(floatDTypeOpt) > 0 { + dt := floatDTypeOpt[0] + if !gotch.IsFloatDType(dt) { + // Ingore the option + if gotch.Debug { + log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt) + } + } else { + dtype = dt + } + } + + p.toFloat(dtype) +} + +// ToDouble casts all variables in current path and subpaths to `Double` precision dtype. +func (p *Path) ToDouble() { + p.toFloat(gotch.Double) +} + +// ToHalf casts all variables in current path and subpaths to `Half` precision dtype. +func (p *Path) ToHalf() { + p.toFloat(gotch.Half) +} + +// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype. +func (p *Path) ToBFloat16() { + p.toFloat(gotch.BFloat16) +} + +func (p *Path) ToDevice(device gotch.Device) { + var err error + + p.varstore.Lock() + defer p.varstore.Unlock() + path := strings.Join(p.path, SEP) + + for name, v := range p.varstore.vars { + if strings.Contains(name, path) { + newVar := v + + var requires bool + requires, err = v.Tensor.RequiresGrad() + or_panic(err) + + var retain bool + retain, err = v.Tensor.RetainsGrad(false) + or_panic(err) + + newVar.Tensor, err = v.Tensor.To(device, true) + or_panic(err) + + newVar.Tensor, err = newVar.Tensor.SetRequiresGrad(requires, true) + or_panic(err) + + if retain { + err = newVar.Tensor.RetainGrad(false) + or_panic(err) + } + + p.varstore.vars[name] = newVar + + } + } + + ts.CleanUp() +} + +// ZerosNoTrain creates a new variable initialized with zeros. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable will not be trainable so +// gradients will not be tracked. +// The variable uses a float tensor initialized with zeros. +func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + device := p.Device() + dtype := gotch.DefaultDType + z, err := ts.Zeros(dims, dtype, device) + if err != nil { + err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err) + return nil, err + } + + out, err := p.Add(name, z, false, opts...) + if err != nil { + return nil, err + } + z.MustDrop() + + return out, nil +} + +// MustZerosNoTrain creates a new variable initialized with zeros. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable will not be trainable so +// gradients will not be tracked. +// The variable uses a float tensor initialized with zeros. +func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := p.ZerosNoTrain(name, dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OnesNoTrain creates a new variable initialized with ones. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable will not be trainable so +// gradients will not be tracked. +// The variable uses a float tensor initialized with ones. +func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + device := p.Device() + dtype := gotch.DefaultDType + z, err := ts.Ones(dims, dtype, device) + if err != nil { + err = fmt.Errorf("Path.OneNoTrain() failed: %w", err) + return nil, err + } + + out, err := p.Add(name, z, false, opts...) + if err != nil { + return nil, err + } + z.MustDrop() + + return out, nil +} + +// MustOnesNoTrain creates a new variable initialized with ones. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable will not be trainable so +// gradients will not be tracked. +// The variable uses a float tensor initialized with ones. +func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := p.OnesNoTrain(name, dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// NewVar creates a new variable. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized as per the +// related argument. +func (p *Path) NewVarNamed(name string, dims []int64, ini nn.Init, opts ...AddOpt) (*ts.Tensor, string, error) { + dtype := gotch.DefaultDType + // v := ini.InitTensor(dims, p.varstore.device, dtype) + var v *ts.Tensor + + v = ini.InitTensor(dims, p.varstore.device, dtype) + + out, path, err := p.AddNamed(name, v, true, opts...) + if err != nil { + return nil, "", err + } + + err = v.Drop() + or_panic(err) + + return out,path, err +} + +// NewVar creates a new variable. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized as per the +// related argument. +func (p *Path) NewVar(name string, dims []int64, ini nn.Init, opts ...AddOpt) (*ts.Tensor, error) { + dtype := gotch.DefaultDType + // v := ini.InitTensor(dims, p.varstore.device, dtype) + var v *ts.Tensor + + v = ini.InitTensor(dims, p.varstore.device, dtype) + + out, err := p.Add(name, v, true, opts...) + if err != nil { + return nil, err + } + + err = v.Drop() + or_panic(err) + + return out, err +} + +// MustNewVar create a new variable. It panics if error. +func (p *Path) MustNewVar(name string, dims []int64, ini nn.Init, opts ...AddOpt) *ts.Tensor { + x, err := p.NewVar(name, dims, ini, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// Zeros creates a new variable initialized with zeros. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized with zeros. +func (p *Path) Zeros(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return p.NewVar(name, dims, nn.NewConstInit(0.0), opts...) +} + +// MustZeros create a new variables with zero values. It panics if error. +func (p *Path) MustZeros(name string, dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := p.Zeros(name, dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// Ones creates a new variable initialized with ones. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized with ones. +func (p *Path) Ones(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return p.NewVar(name, dims, nn.NewConstInit(1.0), opts...) +} + +// MustOnes creates a new variable initialized with ones. It panics if error occurred. +func (p *Path) MustOnes(name string, dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := p.Ones(name, dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// RandnStandard creates a new variable initialized randomly with normal distribution. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized randomly using a +// standard normal distribution. +func (p *Path) RandnStandard(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + + return p.NewVar(name, dims, nn.NewRandnInit(0.0, 1.0), opts...) +} + +// MustRandnStandard creates a new variable initialized randomly with normal distribution. It panics if error occurred. +func (p *Path) MustRandnStandard(name string, dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := p.RandnStandard(name, dims, opts...) + if err != nil { + log.Fatal(err) + } + + return x +} + +// Randn creates a new variable initialized randomly with normal distribution. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized randomly using a +// normal distribution with the specified mean and standard deviation. +func (p *Path) Randn(name string, dims []int64, mean float64, stdev float64, opts ...AddOpt) (*ts.Tensor, error) { + return p.NewVar(name, dims, nn.NewRandnInit(mean, stdev), opts...) +} + +// MustRandn creates a new variable initialized randomly with normal distribution. It panics if error occurred. +func (p *Path) MustRandn(name string, dims []int64, mean float64, stdev float64, opts ...AddOpt) *ts.Tensor { + x, err := p.Randn(name, dims, mean, stdev, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// Uniform creates a new variable initialized randomly with uniform distribution. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized randomly using a +// uniform distribution between the specified bounds. +func (p *Path) Uniform(name string, dims []int64, lo, up float64, opts ...AddOpt) (*ts.Tensor, error) { + return p.NewVar(name, dims, nn.NewUniformInit(lo, up), opts...) +} + +// MustUniform creates a new variable initialized randomly with uniform distribution. It panics if error occurred. +func (p *Path) MustUniform(name string, dims []int64, lo, up float64, opts ...AddOpt) *ts.Tensor { + x, err := p.Uniform(name, dims, lo, up, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// KaimingUniform creates a new variable initialized randomly with kaiming uniform. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized randomly using a +// uniform distribution which bounds follow Kaiming initialization. +func (p *Path) KaimingUniform(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return p.NewVar(name, dims, nn.NewKaimingUniformInit(), opts...) +} + +// MustKaimingUniform creates a new variable initialized randomly with kaiming uniforms. It panics if error occurred. +func (p *Path) MustKaimingUniform(name string, dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := p.KaimingUniform(name, dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// VarCopy creates a new variable initialized by copying an existing tensor. +// +// The new variable is named according to the name parameter and +// has the specified shape. The variable is trainable, its gradient +// will be tracked. +// The variable uses a float tensor initialized by copying some +// given tensor. +func (p *Path) VarCopy(name string, t *ts.Tensor) (*ts.Tensor, error) { + size, err := t.Size() + if err != nil { + err = fmt.Errorf("Path.VarCopy() failed: %w\n", err) + return nil, err + } + v, err := p.Zeros(name, size) + if err != nil { + return nil, err + } + + ts.NoGrad(func() { + ts.Copy_(v, t) + }) + + return v, nil +} + +// VarCopy creates a new variable initialized by copying an existing tensor. +func (p *Path) MustVarCopy(name string, t *ts.Tensor) *ts.Tensor { + x, err := p.VarCopy(name, t) + if err != nil { + log.Fatal(err) + } + return x +} + +// Get gets a reference to tensor corresponding to a given name if present. +func (p *Path) Get(name string) (*ts.Tensor, error) { + p.varstore.Lock() + defer p.varstore.Unlock() + + v, ok := p.varstore.vars[name] + if !ok { + err := fmt.Errorf("Path.Get() failed: Cannot find variable for name: %v\n", name) + return nil, err + } + + return v.Tensor, nil +} + +// MustGet gets a reference to a tensor corresponding to a given name if present. It panics if error occurred. +func (p *Path) MustGet(name string) *ts.Tensor { + x, err := p.Get(name) + if err != nil { + log.Fatal(err) + } + return x +} + +// Entry gets the entry corresponding to a given name for in-place manipulation. +func (p *Path) Entry(name string) *Entry { + p.varstore.Lock() + defer p.varstore.Unlock() + + return &Entry{ + name: name, + path: p, + } +} + +// Entry methods: +// ============== + +// OrVar returns the existing entry if found, otherwise create a new variable. +// +// If this entry name matches the name of a variables stored in the +// var store, the corresponding tensor is returned. Otherwise a new +// variable is added to the var-store with the entry name and is +// initialized according to the init parameter. +func (e *Entry) OrVar(dims []int64, init nn.Init, opts ...AddOpt) (*ts.Tensor, error) { + v := init.InitTensor(dims, e.path.varstore.device) + out, err := e.path.getOrAddWithLock(e.name, v, true, opts...) + if err != nil { + return nil, err + } + v.MustDrop() + + return out, nil +} + +// MustOrVar returns the existing entry if found, otherwise creates a new variable. It panics if error. +func (e *Entry) MustOrVar(dims []int64, init nn.Init, opts ...AddOpt) *ts.Tensor { + x, err := e.OrVar(dims, init, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrZeros returns the existing entry if found, otherwise creates a new variable. +func (e *Entry) OrZeros(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return e.OrVar(dims, nn.NewConstInit(0.0)) +} + +// MustOrZeros returns the exising entry if found, otherwise creates a new variable. +func (e *Entry) MustOrZeros(dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrZeros(dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrVarCopy returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrVarCopy(tensor *ts.Tensor) (*ts.Tensor, error) { + size, err := tensor.Size() + if err != nil { + return nil, err + } + v, err := e.OrZeros(size) + if err != nil { + return nil, err + } + + ts.NoGrad(func() { + ts.Copy_(v, tensor) + }) + + return v, nil +} + +// MustOrVarCopy returns the existing entry if found, otherwise create a new variable. +func (e *Entry) MustOrVarCopy(tensor *ts.Tensor) *ts.Tensor { + x, err := e.OrVarCopy(tensor) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrKaimingUniform returns the existing entry if, otherwise create a new variable. +func (e *Entry) OrKaimingUniform(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return e.OrVar(dims, nn.NewKaimingUniformInit(), opts...) +} + +// MustOrKaimingUniform returns the existing entry if, otherwise create a new variable. +func (e *Entry) MustOrKaimingUniform(dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrKaimingUniform(dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrOnes returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrOnes(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return e.OrVar(dims, nn.NewConstInit(1.0), opts...) +} + +// MustOrOnes returns the existing entry if found, otherwise create a new variable. +func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrOnes(dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + dtype := gotch.DefaultDType + o := ts.MustOnes(dims, dtype, e.path.Device()) + out, err := e.path.getOrAddWithLock(e.name, o, true, opts...) + if err != nil { + return nil, err + } + o.MustDrop() + + return out, nil +} + +// MustOrOnesNoTrain returns the existing entry if found, otherwise create a new variable. +func (e *Entry) MustOrOnesNoTrain(dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrOnesNoTrain(dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrRandn returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrRandn(dims []int64, mean, stdev float64, opts ...AddOpt) (*ts.Tensor, error) { + return e.OrVar(dims, nn.NewRandnInit(mean, stdev), opts...) +} + +// MustOrRandn returns the existing entry if, otherwise create a new variable. +func (e *Entry) MustOrRandn(dims []int64, mean, stdev float64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrRandn(dims, mean, stdev, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrRandnStandard returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrRandnStandard(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + return e.OrVar(dims, nn.NewRandnInit(0.0, 1.0), opts...) +} + +// MustOrRandnStandard returns the existing entry if, otherwise create a new variable. +func (e *Entry) MustOrRandnStandard(dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrRandnStandard(dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrUniform returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrUniform(dims []int64, lo, up float64, opts ...AddOpt) (*ts.Tensor, error) { + return e.OrVar(dims, nn.NewUniformInit(lo, up), opts...) +} + +// MustOrUniform returns the existing entry if found, otherwise create a new variable. +func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrUniform(dims, lo, up, opts...) + if err != nil { + log.Fatal(err) + } + return x +} + +// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable. +func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { + dtype := gotch.DefaultDType + z := ts.MustZeros(dims, dtype, e.path.Device()) + out, err := e.path.getOrAddWithLock(e.name, z, true, opts...) + if err != nil { + return nil, err + } + z.MustDrop() + + return out, nil +} + +// MustOrZerosNoTrain returns the existing entry if found, otherwise create a new variable. +func (e *Entry) MustOrZerosNoTrain(dims []int64, opts ...AddOpt) *ts.Tensor { + x, err := e.OrZerosNoTrain(dims, opts...) + if err != nil { + log.Fatal(err) + } + return x +} diff --git a/logic/models/train/torch/torch.go b/logic/models/train/torch/torch.go index 7a34723..a485daf 100644 --- a/logic/models/train/torch/torch.go +++ b/logic/models/train/torch/torch.go @@ -2,14 +2,12 @@ package train import ( types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" + "git.andr3h3nriqu3s.com/andr3/gotch" "github.com/charmbracelet/log" - "github.com/sugarme/gotch" - "github.com/sugarme/gotch/nn" - //"github.com/sugarme/gotch" - //"github.com/sugarme/gotch/vision" - torch "github.com/sugarme/gotch/ts" + torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" ) type IForwardable interface { @@ -18,23 +16,55 @@ type IForwardable interface { // Container for a model type ContainerModel struct { - Seq *nn.SequentialT - Vs *nn.VarStore + Layers []my_nn.MyLayer + Vs *my_nn.VarStore + path *my_nn.Path } func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { - return n.Seq.ForwardT(x, train) + if len(n.Layers) == 0 { + return x.MustShallowClone() + } + + if len(n.Layers) == 1 { + return n.Layers[0].ForwardT(x, train) + } + + // forward sequentially + outs := make([]*torch.Tensor, len(n.Layers)) + for i := 0; i < len(n.Layers); i++ { + if i == 0 { + outs[0] = n.Layers[i].ForwardT(x, train) + defer outs[0].MustDrop() + } else if i == len(n.Layers)-1 { + return n.Layers[i].ForwardT(outs[i-1], train) + } else { + outs[i] = n.Layers[i].ForwardT(outs[i-1], train) + defer outs[i].MustDrop() + } + } + panic("Do not reach here") } func (n *ContainerModel) To(device gotch.Device) { n.Vs.ToDevice(device) + for _, layer := range n.Layers { + layer.ExtractFromVarstore(n.Vs) + } +} + +func (n *ContainerModel) Refresh() { + for _, layer := range n.Layers { + layer.ExtractFromVarstore(n.Vs) + } } func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *ContainerModel { - base_vs := nn.NewVarStore(gotch.CPU) + base_vs := my_nn.NewVarStore(gotch.CPU) vs := base_vs.Root() - seq := nn.SeqT() + + m_layers := []my_nn.MyLayer{} var lastLinearSize int64 = _lastLinearSize lastLinearConv := []int64{} @@ -46,18 +76,19 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) * } else if layer.LayerType == types.LAYER_DENSE { shape := layer.GetShape() log.Info("New Dense: ", "In:", lastLinearSize, "out:", shape[0]) - seq.Add(NewLinear(vs, lastLinearSize, shape[0])) + m_layers = append(m_layers, NewLinear(vs, lastLinearSize, shape[0])) lastLinearSize = shape[0] } else if layer.LayerType == types.LAYER_FLATTEN { - seq.Add(NewFlatten()) + m_layers = append(m_layers, NewFlatten()) lastLinearSize = 1 for _, i := range lastLinearConv { lastLinearSize *= i } log.Info("Flatten: ", "In:", lastLinearConv, "out:", lastLinearSize) } else if layer.LayerType == types.LAYER_SIMPLE_BLOCK { + panic("TODO") log.Info("New Block: ", "In:", lastLinearConv, "out:", []int64{lastLinearConv[1] / 2, lastLinearConv[2] / 2, 128}) - seq.Add(NewSimpleBlock(vs, lastLinearConv[0])) + //m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0])) lastLinearConv[0] = 128 lastLinearConv[1] /= 2 lastLinearConv[2] /= 2 @@ -65,12 +96,13 @@ func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) * } if addSigmoid { - seq.Add(NewSigmoid()) + m_layers = append(m_layers, NewSigmoid()) } b := &ContainerModel{ - Seq: seq, - Vs: base_vs, + Layers: m_layers, + Vs: base_vs, + path: vs, } return b } diff --git a/logic/models/train/torch/utils.go b/logic/models/train/torch/utils.go index 6f7083c..6e58035 100644 --- a/logic/models/train/torch/utils.go +++ b/logic/models/train/torch/utils.go @@ -1,10 +1,14 @@ package train import ( + "unsafe" + + my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" + "github.com/charmbracelet/log" - "github.com/sugarme/gotch/nn" - torch "github.com/sugarme/gotch/ts" + "git.andr3h3nriqu3s.com/andr3/gotch/nn" + torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" ) func or_panic(err error) { @@ -19,7 +23,9 @@ type SimpleBlock struct { } // BasicBlock returns a BasicBlockModule instance -func NewSimpleBlock(vs *nn.Path, inplanes int64) *SimpleBlock { +func NewSimpleBlock(_vs *my_nn.Path, inplanes int64) *SimpleBlock { + vs := (*nn.Path)(unsafe.Pointer(_vs)) + conf1 := nn.DefaultConv2DConfig() conf1.Stride = []int64{2, 2} @@ -85,40 +91,11 @@ func (b *SimpleBlock) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { return out } -type MyLinear struct { - FC1 *nn.Linear -} // BasicBlock returns a BasicBlockModule instance -func NewLinear(vs *nn.Path, in, out int64) *MyLinear { - config := nn.DefaultLinearConfig() - b := &MyLinear{ - FC1: nn.NewLinear(vs, in, out, config), - } - return b -} - -// Forward method -func (b *MyLinear) Forward(x *torch.Tensor) *torch.Tensor { - var err error - - out := b.FC1.Forward(x) - - out, err = out.Relu(false) - or_panic(err) - - return out -} - -func (b *MyLinear) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { - var err error - - out := b.FC1.ForwardT(x, train) - - out, err = out.Relu(false) - or_panic(err) - - return out +func NewLinear(vs *my_nn.Path, in, out int64) *my_nn.Linear { + config := my_nn.DefaultLinearConfig() + return my_nn.NewLinear(vs, in, out, config) } type Flatten struct{} @@ -128,6 +105,9 @@ func NewFlatten() *Flatten { return &Flatten{} } +// The flatten layer does not to move anything to the device +func (b *Flatten) ExtractFromVarstore(vs *my_nn.VarStore) {} + // Forward method func (b *Flatten) Forward(x *torch.Tensor) *torch.Tensor { @@ -151,6 +131,9 @@ func NewSigmoid() *Sigmoid { return &Sigmoid{} } +// The sigmoid layer does not need to move anything to another device +func (b *Sigmoid) ExtractFromVarstore(vs *my_nn.VarStore) {} + func (b *Sigmoid) Forward(x *torch.Tensor) *torch.Tensor { out, err := x.Sigmoid(false) or_panic(err) diff --git a/logic/models/train/train_normal.go b/logic/models/train/train_normal.go index c73ee58..f05af29 100644 --- a/logic/models/train/train_normal.go +++ b/logic/models/train/train_normal.go @@ -16,16 +16,17 @@ import ( "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + my_torch "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" modelloader "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/modelloader" + my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + "git.andr3h3nriqu3s.com/andr3/gotch" + torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" "github.com/charmbracelet/log" "github.com/goccy/go-json" - "github.com/sugarme/gotch" - "github.com/sugarme/gotch/nn" - torch "github.com/sugarme/gotch/ts" ) const EPOCH_PER_RUN = 20 @@ -132,11 +133,12 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor } model = my_torch.BuildModel(layers, 0, true) + } // TODO Make the runner provide this - // device := gotch.CudaIfAvailable() - device := gotch.CPU + device := gotch.CudaIfAvailable() + // device := gotch.CPU result_path := path.Join(getDir(), "savedData", m.Id, "defs", def.Id) err = os.MkdirAll(result_path, os.ModePerm) @@ -144,6 +146,16 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor return } + /* opt1, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001) + if err != nil { + return + } + + opt1.Debug() */ + + //log.Info("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n") + + // TODO remove this model.To(device) defer model.To(gotch.CPU) @@ -153,23 +165,18 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor return } - err = ds.To(device) - if err != nil { - return - } - - opt, err := nn.DefaultAdamConfig().Build(model.Vs, 0.001) + opt, err := my_nn.DefaultAdamConfig().Build(model.Vs, 0.001) if err != nil { return } for epoch := 0; epoch < EPOCH_PER_RUN; epoch++ { - var trainIter *torch.Iter2 - trainIter, err = ds.TrainIter(64) - if err != nil { - return - } - // trainIter.ToDevice(device) + var trainIter *torch.Iter2 + trainIter, err = ds.TrainIter(32) + if err != nil { + return + } + // trainIter.ToDevice(device) log.Info("epoch", "epoch", epoch) @@ -184,19 +191,49 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor continue } - pred := model.ForwardT(item.Data, true) - - // Calculate loss - - loss, err = pred.BinaryCrossEntropyWithLogits(item.Label, &torch.Tensor{}, &torch.Tensor{}, 1, false) + data := item.Data + data, err = data.ToDevice(device, gotch.Float, false, true, false) if err != nil { return } + data, err = data.SetRequiresGrad(true, true) + if err != nil { + return + } + err = data.RetainGrad(false) + if err != nil { + return + } + + pred := model.ForwardT(data, true) + pred, err = pred.SetRequiresGrad(true, true) + if err != nil { + return + } + + pred.RetainGrad(false) + + label := item.Label + label, err = label.ToDevice(device, gotch.Float, false, true, false) + if err != nil { + return + } + label, err = label.SetRequiresGrad(true, true) + if err != nil { + return + } + label.RetainGrad(false) + + // Calculate loss + loss, err = pred.BinaryCrossEntropyWithLogits(label, &torch.Tensor{}, &torch.Tensor{}, 1, false) + if err != nil { + return + } loss, err = loss.SetRequiresGrad(true, false) - if err != nil { - return - } + if err != nil { + return + } err = opt.ZeroGrad() if err != nil { @@ -213,11 +250,32 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor return } + vars := model.Vs.Variables() + + for k, v := range vars { + var grad *torch.Tensor + grad, err = v.Grad(false) + if err != nil { + return + } + + grad, err = grad.Abs(false) + if err != nil { + return + } + + grad, err = grad.Max(false) + if err != nil { + return + } + + log.Info("[grad check]", "k", k, "grad", grad.Float64Values()) + } + trainLoss = loss.Float64Values()[0] // Calculate accuracy - - var p_pred, p_labels *torch.Tensor + /*var p_pred, p_labels *torch.Tensor p_pred, err = pred.Argmax([]int64{1}, true, false) if err != nil { return @@ -235,9 +293,13 @@ func trainDefinition(c BasePack, m *BaseModel, def *Definition, in_model *my_tor if floats[i] == floats_labels[i] { trainCorrect += 1 } - } + } */ + + // panic("fornow") } + //v := []float64{} + log.Info("model training epoch done loss", "loss", trainLoss, "correct", trainCorrect, "out", ds.TrainImagesSize, "accuracy", trainCorrect/float64(ds.TrainImagesSize)) /*correct := int64(0) diff --git a/run.sh b/run.sh old mode 100644 new mode 100755 index 393055b..a0ef521 --- a/run.sh +++ b/run.sh @@ -1,2 +1,3 @@ +#!/bin/fish podman run --rm --network host --gpus all -ti -v (pwd):/app -e "TERM=xterm-256color" fyp-server bash diff --git a/test.go b/test.go new file mode 100644 index 0000000..dcc9582 --- /dev/null +++ b/test.go @@ -0,0 +1,108 @@ +package main + +import ( + "git.andr3h3nriqu3s.com/andr3/gotch" + + dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch" + my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" + + torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" + "github.com/charmbracelet/log" +) + +func main_() { + + log.Info("Hello world") + + m := train.BuildModel([]*dbtypes.Layer{ + &dbtypes.Layer{ + LayerType: dbtypes.LAYER_INPUT, + Shape: "[ 2, 3, 3 ]", + }, + &dbtypes.Layer{ + LayerType: dbtypes.LAYER_FLATTEN, + }, + &dbtypes.Layer{ + LayerType: dbtypes.LAYER_DENSE, + Shape: "[ 10 ]", + }, + }, 0, true) + + var err error + + d := gotch.CudaIfAvailable() + + log.Info("device", "d", d) + + m.To(d) + + + opt, err := my_nn.DefaultAdamConfig().Build(m.Vs, 0.001) + if err != nil { + return + } + + ones := torch.MustOnes([]int64{1, 2, 3, 3}, gotch.Float, d) + ones = ones.MustSetRequiresGrad(true, true) + ones.RetainGrad(false) + + res := m.ForwardT(ones, true) + res = res.MustSetRequiresGrad(true, true) + res.RetainGrad(false) + + outs := torch.MustOnes([]int64{1, 10}, gotch.Float, d) + outs = outs.MustSetRequiresGrad(true, true) + outs.RetainsGrad(false) + + + loss, err := res.BinaryCrossEntropyWithLogits(outs, &torch.Tensor{}, &torch.Tensor{}, 1, false) + if err != nil { + return + } + loss = loss.MustSetRequiresGrad(true, false) + + opt.ZeroGrad() + + + log.Info("loss", "loss", loss.Float64Values()) + + loss.MustBackward() + + + opt.Step() + + // log.Info(mean.MustGrad(false).Float64Values()) + log.Info(res.MustGrad(false).Float64Values()) + log.Info(ones.MustGrad(false).Float64Values()) + log.Info(outs.MustGrad(false).Float64Values()) + + vars := m.Vs.Variables() + + for k, v := range vars { + + log.Info("[grad check]", "k", k) + + var grad *torch.Tensor + grad, err = v.Grad(false) + if err != nil { + log.Error(err) + return + } + + grad, err = grad.Abs(false) + if err != nil { + log.Error(err) + return + } + + grad, err = grad.Max(false) + if err != nil { + log.Error(err) + return + } + + log.Info("[grad check]", "k", k, "grad", grad.Float64Values()) + } + +}