gotch/pickle/storage.go

812 lines
17 KiB
Go

package pickle
import (
"encoding/binary"
"fmt"
"io"
"math"
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/half"
)
// This file implements Pytorch storage data types.
// ref: https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/storage.py
/*
torch.double: 'DoubleStorage',
torch.float: 'FloatStorage',
torch.half: 'HalfStorage',
torch.long: 'LongStorage',
torch.int: 'IntStorage',
torch.int16: 'ShortStorage',
torch.int8: 'CharStorage',
torch.uint8: 'ByteStorage',
torch.bool: 'BoolStorage',
torch.bfloat16: 'BFloat16Storage',
torch.cdouble: 'ComplexDoubleStorage',
torch.cfloat: 'ComplexFloatStorage',
torch.qint8: 'QInt8Storage',
torch.qint32: 'QInt32Storage',
torch.quint8: 'QUInt8Storage',
torch.quint4x2: 'QUInt4x2Storage',
torch.quint2x4: 'QUInt2x4Storage',
*/
// StorageClass defines interface for types to be used in Storage.
type StorageClass interface {
New(size int, location string) Storage
}
// Storage define Storage interface.
type Storage interface {
SetFromFile(r io.Reader) error
SetFromFileWithSize(r io.Reader, size int) error
DType() gotch.DType
GetData() interface{}
Device() gotch.Device
}
// BaseStorage represents a base storage.
type BaseStorage struct {
Size int
Location string
}
// HalfStorage:
// ============
type HalfStorageClass struct{}
var _ StorageClass = &HalfStorageClass{}
func (s *HalfStorageClass) New(size int, location string) Storage {
return &HalfStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type HalfStorage struct {
BaseStorage
// Data []float32
Data []half.Float16
}
var _ Storage = &HalfStorage{}
func (s *HalfStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]half.Float16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = half.Float16(u16)
}
s.Data = data
return nil
}
func (s *HalfStorage) GetData() interface{} {
return s.Data
}
func (s *HalfStorage) DType() gotch.DType {
return gotch.Half
}
func (s *HalfStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// BFloat16Storage:
// ================
type BFloat16StorageClass struct{}
var _ StorageClass = &BFloat16StorageClass{}
func (s *BFloat16StorageClass) New(size int, location string) Storage {
return &BFloat16Storage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type BFloat16Storage struct {
BaseStorage
Data []half.BFloat16
}
var _ Storage = &BFloat16Storage{}
func (s *BFloat16Storage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *BFloat16Storage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]half.BFloat16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = half.BFloat16(u16)
}
s.Data = data
return nil
}
func (s *BFloat16Storage) GetData() interface{} {
return s.Data
}
func (s *BFloat16Storage) DType() gotch.DType {
return gotch.BFloat16
}
func (s *BFloat16Storage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// FloatStorage:
// =============
type FloatStorageClass struct{}
var _ StorageClass = &FloatStorageClass{}
func (s *FloatStorageClass) New(size int, location string) Storage {
return &FloatStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type FloatStorage struct {
BaseStorage
Data []float32
}
var _ Storage = &FloatStorage{}
func (s *FloatStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *FloatStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float32, size)
br := NewLimitedBufferReader(r, size, 4, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = math.Float32frombits(binary.LittleEndian.Uint32(bytes))
}
s.Data = data
return nil
}
func (s *FloatStorage) GetData() interface{} {
return s.Data
}
func (s *FloatStorage) DType() gotch.DType {
return gotch.Float
}
func (s *FloatStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// DoubleStorage:
// ==============
type DoubleStorageClass struct{}
var _ StorageClass = &DoubleStorageClass{}
func (s *DoubleStorageClass) New(size int, location string) Storage {
return &DoubleStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type DoubleStorage struct {
BaseStorage
Data []float64
}
var _ Storage = &DoubleStorage{}
func (s *DoubleStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *DoubleStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float64, size)
br := NewLimitedBufferReader(r, size, 8, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = math.Float64frombits(binary.LittleEndian.Uint64(bytes))
}
s.Data = data
return nil
}
func (s *DoubleStorage) GetData() interface{} {
return s.Data
}
func (s *DoubleStorage) DType() gotch.DType {
return gotch.Double
}
func (s *DoubleStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// CharStorage:
// ============
type CharStorageClass struct{}
var _ StorageClass = &CharStorageClass{}
func (s *CharStorageClass) New(size int, location string) Storage {
return &CharStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type CharStorage struct {
BaseStorage
Data []int8
}
var _ Storage = &CharStorage{}
func (s *CharStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *CharStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int8, size)
br := NewLimitedBufferReader(r, size, 1, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int8(bytes[0])
}
s.Data = data
return nil
}
func (s *CharStorage) GetData() interface{} {
return s.Data
}
func (s *CharStorage) DType() gotch.DType {
return gotch.Int8
}
func (s *CharStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// ShortStorage:
// =============
type ShortStorageClass struct{}
var _ StorageClass = &ShortStorageClass{}
func (s *ShortStorageClass) New(size int, location string) Storage {
return &ShortStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type ShortStorage struct {
BaseStorage
Data []int16
}
var _ Storage = &ShortStorage{}
func (s *ShortStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *ShortStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int16(binary.LittleEndian.Uint16(bytes))
}
s.Data = data
return nil
}
func (s *ShortStorage) GetData() interface{} {
return s.Data
}
func (s *ShortStorage) DType() gotch.DType {
return gotch.Int16
}
func (s *ShortStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// IntStorage:
// ===========
type IntStorageClass struct{}
var _ StorageClass = &IntStorageClass{}
func (s *IntStorageClass) New(size int, location string) Storage {
return &IntStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type IntStorage struct {
BaseStorage
Data []int32
}
var _ Storage = &IntStorage{}
func (s *IntStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *IntStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int32, size)
br := NewLimitedBufferReader(r, size, 4, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int32(binary.LittleEndian.Uint32(bytes))
}
s.Data = data
return nil
}
func (s *IntStorage) GetData() interface{} {
return s.Data
}
func (s *IntStorage) DType() gotch.DType {
return gotch.Int
}
func (s *IntStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// LongStorage:
// ============
type LongStorageClass struct{}
var _ StorageClass = &LongStorageClass{}
func (s *LongStorageClass) New(size int, location string) Storage {
return &LongStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type LongStorage struct {
BaseStorage
Data []int64
}
var _ Storage = &LongStorage{}
func (s *LongStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *LongStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]int64, size)
br := NewLimitedBufferReader(r, size, 8, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = int64(binary.LittleEndian.Uint64(bytes))
}
s.Data = data
return nil
}
func (s *LongStorage) GetData() interface{} {
return s.Data
}
func (s *LongStorage) DType() gotch.DType {
return gotch.Int64
}
func (s *LongStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// ByteStorage:
// ============
type ByteStorageClass struct{}
var _ StorageClass = &ByteStorageClass{}
func (s *ByteStorageClass) New(size int, location string) Storage {
return &ByteStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type ByteStorage struct {
BaseStorage
Data []uint8
}
var _ Storage = &ByteStorage{}
func (s *ByteStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *ByteStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]uint8, size)
br := NewLimitedBufferReader(r, size, 1, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = bytes[0]
}
s.Data = data
return nil
}
func (s *ByteStorage) GetData() interface{} {
return s.Data
}
func (s *ByteStorage) DType() gotch.DType {
return gotch.Uint8
}
func (s *ByteStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// BoolStorage:
// ============
type BoolStorageClass struct{}
var _ StorageClass = &BoolStorageClass{}
func (s *BoolStorageClass) New(size int, location string) Storage {
return &BoolStorage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type BoolStorage struct {
BaseStorage
Data []bool
}
var _ Storage = &BoolStorage{}
func (s *BoolStorage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *BoolStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]bool, size)
br := NewLimitedBufferReader(r, size, 1, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
data[i] = bytes[0] == 1
}
s.Data = data
return nil
}
func (s *BoolStorage) GetData() interface{} {
return s.Data
}
func (s *BoolStorage) DType() gotch.DType {
return gotch.Float
}
func (s *BoolStorage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
func setFromFile(s Storage, r io.Reader) error {
sizeBuf := make([]byte, 8)
_, err := r.Read(sizeBuf)
if err != nil {
return err
}
size := int(binary.LittleEndian.Uint64(sizeBuf))
return s.SetFromFileWithSize(r, size)
}
// StorageTensor:
// ===============
type StorageTensor struct {
Source Storage
StorageOffset int64
Size []int64
Stride []int64
RequiresGrad bool
}
// Rebuild Tensor:
// ===============
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L132
// ref. def _rebuild_tensor(storage, storage_offset, size, stride):
type RebuildTensor struct{}
var _ Callable = &RebuildTensor{}
func (r *RebuildTensor) Call(args ...interface{}) (interface{}, error) {
if len(args) != 4 {
return nil, fmt.Errorf("RebuildTensor.Call() failed. Expected 4 args, got %d: %#v", len(args), args)
}
storage, storageOk := args[0].(Storage)
storageOffset, storageOffsetOk := args[1].(int)
size, sizeOk := args[2].(*Tuple)
stride, strideOk := args[3].(*Tuple)
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk {
return nil, fmt.Errorf("RebuildTensor.Call() unexpected args: %#v", args)
}
tensor := &StorageTensor{
Source: storage,
StorageOffset: int64(storageOffset),
RequiresGrad: false,
}
var err error
tensor.Size, err = tupleToInt64Slice(size)
if err != nil {
return nil, err
}
tensor.Stride, err = tupleToInt64Slice(stride)
if err != nil {
return nil, err
}
return tensor, nil
}
// RebuildTensorV2 represents a struct to rebuild tensor back from pickle object.
type RebuildTensorV2 struct{}
var _ Callable = &RebuildTensorV2{}
func (r *RebuildTensorV2) Call(args ...interface{}) (interface{}, error) {
if len(args) != 6 {
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
}
storage, storageOk := args[0].(Storage)
storageOffset, storageOffsetOk := args[1].(int)
size, sizeOk := args[2].(*Tuple)
stride, strideOk := args[3].(*Tuple)
requiresGrad, requiresGradOk := args[4].(bool)
// arg[5] "backward hooks" is unused
if !storageOk || !storageOffsetOk || !sizeOk || !strideOk ||
!requiresGradOk {
return nil, fmt.Errorf("RebuildTensorV2 unexpected args: %#v", args)
}
tensor := &StorageTensor{
Source: storage,
StorageOffset: int64(storageOffset),
RequiresGrad: requiresGrad,
}
var err error
tensor.Size, err = tupleToInt64Slice(size)
if err != nil {
return nil, err
}
tensor.Stride, err = tupleToInt64Slice(stride)
if err != nil {
return nil, err
}
return tensor, nil
}
// Rebuild Parameter:
// ==================
// RebuildTensor represents a struct to rebuild tensor back from pickle object.
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L240
type RebuildParameter struct{}
var _ Callable = &RebuildParameter{}
func (r *RebuildParameter) Call(args ...interface{}) (interface{}, error) {
if len(args) != 3 { // data(*StorageTensor), requires_grad, backward_hooks
return nil, fmt.Errorf("RebuildParameter unexpected 3 args, got %d: %#v", len(args), args)
}
tensor, ok := args[0].(*StorageTensor)
if !ok {
err := fmt.Errorf("RebuildParameter.Call() failed: unexpected arg: %#v\n", args)
return nil, err
}
return tensor, nil
}
func tupleToInt64Slice(tuple *Tuple) ([]int64, error) {
length := tuple.Len()
slice := make([]int64, length)
for i := 0; i < length; i++ {
value, ok := tuple.Get(i).(int)
if !ok {
// return nil, fmt.Errorf("tuple of ints expected. Got %#v", tuple)
fmt.Printf("WARNING: tuple of ints expected. Got %#v\n", tuple)
continue
}
slice[i] = int64(value)
}
return slice, nil
}
// Rebuild Sparse Tensor:
// =======================
// ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L178
type RebuildSparseTensor struct{}
var _ Callable = &RebuildSparseTensor{}
func (r *RebuildSparseTensor) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildSparseTensor.Call(): NotImplementedError")
}
// Rebuild Sparse CSR Tensor:
// ==========================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L187
type RebuildSparseCsrTensor struct{}
var _ Callable = &RebuildSparseCsrTensor{}
func (r *RebuildSparseCsrTensor) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildSparseCsrTensor.Call(): NotImplementedError")
}
// Rebuild Device Tensor From Numpy:
// =================================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L197
type RebuildDeviceTensorFromNumpy struct{}
var _ Callable = &RebuildDeviceTensorFromNumpy{}
func (r *RebuildDeviceTensorFromNumpy) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildDeviceTensorFromNumpy.Call(): NotImplementedError")
}
// Rebuild Meta Tensor No Storage:
// ===============================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L208
type RebuildMetaTensorNoStorage struct{}
var _ Callable = &RebuildMetaTensorNoStorage{}
func (r *RebuildMetaTensorNoStorage) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildMetaTensorNoStorage.Call(): NotImplementedError")
}
// Rebuild QTensor:
// ================
// Ref. https://github.com/pytorch/pytorch/blob/c2255c36ec121fdb998ce3db8deb7508c814b567/torch/_utils.py#L214
type RebuildQtensor struct{}
var _ Callable = &RebuildQtensor{}
func (r *RebuildQtensor) Call(args ...interface{}) (interface{}, error) {
// TODO.
panic("RebuildQtensor.Call(): NotImplementedError")
}