gotch/example/tensor-io/main.go

75 lines
1.2 KiB
Go
Raw Normal View History

2020-06-08 08:06:35 +01:00
package main
import (
2020-06-10 22:24:10 +01:00
"fmt"
2024-04-21 15:15:00 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
2020-06-08 08:06:35 +01:00
)
func main() {
2022-03-12 07:20:20 +00:00
x, err := ts.OfSlice([]float64{1.3, 29.7})
2020-06-08 08:06:35 +01:00
if err != nil {
panic(err)
}
path := "file.pt"
2022-03-12 07:20:20 +00:00
x.MustSave(path)
2020-06-08 08:06:35 +01:00
2022-03-12 07:20:20 +00:00
loadedTs := ts.MustLoad(path)
2020-06-08 08:06:35 +01:00
loadedTs.Print()
2022-03-12 07:20:20 +00:00
ts1 := ts.MustOfSlice([]float64{1.3, 29.7})
ts2 := ts.MustOfSlice([]float64{2.1, 31.2})
2022-03-12 07:20:20 +00:00
var namedTensors []ts.NamedTensor = []ts.NamedTensor{
{Name: "ts1", Tensor: ts1},
{Name: "ts2", Tensor: ts2},
}
pathMulti := "file_multi.pt"
2022-03-12 07:20:20 +00:00
// err = ts.SaveMulti(namedTensors, pathMulti)
// if err != nil {
// panic(err)
// }
2022-03-12 07:20:20 +00:00
err = ts.SaveMultiNew(namedTensors, pathMulti)
if err != nil {
panic(err)
}
2022-03-12 07:20:20 +00:00
var data []ts.NamedTensor
2020-06-10 09:31:07 +01:00
2022-03-12 07:20:20 +00:00
data = ts.MustLoadMulti(pathMulti)
2020-06-10 09:31:07 +01:00
for _, v := range data {
v.Tensor.Print()
}
device := gotch.NewCuda()
2022-03-12 07:20:20 +00:00
data = ts.MustLoadMultiWithDevice(pathMulti, device)
for _, v := range data {
v.Tensor.Print()
}
2022-03-12 07:20:20 +00:00
tsString := x.MustToString(80)
2020-06-10 22:24:10 +01:00
fmt.Printf("Tensor String: \n%v\n", tsString)
imagePath := "mnist-sample.png"
2022-03-12 07:20:20 +00:00
imageTs, err := ts.LoadHwc(imagePath)
if err != nil {
panic(err)
}
err = imageTs.Save("mnist-tensor-saved.png")
if err != nil {
panic(err)
}
2020-06-08 08:06:35 +01:00
}