More work on tring to make torch work
This commit is contained in:
@@ -35,6 +35,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pimgs = append(pimgs, img)
|
||||
|
||||
t_label := make([]int, size)
|
||||
@@ -46,6 +47,7 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
plabels = append(plabels, label)
|
||||
}
|
||||
|
||||
@@ -62,13 +64,16 @@ func LoadImagesAndLables(db db.Db, m *types.BaseModel, mode types.DATA_POINT_MOD
|
||||
count = len(pimgs)
|
||||
|
||||
imgs, err = torch.Stack(pimgs, 0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
||||
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
imgs, err = imgs.ToDtype(gotch.Float, false, false, true)
|
||||
labels, err = labels.ToDtype(gotch.Float, false, false, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -129,21 +134,46 @@ func (ds *Dataset) TestIter(batchSize int64) *torch.Iter2 {
|
||||
}
|
||||
|
||||
func (ds *Dataset) TrainIter(batchSize int64) (iter *torch.Iter2, err error) {
|
||||
|
||||
train_images, err := ds.TrainImages.DetachCopy(false)
|
||||
|
||||
// Create a clone of the trainimages
|
||||
size, err := ds.TrainImages.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
train_labels, err := ds.TrainLabels.DetachCopy(false)
|
||||
train_images, err := torch.Zeros(size, gotch.Float, ds.Device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
||||
ds.TrainImages, err = ds.TrainImages.Clone(train_images, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Create a clone of the labels
|
||||
size, err = ds.TrainLabels.Size()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
train_labels, err := torch.Zeros(size, gotch.Float, ds.Device)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ds.TrainLabels, err = ds.TrainLabels.Clone(train_labels, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
iter, err = torch.NewIter2(train_images, train_labels, batchSize)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user