2024-04-19 15:39:51 +01:00
|
|
|
package train
|
|
|
|
|
|
|
|
import (
|
|
|
|
types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
2024-04-22 00:09:07 +01:00
|
|
|
my_nn "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train/torch/nn"
|
2024-04-19 15:39:51 +01:00
|
|
|
|
2024-04-22 00:09:07 +01:00
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
2024-04-19 15:39:51 +01:00
|
|
|
"github.com/charmbracelet/log"
|
|
|
|
|
2024-04-22 00:09:07 +01:00
|
|
|
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
2024-04-19 15:39:51 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
type IForwardable interface {
|
|
|
|
Forward(xs *torch.Tensor) *torch.Tensor
|
|
|
|
}
|
|
|
|
|
|
|
|
// Container for a model
|
|
|
|
type ContainerModel struct {
|
2024-04-22 00:09:07 +01:00
|
|
|
Layers []my_nn.MyLayer
|
|
|
|
Vs *my_nn.VarStore
|
|
|
|
path *my_nn.Path
|
2024-04-19 15:39:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
2024-04-22 00:09:07 +01:00
|
|
|
if len(n.Layers) == 0 {
|
|
|
|
return x.MustShallowClone()
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(n.Layers) == 1 {
|
2024-04-23 11:54:30 +01:00
|
|
|
log.Info("here")
|
2024-04-22 00:09:07 +01:00
|
|
|
return n.Layers[0].ForwardT(x, train)
|
|
|
|
}
|
|
|
|
|
|
|
|
// forward sequentially
|
|
|
|
outs := make([]*torch.Tensor, len(n.Layers))
|
|
|
|
for i := 0; i < len(n.Layers); i++ {
|
|
|
|
if i == 0 {
|
|
|
|
outs[0] = n.Layers[i].ForwardT(x, train)
|
2024-04-23 00:14:35 +01:00
|
|
|
//defer outs[0].MustDrop()
|
2024-04-22 00:09:07 +01:00
|
|
|
} else if i == len(n.Layers)-1 {
|
|
|
|
return n.Layers[i].ForwardT(outs[i-1], train)
|
|
|
|
} else {
|
|
|
|
outs[i] = n.Layers[i].ForwardT(outs[i-1], train)
|
2024-04-23 00:14:35 +01:00
|
|
|
//defer outs[i].MustDrop()
|
2024-04-22 00:09:07 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
panic("Do not reach here")
|
2024-04-19 15:39:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (n *ContainerModel) To(device gotch.Device) {
|
|
|
|
n.Vs.ToDevice(device)
|
2024-04-22 00:09:07 +01:00
|
|
|
for _, layer := range n.Layers {
|
|
|
|
layer.ExtractFromVarstore(n.Vs)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (n *ContainerModel) Refresh() {
|
|
|
|
for _, layer := range n.Layers {
|
|
|
|
layer.ExtractFromVarstore(n.Vs)
|
|
|
|
}
|
2024-04-19 15:39:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *ContainerModel {
|
|
|
|
|
2024-04-22 00:09:07 +01:00
|
|
|
base_vs := my_nn.NewVarStore(gotch.CPU)
|
2024-04-19 15:39:51 +01:00
|
|
|
vs := base_vs.Root()
|
2024-04-22 00:09:07 +01:00
|
|
|
|
|
|
|
m_layers := []my_nn.MyLayer{}
|
2024-04-19 15:39:51 +01:00
|
|
|
|
|
|
|
var lastLinearSize int64 = _lastLinearSize
|
|
|
|
lastLinearConv := []int64{}
|
|
|
|
|
|
|
|
for _, layer := range layers {
|
|
|
|
if layer.LayerType == types.LAYER_INPUT {
|
|
|
|
lastLinearConv = layer.GetShape()
|
|
|
|
log.Info("Input: ", "In:", lastLinearConv)
|
|
|
|
} else if layer.LayerType == types.LAYER_DENSE {
|
|
|
|
shape := layer.GetShape()
|
|
|
|
log.Info("New Dense: ", "In:", lastLinearSize, "out:", shape[0])
|
2024-04-22 00:09:07 +01:00
|
|
|
m_layers = append(m_layers, NewLinear(vs, lastLinearSize, shape[0]))
|
2024-04-19 15:39:51 +01:00
|
|
|
lastLinearSize = shape[0]
|
|
|
|
} else if layer.LayerType == types.LAYER_FLATTEN {
|
2024-04-22 00:09:07 +01:00
|
|
|
m_layers = append(m_layers, NewFlatten())
|
2024-04-19 15:39:51 +01:00
|
|
|
lastLinearSize = 1
|
|
|
|
for _, i := range lastLinearConv {
|
|
|
|
lastLinearSize *= i
|
|
|
|
}
|
|
|
|
log.Info("Flatten: ", "In:", lastLinearConv, "out:", lastLinearSize)
|
|
|
|
} else if layer.LayerType == types.LAYER_SIMPLE_BLOCK {
|
2024-04-22 00:09:07 +01:00
|
|
|
panic("TODO")
|
2024-04-19 15:39:51 +01:00
|
|
|
log.Info("New Block: ", "In:", lastLinearConv, "out:", []int64{lastLinearConv[1] / 2, lastLinearConv[2] / 2, 128})
|
2024-04-22 00:09:07 +01:00
|
|
|
//m_layers = append(m_layers, NewSimpleBlock(vs, lastLinearConv[0]))
|
2024-04-19 15:39:51 +01:00
|
|
|
lastLinearConv[0] = 128
|
|
|
|
lastLinearConv[1] /= 2
|
|
|
|
lastLinearConv[2] /= 2
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if addSigmoid {
|
2024-04-22 00:09:07 +01:00
|
|
|
m_layers = append(m_layers, NewSigmoid())
|
2024-04-19 15:39:51 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
b := &ContainerModel{
|
2024-04-22 00:09:07 +01:00
|
|
|
Layers: m_layers,
|
|
|
|
Vs: base_vs,
|
|
|
|
path: vs,
|
2024-04-19 15:39:51 +01:00
|
|
|
}
|
|
|
|
return b
|
|
|
|
}
|
|
|
|
|
2024-04-23 00:14:35 +01:00
|
|
|
func (model *ContainerModel) Debug() {
|
|
|
|
for _, v := range model.Layers {
|
|
|
|
v.Debug()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-19 15:39:51 +01:00
|
|
|
func SaveModel(model *ContainerModel, modelFn string) (err error) {
|
|
|
|
model.Vs.ToDevice(gotch.CPU)
|
|
|
|
return model.Vs.Save(modelFn)
|
|
|
|
}
|