added new API: ConstantPadNdWithVal

This commit is contained in:
sugarme 2022-01-17 21:41:16 +11:00
parent b1fa1004e0
commit 6e07f2cca1
5 changed files with 42 additions and 0 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- added `nn.Path.Paths()` method
- added `nn.VarStore.Summary()` method
- fixed incorrect tensor method `ts.Meshgrid` -> `Meshgrid`
- added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value.
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -864,3 +864,9 @@ func AtmEval(m Cmodule) {
func AtmTrain(m Cmodule) {
C.atm_train(m)
}
func AtoConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int, value Cscalar) {
cpadDataPtr := (*C.int64_t)(unsafe.Pointer(&padData[0]))
cpadLen := *(*C.int)(unsafe.Pointer(&padLen))
C.ato_constant_pad_nd(ptr, self, cpadDataPtr, cpadLen, value)
}

View File

@ -510,6 +510,13 @@ void ato_set_learning_rate_group(optimizer t, size_t group,
set_lr_group<torch::optim::SGDOptions>(t, group, learning_rate);)
}
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data, int pad_len, scalar value) {
PROTECT(
auto outputs__ = torch::constant_pad_nd(*self, torch::IntArrayRef(pad_data, pad_len), *value);
out__[0] = new torch::Tensor(outputs__);
)
}
// ============ set/get learning rates ==============================
// TT. added for learning rate scheduler
// lr scheduler APIs will be in Pytorch 1.9?

View File

@ -141,6 +141,9 @@ int64_t ato_param_group_num(optimizer);
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
void ato_add_param_group(optimizer, tensor *params, int param_num);
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no option of adding pad value.
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len, scalar value);
scalar ats_int(int64_t);
scalar ats_float(double);
int64_t ats_to_int(scalar);

View File

@ -1243,3 +1243,28 @@ func SaveMultiNew(namedTensors []NamedTensor, path string) error {
return nil
}
func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtoConstantPadNd(ptr, ts.ctensor, pad, len(pad), value.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = &Tensor{ctensor: *ptr}
return retVal, err
}
func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor) {
retVal, err := ts.ConstantPadNdWithVal(pad, value, del)
if err != nil {
log.Fatal(err)
}
return retVal
}