package imageloader import ( "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "github.com/sugarme/gotch" torch "github.com/sugarme/gotch/ts" "github.com/sugarme/gotch/vision" ) type Dataset struct { TrainImages *torch.Tensor TrainLabels *torch.Tensor TestImages *torch.Tensor TestLabels *torch.Tensor TrainImagesSize int TestImagesSize int Device gotch.Device } func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MODE, classStart int, classEnd int) (imgs, labels *torch.Tensor, count int, err error) { train_points, err := m.DataPoints(db, types.DATA_POINT_MODE_TRAINING) if err != nil { return } size := int64(classEnd - classStart + 1) pimgs := []*torch.Tensor{} plabels := []*torch.Tensor{} for _, point := range train_points { var img, label *torch.Tensor img, err = vision.Load(point.Path) if err != nil { return } pimgs = append(pimgs, img) t_label := make([]int, size) if point.Class <= classEnd && point.Class >= classStart { t_label[point.Class-classStart] = 1 } label, err = torch.OfSlice(t_label) if err != nil { return } plabels = append(plabels, label) } imgs, err = torch.Concat(pimgs, 0) if err != nil { return } labels, err = torch.Stack(plabels, 0) if err != nil { return } count = len(pimgs) imgs, err = torch.Stack(pimgs, 0) labels, err = labels.ToDtype(gotch.Float, false, false, true) if err != nil { return } imgs, err = imgs.ToDtype(gotch.Float, false, false, true) if err != nil { return } return } func NewDataset(db db.Db, m *types.BaseModel, classStart int, classEnd int) (ds *Dataset, err error) { trainImages, trainLabels, train_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TRAINING, classStart, classEnd) if err != nil { return } testImages, testLabels, test_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TESTING, classStart, classEnd) if err != nil { return } ds = &Dataset{ TrainImages: trainImages, TrainLabels: trainLabels, TestImages: testImages, TestLabels: testLabels, TrainImagesSize: train_count, TestImagesSize: test_count, Device: gotch.CPU, } return } func (ds *Dataset) To(device gotch.Device) (err error) { ds.TrainImages, err = ds.TrainImages.ToDevice(device, ds.TrainImages.DType(), device.IsCuda(), true, true) if err != nil { return } ds.TrainLabels, err = ds.TrainLabels.ToDevice(device, ds.TrainLabels.DType(), device.IsCuda(), true, true) if err != nil { return } ds.TestImages, err = ds.TestImages.ToDevice(device, ds.TestImages.DType(), device.IsCuda(), true, true) if err != nil { return } ds.TestLabels, err = ds.TestLabels.ToDevice(device, ds.TestLabels.DType(), device.IsCuda(), true, true) if err != nil { return } ds.Device = device return } func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 { return torch.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize) } func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) { train_images, err := ds.TrainImages.DetachCopy(false) if err != nil { return } train_labels, err := ds.TrainLabels.DetachCopy(false) if err != nil { return } iter, err = torch.NewIter2(train_images, train_labels, batchSize) if err != nil { return } return }