package train import ( types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn" "git.andr3h3nriqu3s.com/andr3/gotch" "github.com/charmbracelet/log" torch "git.andr3h3nriqu3s.com/andr3/gotch/ts" ) type IForwardable interface { Forward(xs *torch.Tensor) *torch.Tensor } // Container for a model type ContainerModel struct { Layers []my_nn.MyLayer Vs *my_nn.VarStore path *my_nn.Path } func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor { if len(n.Layers) == 0 { return x.MustShallowClone() } if len(n.Layers) == 1 { log.Info("here") return n.Layers[0].ForwardT(x, train) } // forward sequentially outs := make([]*torch.Tensor, len(n.Layers)) for i := 0; i < len(n.Layers); i++ { if i == 0 { outs[0] = n.Layers[i].ForwardT(x, train) //defer outs[0].MustDrop() } else if i == len(n.Layers)-1 { return n.Layers[i].ForwardT(outs[i-1], train) } else { outs[i] = n.Layers[i].ForwardT(outs[i-1], train) //defer outs[i].MustDrop() } } panic("Do not reach here") } func (n *ContainerModel) To(device gotch.Device) { n.Vs.ToDevice(device) for _, layer := range n.Layers { layer.ExtractFromVarstore(n.Vs) } } func (n *ContainerModel) Refresh() { for _, layer := range n.Layers { layer.ExtractFromVarstore(n.Vs) } } func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *ContainerModel { base_vs := my_nn.NewVarStore(gotch.CPU) vs := base_vs.Root() m_layers := []my_nn.MyLayer{} var lastLinearSize int64 = _lastLinearSize lastLinearConv := []int64{} for _, layer := range layers { if layer.LayerType == types.LAYER_INPUT { lastLinearConv = layer.GetShape() log.Info("Input: ", "In:", lastLinearConv) } else if layer.LayerType == types.LAYER_DENSE { shape := layer.GetShape() log.Info("New Dense: ", "In:", lastLinearSize, "out:", shape[0]) m_layers = append(m_layers, NewLinear(vs, lastLinearSize, shape[0])) lastLinearSize = shape[0] } else if layer.LayerType == types.LAYER_FLATTEN { m_layers = append(m_layers, NewFlatten()) lastLinearSize = 1 for _, i := range lastLinearConv { lastLinearSize *= i } log.Info("Flatten: ", "In:", lastLinearConv, "out:", lastLinearSize) } else if layer.LayerType == types.LAYER_SIMPLE_BLOCK { panic("TODO") log.Info("New Block: ", "In:", lastLinearConv, "out:", []int64{lastLinearConv[1] / 2, lastLinearConv[2] / 2, 128}) //m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0])) lastLinearConv[0] = 128 lastLinearConv[1] /= 2 lastLinearConv[2] /= 2 } } if addSigmoid { m_layers = append(m_layers, NewSigmoid()) } b := &ContainerModel{ Layers: m_layers, Vs: base_vs, path: vs, } return b } func (model *ContainerModel) Debug() { for _, v := range model.Layers { v.Debug() } } func SaveModel(model *ContainerModel, modelFn string) (err error) { model.Vs.ToDevice(gotch.CPU) return model.Vs.Save(modelFn) }