180 lines
3.9 KiB
Go
180 lines
3.9 KiB
Go
package imageloader
|
|
|
|
import (
|
|
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
|
types "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
torch "git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/vision"
|
|
)
|
|
|
|
type Dataset struct {
|
|
TrainImages *torch.Tensor
|
|
TrainLabels *torch.Tensor
|
|
TestImages *torch.Tensor
|
|
TestLabels *torch.Tensor
|
|
TrainImagesSize int
|
|
TestImagesSize int
|
|
Device gotch.Device
|
|
}
|
|
|
|
func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MODE, classStart int, classEnd int) (imgs, labels *torch.Tensor, count int, err error) {
|
|
train_points, err := m.DataPoints(db, types.DATA_POINT_MODE_TRAINING)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
size := int64(classEnd - classStart + 1)
|
|
|
|
pimgs := []*torch.Tensor{}
|
|
plabels := []*torch.Tensor{}
|
|
|
|
for _, point := range train_points {
|
|
var img, label *torch.Tensor
|
|
img, err = vision.Load(point.Path)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
pimgs = append(pimgs, img)
|
|
|
|
t_label := make([]int, size)
|
|
if point.Class <= classEnd && point.Class >= classStart {
|
|
t_label[point.Class-classStart] = 1
|
|
}
|
|
|
|
label, err = torch.OfSlice(t_label)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
plabels = append(plabels, label)
|
|
}
|
|
|
|
imgs, err = torch.Concat(pimgs, 0)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
labels, err = torch.Stack(plabels, 0)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
count = len(pimgs)
|
|
|
|
imgs, err = torch.Stack(pimgs, 0)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func NewDataset(db db.Db, m *types.BaseModel, classStart int, classEnd int) (ds *Dataset, err error) {
|
|
trainImages, trainLabels, train_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TRAINING, classStart, classEnd)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
testImages, testLabels, test_count, err := LoadImagesAndLables(db, m, types.DATA_POINT_MODE_TESTING, classStart, classEnd)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds = &Dataset{
|
|
TrainImages: trainImages,
|
|
TrainLabels: trainLabels,
|
|
TestImages: testImages,
|
|
TestLabels: testLabels,
|
|
TrainImagesSize: train_count,
|
|
TestImagesSize: test_count,
|
|
Device: gotch.CPU,
|
|
}
|
|
return
|
|
}
|
|
|
|
func (ds *Dataset) To(device gotch.Device) (err error) {
|
|
ds.TrainImages, err = ds.TrainImages.ToDevice(device, ds.TrainImages.DType(), device.IsCuda(), true, true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds.TrainLabels, err = ds.TrainLabels.ToDevice(device, ds.TrainLabels.DType(), device.IsCuda(), true, true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds.TestImages, err = ds.TestImages.ToDevice(device, ds.TestImages.DType(), device.IsCuda(), true, true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds.TestLabels, err = ds.TestLabels.ToDevice(device, ds.TestLabels.DType(), device.IsCuda(), true, true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds.Device = device
|
|
return
|
|
}
|
|
|
|
func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
|
|
return torch.MustNewIter2(ds.TestImages, ds.TestLabels, batchSize)
|
|
}
|
|
|
|
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
|
|
|
|
// Create a clone of the trainimages
|
|
size, err := ds.TrainImages.Size()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
train_images, err := torch.Zeros(size, gotch.Float, ds.Device)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds.TrainImages, err = ds.TrainImages.Clone(train_images, false)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
|
|
|
|
// Create a clone of the labels
|
|
size, err = ds.TrainLabels.Size()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
train_labels, err := torch.Zeros(size, gotch.Float, ds.Device)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ds.TrainLabels, err = ds.TrainLabels.Clone(train_labels, false)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
|
|
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|