corrected Int64Option and Float64Option param functions
This commit is contained in:
parent
b2bc72b1d4
commit
c38c909977
|
@ -116,9 +116,9 @@ func main() {
|
|||
sumLoss += loss.Float64Values()[0]
|
||||
cntLoss += 1.0
|
||||
|
||||
batchTs.MustDrop()
|
||||
batchNarrow.MustDrop()
|
||||
xsOnehotTmp.MustDrop()
|
||||
// batchTs.MustDrop()
|
||||
// batchNarrow.MustDrop()
|
||||
// xsOnehotTmp.MustDrop()
|
||||
xsOnehot.MustDrop()
|
||||
ys.MustDrop()
|
||||
lstmOut.MustDrop()
|
||||
|
|
|
@ -117,21 +117,21 @@ func runCNN1() {
|
|||
logits := net.ForwardT(bImages, true)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
|
||||
// loss = loss.MustSetRequiresGrad(true)
|
||||
// loss = loss.MustSetRequiresGrad(true, false)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
epocLoss = loss.MustShallowClone()
|
||||
epocLoss.Detach_()
|
||||
|
||||
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Values()[0])
|
||||
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
}
|
||||
|
||||
vs.Freeze()
|
||||
// vs.Freeze()
|
||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||
vs.Unfreeze()
|
||||
// vs.Unfreeze()
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
||||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
|
|
|
@ -45,7 +45,7 @@ func runLinear() {
|
|||
})
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -271,7 +271,7 @@ func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
|
|||
h := res[2]
|
||||
w := res[3]
|
||||
|
||||
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0, false)
|
||||
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, []float64{2.0}, []float64{2.0}, false)
|
||||
})
|
||||
|
||||
return prevChannels, Layer{Val: layer}
|
||||
|
|
22
gen/gen.ml
22
gen/gen.ml
|
@ -742,7 +742,16 @@ let write_wrapper funcs filename =
|
|||
; "Split"
|
||||
; "SplitWithSizes"
|
||||
; "Unbind"
|
||||
; "Where" ]
|
||||
; "Where"
|
||||
; "Atleast1d1"
|
||||
; "Atleast2d1"
|
||||
; "Atleast3d1"
|
||||
; "Dequantize1"
|
||||
; "QuantizePerTensor1"
|
||||
; "UnsafeChunk"
|
||||
; "UnsafeSplit"
|
||||
; "UnsafeSplitWithSizes"
|
||||
; "AlignTensors" ]
|
||||
in
|
||||
if
|
||||
List.exists excluded_funcs ~f:(fun name ->
|
||||
|
@ -848,7 +857,16 @@ let write_must_wrapper funcs filename =
|
|||
; "Split"
|
||||
; "SplitWithSizes"
|
||||
; "Unbind"
|
||||
; "Where" ]
|
||||
; "Where"
|
||||
; "Atleast1d1"
|
||||
; "Atleast2d1"
|
||||
; "Atleast3d1"
|
||||
; "Dequantize1"
|
||||
; "QuantizePerTensor1"
|
||||
; "UnsafeChunk"
|
||||
; "UnsafeSplit"
|
||||
; "UnsafeSplitWithSizes"
|
||||
; "AlignTensors" ]
|
||||
in
|
||||
if
|
||||
List.exists excluded_funcs ~f:(fun name ->
|
||||
|
|
|
@ -5860,6 +5860,7 @@ func AtgSparseResizeAndClear_(ptr *Ctensor, self Ctensor, sizeData []int64, size
|
|||
C.atg_sparse_resize_and_clear_(ptr, self, csizeDataPtr, csizeLen, csparseDim, cdenseDim)
|
||||
}
|
||||
|
||||
|
||||
func AtgSqrt(ptr *Ctensor, self Ctensor){
|
||||
C.atg_sqrt(ptr, self)
|
||||
}
|
||||
|
@ -6305,6 +6306,8 @@ func AtgUniqueDimConsecutive(ptr *Ctensor, self Ctensor, dim int64, returnInvers
|
|||
C.atg_unique_dim_consecutive(ptr, self, cdim, creturnInverse, creturnCounts)
|
||||
}
|
||||
|
||||
|
||||
|
||||
func AtgUnsqueeze(ptr *Ctensor, self Ctensor, dim int64){
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
C.atg_unsqueeze(ptr, self, cdim)
|
||||
|
|
|
@ -254,6 +254,40 @@ func BatchAccuracyForLogits(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d got
|
|||
return sumAccuracy / sampleCount
|
||||
}
|
||||
|
||||
func BatchAccuracyForLogitsOld(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||
|
||||
var (
|
||||
sumAccuracy float64 = 0.0
|
||||
sampleCount float64 = 0.0
|
||||
)
|
||||
|
||||
vs.Freeze()
|
||||
defer vs.Unfreeze()
|
||||
|
||||
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
||||
for {
|
||||
item, ok := iter2.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
size := float64(item.Data.MustSize()[0])
|
||||
bImages := item.Data.MustTo(d, true)
|
||||
bLabels := item.Label.MustTo(d, true)
|
||||
|
||||
logits := m.ForwardT(bImages, false)
|
||||
acc := logits.AccuracyForLogits(bLabels)
|
||||
sumAccuracy += acc.Float64Values()[0] * size
|
||||
sampleCount += size
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
acc.MustDrop()
|
||||
}
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
}
|
||||
|
||||
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
||||
// calculate accuracy for specified batch on module weight. It uses tensor
|
||||
// indexing instead of Iter2
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -19,7 +19,7 @@ func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
|
|||
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
||||
// targets represent ground-truth.
|
||||
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
|
||||
argmax := ts.MustArgmax(-1, false, true)
|
||||
argmax := ts.MustArgmax([]int64{-1}, false, true)
|
||||
eq1 := argmax.MustEq1(targets, true)
|
||||
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
|
||||
}
|
||||
|
|
|
@ -3579,22 +3579,6 @@ func (ts *Tensor) Atleast1d(del bool) (retVal *Tensor, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func Atleast1d1(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
lib.AtgAtleast1d1(ptr, ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) Atleast2d(del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
|
@ -3610,22 +3594,6 @@ func (ts *Tensor) Atleast2d(del bool) (retVal *Tensor, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func Atleast2d1(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
lib.AtgAtleast2d1(ptr, ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) Atleast3d(del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
|
@ -3641,22 +3609,6 @@ func (ts *Tensor) Atleast3d(del bool) (retVal *Tensor, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func Atleast3d1(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
lib.AtgAtleast3d1(ptr, ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) AvgPool1d(kernelSize []int64, stride []int64, padding []int64, ceilMode bool, countIncludePad bool, del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
|
@ -6077,22 +6029,6 @@ func (ts *Tensor) Dequantize(del bool) (retVal *Tensor, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func Dequantize1(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
lib.AtgDequantize1(ptr, ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) Det(del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
|
@ -14649,22 +14585,6 @@ func (ts *Tensor) QuantizePerTensor(scale float64, zeroPoint int64, dtype gotch.
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func QuantizePerTensor1(tensors []Tensor, scales *Tensor, zeroPoints *Tensor, dtype gotch.DType) (retVal []Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
lib.AtgQuantizePerTensor1(ptr, ctensors, len(ctensors), scales.ctensor, zeroPoints.ctensor, dtype.CInt())
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func QuantizedBatchNorm(input *Tensor, weight *Tensor, bias *Tensor, mean *Tensor, vari *Tensor, eps float64, outputScale float64, outputZeroPoint int64) (retVal *Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
|
@ -18377,51 +18297,6 @@ func (ts *Tensor) Uniform_(from float64, to float64) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
func (ts *Tensor) UnsafeChunk(chunks int64, dim int64, del bool) (retVal []Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgUnsafeChunk(ptr, ts.ctensor, chunks, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) UnsafeSplit(splitSize int64, dim int64, del bool) (retVal []Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgUnsafeSplit(ptr, ts.ctensor, splitSize, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) UnsafeSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgUnsafeSplitWithSizes(ptr, ts.ctensor, splitSizes, len(splitSizes), dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) Unsqueeze(dim int64, del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
|
|
|
@ -1171,7 +1171,7 @@ func (ts *Tensor) Swish() *Tensor {
|
|||
}
|
||||
|
||||
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
|
||||
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
|
||||
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, []int64{1}, del)
|
||||
}
|
||||
|
||||
// SaveMultiNew saves a slice of named tensors to the given file path.
|
||||
|
|
|
@ -103,7 +103,7 @@ func densenet(p *nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth i
|
|||
|
||||
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
||||
tmp1 := xs.MustRelu(false)
|
||||
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, 1, true)
|
||||
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, []int64{1}, true)
|
||||
res := tmp2.FlatView()
|
||||
tmp2.MustDrop()
|
||||
return res
|
||||
|
|
|
@ -78,7 +78,7 @@ func inceptionA(p *nn.Path, cIn, cPool int64) ts.ModuleT {
|
|||
b3Ts := b3Tmp2.ApplyT(b33, train)
|
||||
b3Tmp2.MustDrop()
|
||||
|
||||
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
||||
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||
bpoolTs := bpoolTmp.ApplyT(bpool, train)
|
||||
|
||||
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||
|
@ -145,7 +145,7 @@ func inceptionC(p *nn.Path, cIn int64, c7 int64) ts.ModuleT {
|
|||
b3Ts := b3Tmp4.ApplyT(b35, train)
|
||||
b3Tmp4.MustDrop()
|
||||
|
||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||
|
||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||
|
@ -211,7 +211,7 @@ func inceptionE(p *nn.Path, cIn int64) ts.ModuleT {
|
|||
b3bTs := b3Tmp2.ApplyT(b33b, train)
|
||||
b3Ts := ts.MustCat([]ts.Tensor{*b3aTs, *b3bTs}, 1)
|
||||
|
||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||
|
||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user