Started working on moving to torch
This commit is contained in:
81
logic/models/train/torch/torch.go
Normal file
81
logic/models/train/torch/torch.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package train
|
||||
|
||||
import (
|
||||
types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
|
||||
//"github.com/sugarme/gotch"
|
||||
//"github.com/sugarme/gotch/vision"
|
||||
torch "github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
type IForwardable interface {
|
||||
Forward(xs *torch.Tensor) *torch.Tensor
|
||||
}
|
||||
|
||||
// Container for a model
|
||||
type ContainerModel struct {
|
||||
Seq *nn.SequentialT
|
||||
Vs *nn.VarStore
|
||||
}
|
||||
|
||||
func (n *ContainerModel) ForwardT(x *torch.Tensor, train bool) *torch.Tensor {
|
||||
return n.Seq.ForwardT(x, train)
|
||||
}
|
||||
|
||||
func (n *ContainerModel) To(device gotch.Device) {
|
||||
n.Vs.ToDevice(device)
|
||||
}
|
||||
|
||||
func BuildModel(layers []*types.Layer, _lastLinearSize int64, addSigmoid bool) *ContainerModel {
|
||||
|
||||
base_vs := nn.NewVarStore(gotch.CPU)
|
||||
vs := base_vs.Root()
|
||||
seq := nn.SeqT()
|
||||
|
||||
var lastLinearSize int64 = _lastLinearSize
|
||||
lastLinearConv := []int64{}
|
||||
|
||||
for _, layer := range layers {
|
||||
if layer.LayerType == types.LAYER_INPUT {
|
||||
lastLinearConv = layer.GetShape()
|
||||
log.Info("Input: ", "In:", lastLinearConv)
|
||||
} else if layer.LayerType == types.LAYER_DENSE {
|
||||
shape := layer.GetShape()
|
||||
log.Info("New Dense: ", "In:", lastLinearSize, "out:", shape[0])
|
||||
seq.Add(NewLinear(vs, lastLinearSize, shape[0]))
|
||||
lastLinearSize = shape[0]
|
||||
} else if layer.LayerType == types.LAYER_FLATTEN {
|
||||
seq.Add(NewFlatten())
|
||||
lastLinearSize = 1
|
||||
for _, i := range lastLinearConv {
|
||||
lastLinearSize *= i
|
||||
}
|
||||
log.Info("Flatten: ", "In:", lastLinearConv, "out:", lastLinearSize)
|
||||
} else if layer.LayerType == types.LAYER_SIMPLE_BLOCK {
|
||||
log.Info("New Block: ", "In:", lastLinearConv, "out:", []int64{lastLinearConv[1] / 2, lastLinearConv[2] / 2, 128})
|
||||
seq.Add(NewSimpleBlock(vs, lastLinearConv[0]))
|
||||
lastLinearConv[0] = 128
|
||||
lastLinearConv[1] /= 2
|
||||
lastLinearConv[2] /= 2
|
||||
}
|
||||
}
|
||||
|
||||
if addSigmoid {
|
||||
seq.Add(NewSigmoid())
|
||||
}
|
||||
|
||||
b := &ContainerModel{
|
||||
Seq: seq,
|
||||
Vs: base_vs,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func SaveModel(model *ContainerModel, modelFn string) (err error) {
|
||||
model.Vs.ToDevice(gotch.CPU)
|
||||
return model.Vs.Save(modelFn)
|
||||
}
|
||||
Reference in New Issue
Block a user