Started working on moving to torch
This commit is contained in:
149
logic/models/train/torch/modelloader/modelloader.go
Normal file
149
logic/models/train/torch/modelloader/modelloader.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package imageloader
|
||||
|
||||
import (
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
"github.com/sugarme/gotch"
|
||||
torch "github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
type Dataset struct {
|
||||
TrainImages *torch.Tensor
|
||||
TrainLabels *torch.Tensor
|
||||
TestImages *torch.Tensor
|
||||
TestLabels *torch.Tensor
|
||||
TrainImagesSize int
|
||||
TestImagesSize int
|
||||
Device gotch.Device
|
||||
}
|
||||
|
||||
func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MODE, classStart int, classEnd int) (imgs, labels *torch.Tensor, count int, err error) {
|
||||
train_points, err := m.DataPoints(db, types.DATA_POINT_MODE_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
size := int64(classEnd - classStart + 1)
|
||||
|
||||
pimgs := []*torch.Tensor{}
|
||||
plabels := []*torch.Tensor{}
|
||||
|
||||
for _, point := range train_points {
|
||||
var img, label *torch.Tensor
|
||||
img, err = vision.Load(point.Path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pimgs = append(pimgs, img)
|
||||
|
||||
t_label := make([]int, size)
|
||||
if point.Class <= classEnd && point.Class >= classStart {
|
||||
t_label[point.Class-classStart] = 1
|
||||
}
|
||||
|
||||
label, err = torch.OfSlice(t_label)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
plabels = append(plabels, label)
|
||||
}
|
||||
|
||||
imgs, err = torch.Concat(pimgs, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
labels, err = torch.Stack(plabels, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
count = len(pimgs)
|
||||
|
||||
imgs, err = torch.Stack(pimgs, 0)
|
||||
|
||||
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func NewDataset(db db.Db, m *types.BaseModel, classStart int, classEnd int) (ds *Dataset, err error) {
|
||||
trainImages, trainLabels, train_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TRAINING, classStart, classEnd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
testImages, testLabels, test_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TESTING, classStart, classEnd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds = &Dataset{
|
||||
TrainImages: trainImages,
|
||||
TrainLabels: trainLabels,
|
||||
TestImages: testImages,
|
||||
TestLabels: testLabels,
|
||||
TrainImagesSize: train_count,
|
||||
TestImagesSize: test_count,
|
||||
Device: gotch.CPU,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ds *Dataset) To(device gotch.Device) (err error) {
|
||||
ds.TrainImages, err = ds.TrainImages.ToDevice(device, ds.TrainImages.DType(), device.IsCuda(), true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.TrainLabels, err = ds.TrainLabels.ToDevice(device, ds.TrainLabels.DType(), device.IsCuda(), true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.TestImages, err = ds.TestImages.ToDevice(device, ds.TestImages.DType(), device.IsCuda(), true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.TestLabels, err = ds.TestLabels.ToDevice(device, ds.TestLabels.DType(), device.IsCuda(), true, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.Device = device
|
||||
return
|
||||
}
|
||||
|
||||
func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
|
||||
return torch.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
||||
}
|
||||
|
||||
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
|
||||
|
||||
train_images, err := ds.TrainImages.DetachCopy(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
train_labels, err := ds.TrainLabels.DetachCopy(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user