feat(example/jit)
This commit is contained in:
parent
47283314de
commit
fcbc4ca870
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -11,6 +11,8 @@
|
|||
|
||||
*.txt
|
||||
*.json
|
||||
*.pt
|
||||
*.jpg
|
||||
|
||||
target/
|
||||
_build/
|
||||
|
|
58
example/jit/main.go
Normal file
58
example/jit/main.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package main
|
||||
|
||||
// This example illustrates how to use a PyTorch model trained and exported using the
|
||||
// Python JIT API.
|
||||
// See https://pytorch.org/tutorials/advanced/cpp_export.html for more details.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
var (
|
||||
modelPath string
|
||||
imageFile string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&modelPath, "modelpath", "model.pt", "full path to exported pytorch model.")
|
||||
flag.StringVar(&imageFile, "image", "image.jpg", "full path to image file.")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
imageNet := vision.NewImageNet()
|
||||
|
||||
// Load the image file and resize it to the usual imagenet dimension of 224x224.
|
||||
// image, err := imageNet.LoadImageAndResize224(imageFile)
|
||||
image, err := imageNet.LoadImage(imageFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// image.MustSave("resize224.jpg")
|
||||
|
||||
// Load the Python saved module.
|
||||
model, err := ts.ModuleLoad(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Apply the forward pass of the model to get the logits.
|
||||
output := image.MustUnsqueeze(int64(0), false).ApplyCModule(model).MustSoftmax(-1, gotch.Float, true)
|
||||
|
||||
// Print the top 5 categories for this image.
|
||||
var top5 []vision.TopItem
|
||||
|
||||
top5 = imageNet.Top(output, int64(5))
|
||||
|
||||
for _, i := range top5 {
|
||||
fmt.Printf("%v \t\t\t: %.2f%%\n", i.Label, i.Pvalue*100)
|
||||
}
|
||||
}
|
|
@ -41,6 +41,11 @@ func AtgDetach_(ptr *Ctensor, self Ctensor) {
|
|||
C.atg_detach_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_detach(tensor *, tensor self);
|
||||
func AtgDetach(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_detach(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_zero_(tensor *, tensor self);
|
||||
func AtgZero_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_zero_(ptr, self)
|
||||
|
@ -556,3 +561,11 @@ func AtgAdaptiveAvgPool2d(ptr *Ctensor, self Ctensor, outputSizeData []int64, ou
|
|||
|
||||
C.atg_adaptive_avg_pool2d(ptr, self, outputSizeDataPtr, coutputSizeLen)
|
||||
}
|
||||
|
||||
// void atg_softmax(tensor *, tensor self, int64_t dim, int dtype);
|
||||
func AtgSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
||||
|
||||
C.atg_softmax(ptr, self, cdim, cdtype)
|
||||
}
|
||||
|
|
|
@ -1062,3 +1062,25 @@ func (cm CModule) To(device gotch.Device, kind gotch.DType, nonBlocking bool) {
|
|||
log.Fatalf("CModule To method call err: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Module for CModule:
|
||||
// =============================
|
||||
|
||||
func (cm CModule) Forward(tensor Tensor) (retVal Tensor, err error) {
|
||||
|
||||
var tensors []Tensor = []Tensor{tensor}
|
||||
return cm.ForwardTs(tensors)
|
||||
}
|
||||
|
||||
// Tensor methods for CModule:
|
||||
// ======================================
|
||||
|
||||
// Apply forwards tensor itself through a module.
|
||||
func (ts Tensor) ApplyCModule(m CModule) (retVal Tensor) {
|
||||
retVal, err := m.Forward(ts)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -93,6 +93,28 @@ func (ts Tensor) Detach_() {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Detach() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgDetach(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDetach() (retVal Tensor) {
|
||||
retVal, err := ts.Detach()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) Zero_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgZero_(ptr, ts.ctensor)
|
||||
|
@ -341,6 +363,15 @@ func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err er
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSelect(dim int64, index int64, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Select(dim, index, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Narrow creates a new tensor from current tensor given dim and start index
|
||||
// and length.
|
||||
func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal Tensor, err error) {
|
||||
|
@ -359,6 +390,15 @@ func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustNarrow(dim int64, start int64, length int64, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Narrow(dim, start, length, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// IndexSelect creates a new tensor from current tensor given dim and index
|
||||
// tensor.
|
||||
func (ts Tensor) IndexSelect(dim int64, index Tensor, del bool) (retVal Tensor, err error) {
|
||||
|
@ -1652,3 +1692,29 @@ func (ts Tensor) MustAdaptiveAvgPool2D(outputSizeData []int64) (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Softmax(dim int64, dtype gotch.DType, del bool) (retVal Tensor, err error) {
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgSoftmax(ptr, ts.ctensor, dim, dtype.CInt())
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSoftmax(dim int64, dtype gotch.DType, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Softmax(dim, dtype, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -367,6 +367,14 @@ func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts Tensor) MustInt64Value(idx []int64) (retVal int64) {
|
||||
retVal, err := ts.Int64Value(idx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// RequiresGrad returns true if gradient are currently tracked for this tensor.
|
||||
func (ts Tensor) RequiresGrad() (retVal bool, err error) {
|
||||
retVal = lib.AtRequiresGrad(ts.ctensor)
|
||||
|
@ -991,11 +999,17 @@ func (r Reduction) ToInt() (retVal int) {
|
|||
|
||||
// Values returns values of tensor in a slice of float64.
|
||||
func (ts Tensor) Values() []float64 {
|
||||
clone := ts.MustShallowClone()
|
||||
clone.Detach_()
|
||||
// NOTE: this for 2D tensor.
|
||||
// TODO: flatten nd tensor to slice
|
||||
return []float64{clone.MustView([]int64{-1}, true).MustFloat64Value([]int64{-1})}
|
||||
clone := ts.MustShallowClone().MustDetach().MustView([]int64{-1}, true)
|
||||
|
||||
n := clone.MustSize()[0]
|
||||
|
||||
var values []float64
|
||||
for i := 0; i < int(n); i++ {
|
||||
val := clone.MustFloat64Value([]int64{int64(i)})
|
||||
values = append(values, val)
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
// FlatView flattens a tensor.
|
||||
|
|
|
@ -130,6 +130,8 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
|
|||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
} else {
|
||||
tensorW = tensor
|
||||
}
|
||||
|
||||
if resizeH == outH {
|
||||
|
|
|
@ -33,7 +33,7 @@ func (in ImageNet) Normalize(tensor ts.Tensor) (retVal ts.Tensor, err error) {
|
|||
in.mutex.Lock()
|
||||
defer in.mutex.Unlock()
|
||||
|
||||
res, err := tensor.Totype(gotch.Float, true)
|
||||
res, err := tensor.Totype(gotch.Float, false)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ func (in ImageNet) Normalize(tensor ts.Tensor) (retVal ts.Tensor, err error) {
|
|||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
resMean, err := resDiv1.Sub(in.mean, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
|
@ -132,6 +133,7 @@ func (in ImageNet) LoadImageAndResize224(path string) (retVal ts.Tensor, err err
|
|||
err = fmt.Errorf("ImageNet - LoadImageAndResize224/LoadImageAndResize method call: %v", err)
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return in.Normalize(tensor)
|
||||
}
|
||||
|
||||
|
@ -1288,30 +1290,33 @@ type TopItem struct {
|
|||
func (in ImageNet) Top(input ts.Tensor, k int64) (retVal []TopItem) {
|
||||
|
||||
var tensor ts.Tensor
|
||||
shape := tensor.MustSize()
|
||||
shape := input.MustSize()
|
||||
|
||||
switch {
|
||||
case reflect.DeepEqual(shape, []int64{imagenetClassCount}):
|
||||
tensor = input.MustShallowClone()
|
||||
case reflect.DeepEqual(shape, []int64{1, imagenetClassCount}):
|
||||
// TODO: check whether []int64{imagenetClassCount, -1}
|
||||
tensor = input.MustView([]int64{imagenetClassCount, -1}, false)
|
||||
tensor = input.MustView([]int64{imagenetClassCount}, false) // shape: [1000]
|
||||
case reflect.DeepEqual(shape, []int64{1, 1, imagenetClassCount}):
|
||||
tensor = input.MustView([]int64{imagenetClassCount, -1}, false)
|
||||
tensor = input.MustView([]int64{imagenetClassCount}, false) // shape: [1000]
|
||||
default:
|
||||
log.Fatalf("Unexpected tensor shape: %v\n", shape)
|
||||
}
|
||||
|
||||
valsTs, idxsTs := tensor.MustTopK(k, 0, true, true)
|
||||
vals := valsTs.Values()
|
||||
idxs := idxsTs.Values()
|
||||
|
||||
var topItems []TopItem
|
||||
|
||||
for i := 0; i < len(vals); i++ {
|
||||
vals := valsTs.Values()
|
||||
idxs := idxsTs.Values()
|
||||
|
||||
for i := 0; i < int(k); i++ {
|
||||
val := vals[i]
|
||||
idx := idxs[i]
|
||||
|
||||
item := TopItem{
|
||||
Pvalue: vals[i],
|
||||
Label: imagenetClasses[int(idxs[i])],
|
||||
Pvalue: val,
|
||||
Label: imagenetClasses[int(idx)],
|
||||
}
|
||||
topItems = append(topItems, item)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user