corrected Int64Option and Float64Option param functions
This commit is contained in:
parent
b2bc72b1d4
commit
c38c909977
|
@ -116,9 +116,9 @@ func main() {
|
||||||
sumLoss += loss.Float64Values()[0]
|
sumLoss += loss.Float64Values()[0]
|
||||||
cntLoss += 1.0
|
cntLoss += 1.0
|
||||||
|
|
||||||
batchTs.MustDrop()
|
// batchTs.MustDrop()
|
||||||
batchNarrow.MustDrop()
|
// batchNarrow.MustDrop()
|
||||||
xsOnehotTmp.MustDrop()
|
// xsOnehotTmp.MustDrop()
|
||||||
xsOnehot.MustDrop()
|
xsOnehot.MustDrop()
|
||||||
ys.MustDrop()
|
ys.MustDrop()
|
||||||
lstmOut.MustDrop()
|
lstmOut.MustDrop()
|
||||||
|
|
|
@ -117,21 +117,21 @@ func runCNN1() {
|
||||||
logits := net.ForwardT(bImages, true)
|
logits := net.ForwardT(bImages, true)
|
||||||
loss := logits.CrossEntropyForLogits(bLabels)
|
loss := logits.CrossEntropyForLogits(bLabels)
|
||||||
|
|
||||||
// loss = loss.MustSetRequiresGrad(true)
|
// loss = loss.MustSetRequiresGrad(true, false)
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
epocLoss = loss.MustShallowClone()
|
epocLoss = loss.MustShallowClone()
|
||||||
epocLoss.Detach_()
|
epocLoss.Detach_()
|
||||||
|
|
||||||
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Values()[0])
|
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])
|
||||||
|
|
||||||
bImages.MustDrop()
|
bImages.MustDrop()
|
||||||
bLabels.MustDrop()
|
bLabels.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
vs.Freeze()
|
// vs.Freeze()
|
||||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||||
vs.Unfreeze()
|
// vs.Unfreeze()
|
||||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
||||||
if testAccuracy > bestAccuracy {
|
if testAccuracy > bestAccuracy {
|
||||||
bestAccuracy = testAccuracy
|
bestAccuracy = testAccuracy
|
||||||
|
|
|
@ -45,7 +45,7 @@ func runLinear() {
|
||||||
})
|
})
|
||||||
|
|
||||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -271,7 +271,7 @@ func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
|
||||||
h := res[2]
|
h := res[2]
|
||||||
w := res[3]
|
w := res[3]
|
||||||
|
|
||||||
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0, false)
|
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, []float64{2.0}, []float64{2.0}, false)
|
||||||
})
|
})
|
||||||
|
|
||||||
return prevChannels, Layer{Val: layer}
|
return prevChannels, Layer{Val: layer}
|
||||||
|
|
22
gen/gen.ml
22
gen/gen.ml
|
@ -742,7 +742,16 @@ let write_wrapper funcs filename =
|
||||||
; "Split"
|
; "Split"
|
||||||
; "SplitWithSizes"
|
; "SplitWithSizes"
|
||||||
; "Unbind"
|
; "Unbind"
|
||||||
; "Where" ]
|
; "Where"
|
||||||
|
; "Atleast1d1"
|
||||||
|
; "Atleast2d1"
|
||||||
|
; "Atleast3d1"
|
||||||
|
; "Dequantize1"
|
||||||
|
; "QuantizePerTensor1"
|
||||||
|
; "UnsafeChunk"
|
||||||
|
; "UnsafeSplit"
|
||||||
|
; "UnsafeSplitWithSizes"
|
||||||
|
; "AlignTensors" ]
|
||||||
in
|
in
|
||||||
if
|
if
|
||||||
List.exists excluded_funcs ~f:(fun name ->
|
List.exists excluded_funcs ~f:(fun name ->
|
||||||
|
@ -848,7 +857,16 @@ let write_must_wrapper funcs filename =
|
||||||
; "Split"
|
; "Split"
|
||||||
; "SplitWithSizes"
|
; "SplitWithSizes"
|
||||||
; "Unbind"
|
; "Unbind"
|
||||||
; "Where" ]
|
; "Where"
|
||||||
|
; "Atleast1d1"
|
||||||
|
; "Atleast2d1"
|
||||||
|
; "Atleast3d1"
|
||||||
|
; "Dequantize1"
|
||||||
|
; "QuantizePerTensor1"
|
||||||
|
; "UnsafeChunk"
|
||||||
|
; "UnsafeSplit"
|
||||||
|
; "UnsafeSplitWithSizes"
|
||||||
|
; "AlignTensors" ]
|
||||||
in
|
in
|
||||||
if
|
if
|
||||||
List.exists excluded_funcs ~f:(fun name ->
|
List.exists excluded_funcs ~f:(fun name ->
|
||||||
|
|
|
@ -5860,6 +5860,7 @@ func AtgSparseResizeAndClear_(ptr *Ctensor, self Ctensor, sizeData []int64, size
|
||||||
C.atg_sparse_resize_and_clear_(ptr, self, csizeDataPtr, csizeLen, csparseDim, cdenseDim)
|
C.atg_sparse_resize_and_clear_(ptr, self, csizeDataPtr, csizeLen, csparseDim, cdenseDim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func AtgSqrt(ptr *Ctensor, self Ctensor){
|
func AtgSqrt(ptr *Ctensor, self Ctensor){
|
||||||
C.atg_sqrt(ptr, self)
|
C.atg_sqrt(ptr, self)
|
||||||
}
|
}
|
||||||
|
@ -6305,6 +6306,8 @@ func AtgUniqueDimConsecutive(ptr *Ctensor, self Ctensor, dim int64, returnInvers
|
||||||
C.atg_unique_dim_consecutive(ptr, self, cdim, creturnInverse, creturnCounts)
|
C.atg_unique_dim_consecutive(ptr, self, cdim, creturnInverse, creturnCounts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func AtgUnsqueeze(ptr *Ctensor, self Ctensor, dim int64){
|
func AtgUnsqueeze(ptr *Ctensor, self Ctensor, dim int64){
|
||||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
C.atg_unsqueeze(ptr, self, cdim)
|
C.atg_unsqueeze(ptr, self, cdim)
|
||||||
|
|
|
@ -254,6 +254,40 @@ func BatchAccuracyForLogits(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d got
|
||||||
return sumAccuracy / sampleCount
|
return sumAccuracy / sampleCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BatchAccuracyForLogitsOld(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||||
|
|
||||||
|
var (
|
||||||
|
sumAccuracy float64 = 0.0
|
||||||
|
sampleCount float64 = 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
vs.Freeze()
|
||||||
|
defer vs.Unfreeze()
|
||||||
|
|
||||||
|
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
||||||
|
for {
|
||||||
|
item, ok := iter2.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
size := float64(item.Data.MustSize()[0])
|
||||||
|
bImages := item.Data.MustTo(d, true)
|
||||||
|
bLabels := item.Label.MustTo(d, true)
|
||||||
|
|
||||||
|
logits := m.ForwardT(bImages, false)
|
||||||
|
acc := logits.AccuracyForLogits(bLabels)
|
||||||
|
sumAccuracy += acc.Float64Values()[0] * size
|
||||||
|
sampleCount += size
|
||||||
|
|
||||||
|
bImages.MustDrop()
|
||||||
|
bLabels.MustDrop()
|
||||||
|
acc.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return sumAccuracy / sampleCount
|
||||||
|
}
|
||||||
|
|
||||||
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
||||||
// calculate accuracy for specified batch on module weight. It uses tensor
|
// calculate accuracy for specified batch on module weight. It uses tensor
|
||||||
// indexing instead of Iter2
|
// indexing instead of Iter2
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -19,7 +19,7 @@ func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
|
||||||
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
||||||
// targets represent ground-truth.
|
// targets represent ground-truth.
|
||||||
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
|
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
|
||||||
argmax := ts.MustArgmax(-1, false, true)
|
argmax := ts.MustArgmax([]int64{-1}, false, true)
|
||||||
eq1 := argmax.MustEq1(targets, true)
|
eq1 := argmax.MustEq1(targets, true)
|
||||||
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
|
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3579,22 +3579,6 @@ func (ts *Tensor) Atleast1d(del bool) (retVal *Tensor, err error) {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Atleast1d1(tensors []Tensor) (retVal []Tensor, err error) {
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
var ctensors []lib.Ctensor
|
|
||||||
for _, t := range tensors {
|
|
||||||
ctensors = append(ctensors, t.ctensor)
|
|
||||||
}
|
|
||||||
lib.AtgAtleast1d1(ptr, ctensors, len(ctensors))
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) Atleast2d(del bool) (retVal *Tensor, err error) {
|
func (ts *Tensor) Atleast2d(del bool) (retVal *Tensor, err error) {
|
||||||
if del {
|
if del {
|
||||||
defer ts.MustDrop()
|
defer ts.MustDrop()
|
||||||
|
@ -3610,22 +3594,6 @@ func (ts *Tensor) Atleast2d(del bool) (retVal *Tensor, err error) {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Atleast2d1(tensors []Tensor) (retVal []Tensor, err error) {
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
var ctensors []lib.Ctensor
|
|
||||||
for _, t := range tensors {
|
|
||||||
ctensors = append(ctensors, t.ctensor)
|
|
||||||
}
|
|
||||||
lib.AtgAtleast2d1(ptr, ctensors, len(ctensors))
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) Atleast3d(del bool) (retVal *Tensor, err error) {
|
func (ts *Tensor) Atleast3d(del bool) (retVal *Tensor, err error) {
|
||||||
if del {
|
if del {
|
||||||
defer ts.MustDrop()
|
defer ts.MustDrop()
|
||||||
|
@ -3641,22 +3609,6 @@ func (ts *Tensor) Atleast3d(del bool) (retVal *Tensor, err error) {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Atleast3d1(tensors []Tensor) (retVal []Tensor, err error) {
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
var ctensors []lib.Ctensor
|
|
||||||
for _, t := range tensors {
|
|
||||||
ctensors = append(ctensors, t.ctensor)
|
|
||||||
}
|
|
||||||
lib.AtgAtleast3d1(ptr, ctensors, len(ctensors))
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) AvgPool1d(kernelSize []int64, stride []int64, padding []int64, ceilMode bool, countIncludePad bool, del bool) (retVal *Tensor, err error) {
|
func (ts *Tensor) AvgPool1d(kernelSize []int64, stride []int64, padding []int64, ceilMode bool, countIncludePad bool, del bool) (retVal *Tensor, err error) {
|
||||||
if del {
|
if del {
|
||||||
defer ts.MustDrop()
|
defer ts.MustDrop()
|
||||||
|
@ -6077,22 +6029,6 @@ func (ts *Tensor) Dequantize(del bool) (retVal *Tensor, err error) {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Dequantize1(tensors []Tensor) (retVal []Tensor, err error) {
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
var ctensors []lib.Ctensor
|
|
||||||
for _, t := range tensors {
|
|
||||||
ctensors = append(ctensors, t.ctensor)
|
|
||||||
}
|
|
||||||
lib.AtgDequantize1(ptr, ctensors, len(ctensors))
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) Det(del bool) (retVal *Tensor, err error) {
|
func (ts *Tensor) Det(del bool) (retVal *Tensor, err error) {
|
||||||
if del {
|
if del {
|
||||||
defer ts.MustDrop()
|
defer ts.MustDrop()
|
||||||
|
@ -14649,22 +14585,6 @@ func (ts *Tensor) QuantizePerTensor(scale float64, zeroPoint int64, dtype gotch.
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuantizePerTensor1(tensors []Tensor, scales *Tensor, zeroPoints *Tensor, dtype gotch.DType) (retVal []Tensor, err error) {
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
var ctensors []lib.Ctensor
|
|
||||||
for _, t := range tensors {
|
|
||||||
ctensors = append(ctensors, t.ctensor)
|
|
||||||
}
|
|
||||||
lib.AtgQuantizePerTensor1(ptr, ctensors, len(ctensors), scales.ctensor, zeroPoints.ctensor, dtype.CInt())
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func QuantizedBatchNorm(input *Tensor, weight *Tensor, bias *Tensor, mean *Tensor, vari *Tensor, eps float64, outputScale float64, outputZeroPoint int64) (retVal *Tensor, err error) {
|
func QuantizedBatchNorm(input *Tensor, weight *Tensor, bias *Tensor, mean *Tensor, vari *Tensor, eps float64, outputScale float64, outputZeroPoint int64) (retVal *Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
@ -18377,51 +18297,6 @@ func (ts *Tensor) Uniform_(from float64, to float64) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensor) UnsafeChunk(chunks int64, dim int64, del bool) (retVal []Tensor, err error) {
|
|
||||||
if del {
|
|
||||||
defer ts.MustDrop()
|
|
||||||
}
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
lib.AtgUnsafeChunk(ptr, ts.ctensor, chunks, dim)
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) UnsafeSplit(splitSize int64, dim int64, del bool) (retVal []Tensor, err error) {
|
|
||||||
if del {
|
|
||||||
defer ts.MustDrop()
|
|
||||||
}
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
lib.AtgUnsafeSplit(ptr, ts.ctensor, splitSize, dim)
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) UnsafeSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor, err error) {
|
|
||||||
if del {
|
|
||||||
defer ts.MustDrop()
|
|
||||||
}
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
||||||
|
|
||||||
lib.AtgUnsafeSplitWithSizes(ptr, ts.ctensor, splitSizes, len(splitSizes), dim)
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
|
||||||
|
|
||||||
return retVal, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *Tensor) Unsqueeze(dim int64, del bool) (retVal *Tensor, err error) {
|
func (ts *Tensor) Unsqueeze(dim int64, del bool) (retVal *Tensor, err error) {
|
||||||
if del {
|
if del {
|
||||||
defer ts.MustDrop()
|
defer ts.MustDrop()
|
||||||
|
|
|
@ -1171,7 +1171,7 @@ func (ts *Tensor) Swish() *Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
|
func (ts *Tensor) AvgPool2DDefault(ksize int64, del bool) *Tensor {
|
||||||
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
|
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, []int64{1}, del)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveMultiNew saves a slice of named tensors to the given file path.
|
// SaveMultiNew saves a slice of named tensors to the given file path.
|
||||||
|
|
|
@ -103,7 +103,7 @@ func densenet(p *nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth i
|
||||||
|
|
||||||
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
||||||
tmp1 := xs.MustRelu(false)
|
tmp1 := xs.MustRelu(false)
|
||||||
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, 1, true)
|
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, []int64{1}, true)
|
||||||
res := tmp2.FlatView()
|
res := tmp2.FlatView()
|
||||||
tmp2.MustDrop()
|
tmp2.MustDrop()
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -78,7 +78,7 @@ func inceptionA(p *nn.Path, cIn, cPool int64) ts.ModuleT {
|
||||||
b3Ts := b3Tmp2.ApplyT(b33, train)
|
b3Ts := b3Tmp2.ApplyT(b33, train)
|
||||||
b3Tmp2.MustDrop()
|
b3Tmp2.MustDrop()
|
||||||
|
|
||||||
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||||
bpoolTs := bpoolTmp.ApplyT(bpool, train)
|
bpoolTs := bpoolTmp.ApplyT(bpool, train)
|
||||||
|
|
||||||
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||||
|
@ -145,7 +145,7 @@ func inceptionC(p *nn.Path, cIn int64, c7 int64) ts.ModuleT {
|
||||||
b3Ts := b3Tmp4.ApplyT(b35, train)
|
b3Ts := b3Tmp4.ApplyT(b35, train)
|
||||||
b3Tmp4.MustDrop()
|
b3Tmp4.MustDrop()
|
||||||
|
|
||||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||||
|
|
||||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||||
|
@ -211,7 +211,7 @@ func inceptionE(p *nn.Path, cIn int64) ts.ModuleT {
|
||||||
b3bTs := b3Tmp2.ApplyT(b33b, train)
|
b3bTs := b3Tmp2.ApplyT(b33b, train)
|
||||||
b3Ts := ts.MustCat([]ts.Tensor{*b3aTs, *b3bTs}, 1)
|
b3Ts := ts.MustCat([]ts.Tensor{*b3aTs, *b3bTs}, 1)
|
||||||
|
|
||||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, 9, false)
|
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||||
|
|
||||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user