feat(example/jit)

This commit is contained in:
sugarme 2020-06-30 20:01:01 +10:00
parent 47283314de
commit fcbc4ca870
8 changed files with 197 additions and 15 deletions

2
.gitignore vendored
View File

@ -11,6 +11,8 @@
*.txt
*.json
*.pt
*.jpg
target/
_build/

58
example/jit/main.go Normal file
View File

@ -0,0 +1,58 @@
package main
// This example illustrates how to use a PyTorch model trained and exported using the
// Python JIT API.
// See https://pytorch.org/tutorials/advanced/cpp_export.html for more details.
import (
"flag"
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
var (
modelPath string
imageFile string
)
func init() {
flag.StringVar(&modelPath, "modelpath", "model.pt", "full path to exported pytorch model.")
flag.StringVar(&imageFile, "image", "image.jpg", "full path to image file.")
}
func main() {
flag.Parse()
imageNet := vision.NewImageNet()
// Load the image file and resize it to the usual imagenet dimension of 224x224.
// image, err := imageNet.LoadImageAndResize224(imageFile)
image, err := imageNet.LoadImage(imageFile)
if err != nil {
log.Fatal(err)
}
// image.MustSave("resize224.jpg")
// Load the Python saved module.
model, err := ts.ModuleLoad(modelPath)
if err != nil {
log.Fatal(err)
}
// Apply the forward pass of the model to get the logits.
output := image.MustUnsqueeze(int64(0), false).ApplyCModule(model).MustSoftmax(-1, gotch.Float, true)
// Print the top 5 categories for this image.
var top5 []vision.TopItem
top5 = imageNet.Top(output, int64(5))
for _, i := range top5 {
fmt.Printf("%v \t\t\t: %.2f%%\n", i.Label, i.Pvalue*100)
}
}

View File

@ -41,6 +41,11 @@ func AtgDetach_(ptr *Ctensor, self Ctensor) {
C.atg_detach_(ptr, self)
}
// void atg_detach(tensor *, tensor self);
func AtgDetach(ptr *Ctensor, self Ctensor) {
C.atg_detach(ptr, self)
}
// void atg_zero_(tensor *, tensor self);
func AtgZero_(ptr *Ctensor, self Ctensor) {
C.atg_zero_(ptr, self)
@ -556,3 +561,11 @@ func AtgAdaptiveAvgPool2d(ptr *Ctensor, self Ctensor, outputSizeData []int64, ou
C.atg_adaptive_avg_pool2d(ptr, self, outputSizeDataPtr, coutputSizeLen)
}
// void atg_softmax(tensor *, tensor self, int64_t dim, int dtype);
func AtgSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
C.atg_softmax(ptr, self, cdim, cdtype)
}

View File

@ -1062,3 +1062,25 @@ func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
log.Fatalf("CModule To method call err: %v\n", err)
}
}
// Implement Module for CModule:
// =============================
func (cm CModule) Forward(tensor Tensor) (retVal Tensor, err error) {
var tensors []Tensor = []Tensor{tensor}
return cm.ForwardTs(tensors)
}
// Tensor methods for CModule:
// ======================================
// Apply forwards tensor itself through a module.
func (ts Tensor) ApplyCModule(m CModule) (retVal Tensor) {
retVal, err := m.Forward(ts)
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -93,6 +93,28 @@ func (ts Tensor) Detach_() {
}
}
func (ts Tensor) Detach() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgDetach(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustDetach() (retVal Tensor) {
retVal, err := ts.Detach()
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Zero_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgZero_(ptr, ts.ctensor)
@ -341,6 +363,15 @@ func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err er
return retVal, nil
}
func (ts Tensor) MustSelect(dim int64, index int64, del bool) (retVal Tensor) {
retVal, err := ts.Select(dim, index, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Narrow creates a new tensor from current tensor given dim and start index
// and length.
func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal Tensor, err error) {
@ -359,6 +390,15 @@ func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal
return retVal, nil
}
func (ts Tensor) MustNarrow(dim int64, start int64, length int64, del bool) (retVal Tensor) {
retVal, err := ts.Narrow(dim, start, length, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// IndexSelect creates a new tensor from current tensor given dim and index
// tensor.
func (ts Tensor) IndexSelect(dim int64, index Tensor, del bool) (retVal Tensor, err error) {
@ -1652,3 +1692,29 @@ func (ts Tensor) MustAdaptiveAvgPool2D(outputSizeData []int64) (retVal Tensor) {
return retVal
}
func (ts Tensor) Softmax(dim int64, dtype gotch.DType, del bool) (retVal Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgSoftmax(ptr, ts.ctensor, dim, dtype.CInt())
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustSoftmax(dim int64, dtype gotch.DType, del bool) (retVal Tensor) {
retVal, err := ts.Softmax(dim, dtype, del)
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -367,6 +367,14 @@ func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
return retVal, err
}
func (ts Tensor) MustInt64Value(idx []int64) (retVal int64) {
retVal, err := ts.Int64Value(idx)
if err != nil {
log.Fatal(err)
}
return retVal
}
// RequiresGrad returns true if gradient are currently tracked for this tensor.
func (ts Tensor) RequiresGrad() (retVal bool, err error) {
retVal = lib.AtRequiresGrad(ts.ctensor)
@ -991,11 +999,17 @@ func (r Reduction) ToInt() (retVal int) {
// Values returns values of tensor in a slice of float64.
func (ts Tensor) Values() []float64 {
clone := ts.MustShallowClone()
clone.Detach_()
// NOTE: this for 2D tensor.
// TODO: flatten nd tensor to slice
return []float64{clone.MustView([]int64{-1}, true).MustFloat64Value([]int64{-1})}
clone := ts.MustShallowClone().MustDetach().MustView([]int64{-1}, true)
n := clone.MustSize()[0]
var values []float64
for i := 0; i < int(n); i++ {
val := clone.MustFloat64Value([]int64{int64(i)})
values = append(values, val)
}
return values
}
// FlatView flattens a tensor.

View File

@ -130,6 +130,8 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err
}
} else {
tensorW = tensor
}
if resizeH == outH {

View File

@ -33,7 +33,7 @@ func (in ImageNet) Normalize(tensor ts.Tensor) (retVal ts.Tensor, err error) {
in.mutex.Lock()
defer in.mutex.Unlock()
res, err := tensor.Totype(gotch.Float, true)
res, err := tensor.Totype(gotch.Float, false)
if err != nil {
return retVal, err
}
@ -42,6 +42,7 @@ func (in ImageNet) Normalize(tensor ts.Tensor) (retVal ts.Tensor, err error) {
if err != nil {
return retVal, err
}
resMean, err := resDiv1.Sub(in.mean, true)
if err != nil {
return retVal, err
@ -132,6 +133,7 @@ func (in ImageNet) LoadImageAndResize224(path string) (retVal ts.Tensor, err err
err = fmt.Errorf("ImageNet - LoadImageAndResize224/LoadImageAndResize method call: %v", err)
return retVal, err
}
return in.Normalize(tensor)
}
@ -1288,30 +1290,33 @@ type TopItem struct {
func (in ImageNet) Top(input ts.Tensor, k int64) (retVal []TopItem) {
var tensor ts.Tensor
shape := tensor.MustSize()
shape := input.MustSize()
switch {
case reflect.DeepEqual(shape, []int64{imagenetClassCount}):
tensor = input.MustShallowClone()
case reflect.DeepEqual(shape, []int64{1, imagenetClassCount}):
// TODO: check whether []int64{imagenetClassCount, -1}
tensor = input.MustView([]int64{imagenetClassCount, -1}, false)
tensor = input.MustView([]int64{imagenetClassCount}, false) // shape: [1000]
case reflect.DeepEqual(shape, []int64{1, 1, imagenetClassCount}):
tensor = input.MustView([]int64{imagenetClassCount, -1}, false)
tensor = input.MustView([]int64{imagenetClassCount}, false) // shape: [1000]
default:
log.Fatalf("Unexpected tensor shape: %v\n", shape)
}
valsTs, idxsTs := tensor.MustTopK(k, 0, true, true)
vals := valsTs.Values()
idxs := idxsTs.Values()
var topItems []TopItem
for i := 0; i < len(vals); i++ {
vals := valsTs.Values()
idxs := idxsTs.Values()
for i := 0; i < int(k); i++ {
val := vals[i]
idx := idxs[i]
item := TopItem{
Pvalue: vals[i],
Label: imagenetClasses[int(idxs[i])],
Pvalue: val,
Label: imagenetClasses[int(idx)],
}
topItems = append(topItems, item)
}