tensor/kind: added kind for tensor
This commit is contained in:
parent
51d5d127dc
commit
bbf8bface1
|
@ -2,8 +2,8 @@ package tensor
|
|||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
// CInt is equal to C type int. Go type is int32
|
||||
|
@ -99,4 +99,49 @@ func (k Kind) EltSizeInBytes() uint {
|
|||
return uint(0)
|
||||
}
|
||||
|
||||
// TODO: continue with devices...
|
||||
type KindDevice struct {
|
||||
Kind Kind
|
||||
Device gotch.Device
|
||||
}
|
||||
|
||||
var (
|
||||
FloatCPU KindDevice = KindDevice{Float, gotch.CPU}
|
||||
DoubleCPU KindDevice = KindDevice{Double, gotch.CPU}
|
||||
Int64CPU KindDevice = KindDevice{Int64, gotch.CPU}
|
||||
|
||||
FloatCUDA KindDevice = KindDevice{Float, gotch.CudaBuilder(0)}
|
||||
DoubleCUDA KindDevice = KindDevice{Double, gotch.CudaBuilder(0)}
|
||||
Int64CUDA KindDevice = KindDevice{Int64, gotch.CudaBuilder(0)}
|
||||
)
|
||||
|
||||
type KindTrait interface {
|
||||
GetKind() Kind
|
||||
}
|
||||
|
||||
type KindUint8 struct{}
|
||||
|
||||
func (k KindUint8) GetKind() Kind { return Uint8 }
|
||||
|
||||
type KindInt8 struct{}
|
||||
|
||||
func (k KindInt8) GetKind() Kind { return Int8 }
|
||||
|
||||
type KindInt16 struct{}
|
||||
|
||||
func (k KindInt16) GetKind() Kind { return Int16 }
|
||||
|
||||
type KindInt64 struct{}
|
||||
|
||||
func (k KindInt64) GetKind() Kind { return Int64 }
|
||||
|
||||
type KindFloat32 struct{}
|
||||
|
||||
func (k KindFloat32) GetKind() Kind { return Float }
|
||||
|
||||
type KindFloat64 struct{}
|
||||
|
||||
func (k KindFloat64) GetKind() Kind { return Double }
|
||||
|
||||
type KindBool struct{}
|
||||
|
||||
func (k KindBool) GetKind() Kind { return Bool }
|
||||
|
|
Loading…
Reference in New Issue
Block a user