From cf61333fab596df239d576a279bd448eb3c1e786 Mon Sep 17 00:00:00 2001 From: sugarme Date: Tue, 16 Jun 2020 01:59:41 +1000 Subject: [PATCH] WIP(example/mnist): auto-grad still not working --- example/mnist/linear.go | 32 ++-- libtch/c-generated-sample.go | 33 ++++ tensor/tensor-generated-sample.go | 243 +++++++++++++++++++++++++++++- tensor/tensor.go | 17 +++ vision/mnist.go | 205 ++++++++++++------------- 5 files changed, 411 insertions(+), 119 deletions(-) diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 1395f58..589a4cc 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "log" "github.com/sugarme/gotch" ts "github.com/sugarme/gotch/tensor" @@ -27,21 +26,32 @@ func runLinear() { fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize()) device := (gotch.CPU).CInt() - dtype := (gotch.Double).CInt() + dtype := (gotch.Float).CInt() ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true) - bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true) - fmt.Println(ws.MustSize()) - fmt.Println(bs.MustSize()) - for epoch := 0; epoch < epochs; epoch++ { - } -} + logits := ds.TrainImages.MustMm(ws).MustAdd(bs) + loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels) + + ws.ZeroGrad() + bs.ZeroGrad() + loss.Backward() + + wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)) + bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)) + + wsClone := ws.MustShallowClone() + bsClone := bs.MustShallowClone() + + // wsClone.MustAdd_(wsGrad) + // bsClone.MustAdd_(bsGrad) + + testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad)) + testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + + fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100) -func handleError(err error) { - if err != nil { - log.Fatal(err) } } diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index 633bbfa..44b35ad 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -57,11 +57,26 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_mul(ptr, self, other) } +// void atg_mul_(tensor *, tensor self, tensor other); +func AtgMul_(ptr *Ctensor, self Ctensor, other Ctensor) { + C.atg_mul_(ptr, self, other) +} + +// void atg_mul1(tensor *, tensor self, scalar other); +func AtgMul1(ptr *Ctensor, self Ctensor, other Cscalar) { + C.atg_mul1(ptr, self, other) +} + // void atg_add(tensor *, tensor self, tensor other); func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_add(ptr, self, other) } +// void atg_add_(tensor *, tensor self, tensor other); +func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) { + C.atg_add_(ptr, self, other) +} + // void atg_totype(tensor *, tensor self, int scalar_type); func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) { cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type)) @@ -200,3 +215,21 @@ func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) { C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim) } + +// void atg_mm(tensor *, tensor self, tensor mat2); +func AtgMm(ptr *Ctensor, self Ctensor, mat2 Ctensor) { + C.atg_mm(ptr, self, mat2) +} + +// void atg_view(tensor *, tensor self, int64_t *size_data, int size_len); +func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) { + sizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0])) + csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen)) + + C.atg_view(ptr, self, sizeDataPtr, csizeLen) +} + +// void atg_div1(tensor *, tensor self, scalar other); +func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) { + C.atg_div1(ptr, self, other) +} diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index b21e3d2..429db20 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -101,25 +101,23 @@ func (ts Tensor) MustDetach_() (retVal Tensor) { return retVal } -func (ts Tensor) Zero_() (retVal Tensor, err error) { +func (ts Tensor) Zero_() (err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) lib.AtgZero_(ptr, ts.ctensor) if err = TorchErr(); err != nil { - return retVal, err + return err } - return Tensor{ctensor: *ptr}, nil + return nil } -func (ts Tensor) MustZero_() (retVal Tensor) { - retVal, err := ts.Zero_() +func (ts Tensor) MustZero_() { + err := ts.Zero_() if err != nil { log.Fatal(err) } - - return retVal } func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) { @@ -170,6 +168,46 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) { return retVal } +func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + lib.AtgMul1(ptr, ts.ctensor, other.cscalar) + + if err = TorchErr(); err != nil { + return retVal, err + } + + return Tensor{ctensor: *ptr}, nil +} + +func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) { + retVal, err := ts.Mul1(other) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Mul_(other Tensor) (err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + lib.AtgMul_(ptr, ts.ctensor, other.ctensor) + + if err = TorchErr(); err != nil { + return err + } + + return nil +} + +func (ts Tensor) MustMul_(other Tensor) { + err := ts.Mul_(other) + if err != nil { + log.Fatal(err) + } +} + func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) @@ -191,6 +229,26 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) { return retVal } +func (ts Tensor) Add_(other Tensor) (err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + lib.AtgAdd_(ptr, ts.ctensor, other.ctensor) + + if err = TorchErr(); err != nil { + return err + } + + return nil + +} + +func (ts Tensor) MustAdd_(other Tensor) { + err := ts.Add_(other) + if err != nil { + log.Fatal(err) + } +} + func (ts Tensor) AddG(other Tensor) (err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) @@ -468,3 +526,174 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) { return retVal, nil } + +func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgMm(ptr, ts.ctensor, mat2.ctensor) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) { + retVal, err := ts.Mm(mat2) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) { + retVal, err := ts.LogSoftmax(dim, dtype) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + + weight := NewTensor() + + reduction := int64(1) // Mean of loss + ignoreIndex := int64(-100) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) { + retVal, err := ts.NllLoss(target) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + var ckeepDim int = 0 + if keepDim { + ckeepDim = 1 + } + + lib.AtgArgmax(ptr, ts.ctensor, dim, ckeepDim) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) { + retVal, err := ts.Argmax(dim, keepDim) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgMean(ptr, ts.ctensor, dtype) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustMean(dtype int32) (retVal Tensor) { + retVal, err := ts.Mean(dtype) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData)) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) { + retVal, err := ts.View(sizeData) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgDiv1(ptr, ts.ctensor, other.cscalar) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) { + retVal, err := ts.Div1(other) + if err != nil { + log.Fatal(err) + } + + return retVal +} diff --git a/tensor/tensor.go b/tensor/tensor.go index 2849c5f..2f2359c 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -311,6 +311,15 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) { } +func (ts Tensor) MustEq1(other Tensor) (retVal Tensor) { + retVal, err := ts.Eq1(other) + if err != nil { + log.Fatal(err) + } + + return retVal +} + // Float64Value returns a float value on tensors holding a single element. // An error is returned otherwise. // double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len); @@ -330,6 +339,14 @@ func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) { return retVal, err } +func (ts Tensor) MustFloat64Value(idx []int64) (retVal float64) { + retVal, err := ts.Float64Value(idx) + if err != nil { + log.Fatal(err) + } + return retVal +} + // Int64Value returns an int value on tensors holding a single element. An error is // returned otherwise. func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) { diff --git a/vision/mnist.go b/vision/mnist.go index c7b9138..2f711e4 100644 --- a/vision/mnist.go +++ b/vision/mnist.go @@ -6,123 +6,126 @@ package vision // http://yann.lecun.com/exdb/mnist/ import ( - "encoding/binary" - "io" + "fmt" "log" "os" "path/filepath" + "github.com/sugarme/gotch" ts "github.com/sugarme/gotch/tensor" ) -// Image holds the pixel intensities of an image. -// 255 is foreground (black), 0 is background (white). -type RawImage []byte - -const numLabels = 10 -const pixelRange = 255 - -const ( - imageMagic = 0x00000803 - labelMagic = 0x00000801 - // Width of the input tensor / picture - Width = 28 - // Height of the input tensor / picture - Height = 28 -) - -func readLabels(r io.Reader, e error) (retVal ts.Tensor) { - if e != nil { - log.Fatalf("readLabels errors: %v\n", e) +// readInt32 read 4 bytes and convert to MSB first (big endian) interger. +func readInt32(f *os.File) (retVal int, err error) { + buf := make([]byte, 4) + n, err := f.Read(buf) + switch { + case err != nil: + return 0, err + case n != 4: + err = fmt.Errorf("Invalid format: %v", f.Name()) + return 0, err } - var ( - magic int32 - n int32 - err error - ) - - // Check magic number - if err = binary.Read(r, binary.BigEndian, &magic); err != nil { - log.Fatalf("readLabels - binary.Read error: %v\n", err) - } - if magic != labelMagic { - log.Fatal(os.ErrInvalid) + // flip to big endian + var v int = 0 + for _, i := range buf { + v = v*256 + int(i) } - // Now decode number - if err = binary.Read(r, binary.BigEndian, &n); err != nil { - log.Fatalf("readLabels - binary.Read error: %v\n", err) - } + return v, nil +} - // label is a digit number range 0 - 9 - labels := make([]uint8, n) - for i := 0; i < int(n); i++ { - var l uint8 - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - log.Fatalf("readLabels - binary.Read error: %v\n", err) - } - labels[i] = l - } - - retVal, err = ts.OfSlice(labels) +// checkMagicNumber checks the magic number located at the first 4 bytes of +// mnist files. +func checkMagicNumber(f *os.File, wantNumber int) (err error) { + gotNumber, err := readInt32(f) if err != nil { - log.Fatalf("readLabels - ts.OfSlice error: %v\n", err) + return err } + if gotNumber != wantNumber { + err = fmt.Errorf("incorrect magic number: got %v want %v\n", gotNumber, wantNumber) + return err + } + + return nil +} + +func readLabels(filename string) (retVal ts.Tensor) { + + f, err := os.Open(filename) + if err != nil { + log.Fatalf("readLabels errors: %v\n", err) + } + defer f.Close() + + if err = checkMagicNumber(f, 2049); err != nil { + log.Fatal(err) + } + + samples, err := readInt32(f) + if err != nil { + log.Fatal(err) + } + + var data []uint8 = make([]uint8, samples) + len, err := f.Read(data) + if err != nil || len != samples { + err = fmt.Errorf("invalid format %v", f.Name()) + log.Fatal(err) + } + + labelsTs, err := ts.OfSlice(data) + if err != nil { + err = fmt.Errorf("create label tensor err.") + log.Fatal(err) + } + + retVal = labelsTs.MustTotype(gotch.Int64) + return retVal } -func readImages(r io.Reader, e error) (retVal ts.Tensor) { - if e != nil { - log.Fatalf("readLabels errors: %v\n", e) - } - - var ( - magic int32 - n int32 - nrow int32 - ncol int32 - err error - ) - - // Check magic number - if err = binary.Read(r, binary.BigEndian, &magic); err != nil { - log.Fatalf("readImages - binary.Read error: %v\n", err) - } - - if magic != imageMagic { - log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid - } - - // Now, decode image - if err = binary.Read(r, binary.BigEndian, &n); err != nil { - log.Fatalf("readImages - binary.Read error: %v\n", err) - } - if err = binary.Read(r, binary.BigEndian, &nrow); err != nil { - log.Fatalf("readImages - binary.Read error: %v\n", err) - } - if err = binary.Read(r, binary.BigEndian, &ncol); err != nil { - log.Fatalf("readImages - binary.Read error: %v\n", err) - } - - imgs := make([]RawImage, n) - m := int(nrow * ncol) - for i := 0; i < int(n); i++ { - imgs[i] = make(RawImage, m) - m_, err := io.ReadFull(r, imgs[i]) - if err != nil { - log.Fatalf("readImages - io.ReadFull error: %v\n", err) - } - if m_ != int(m) { - log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid) - } - } - - retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)}) +func readImages(filename string) (retVal ts.Tensor) { + f, err := os.Open(filename) if err != nil { - log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err) + log.Fatalf("readImages errors: %v\n", err) } + defer f.Close() + + if err = checkMagicNumber(f, 2051); err != nil { + log.Fatal(err) + } + + samples, err := readInt32(f) + if err != nil { + log.Fatal(err) + } + + rows, err := readInt32(f) + if err != nil { + log.Fatal(err) + } + cols, err := readInt32(f) + if err != nil { + log.Fatal(err) + } + + dataLen := samples * rows * cols + var data []uint8 = make([]uint8, dataLen) + len, err := f.Read(data) + if err != nil || len != dataLen { + err = fmt.Errorf("invalid format %v", f.Name()) + log.Fatal(err) + } + + imagesTs, err := ts.OfSlice(data) + if err != nil { + err = fmt.Errorf("create images tensor err.") + log.Fatal(err) + } + retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0)) return retVal } @@ -141,10 +144,10 @@ func LoadMNISTDir(dir string) (retVal Dataset) { testLabelsFile := filepath.Join(dir, testLabels) testImagesFile := filepath.Join(dir, testImages) - trainImagesTs := readImages(os.Open(trainImagesFile)) - trainLabelsTs := readLabels(os.Open(trainLabelsFile)) - testImagesTs := readImages(os.Open(testImagesFile)) - testLabelsTs := readLabels(os.Open(testLabelsFile)) + trainImagesTs := readImages(trainImagesFile) + trainLabelsTs := readLabels(trainLabelsFile) + testImagesTs := readImages(testImagesFile) + testLabelsTs := readLabels(testLabelsFile) return Dataset{ TrainImages: trainImagesTs,