More work on tring to make torch work

This commit is contained in:
2024-04-23 00:14:35 +01:00
parent 703fea46f2
commit a4a9ade71f
7 changed files with 109 additions and 43 deletions

View File

@@ -35,6 +35,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
if err != nil {
return
}
pimgs = append(pimgs, img)
t_label := make([]int, size)
@@ -46,6 +47,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
if err != nil {
return
}
plabels = append(plabels, label)
}
@@ -62,13 +64,16 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
count = len(pimgs)
imgs, err = torch.Stack(pimgs, 0)
if err != nil {
return
}
labels, err = labels.ToDtype(gotch.Float, false, false, true)
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
if err != nil {
return
}
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
labels, err = labels.ToDtype(gotch.Float, false, false, true)
if err != nil {
return
}
@@ -129,21 +134,46 @@ func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
}
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
train_images, err := ds.TrainImages.DetachCopy(false)
// Create a clone of the trainimages
size, err := ds.TrainImages.Size()
if err != nil {
return
}
train_labels, err := ds.TrainLabels.DetachCopy(false)
train_images, err := torch.Zeros(size, gotch.Float, ds.Device)
if err != nil {
return
}
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
ds.TrainImages, err = ds.TrainImages.Clone(train_images, false)
if err != nil {
return
}
// Create a clone of the labels
size, err = ds.TrainLabels.Size()
if err != nil {
return
}
train_labels, err := torch.Zeros(size, gotch.Float, ds.Device)
if err != nil {
return
}
ds.TrainLabels, err = ds.TrainLabels.Clone(train_labels, false)
if err != nil {
return
}
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
if err != nil {
return
}
return
}