added new API: ConstantPadNdWithVal
This commit is contained in:
parent
b1fa1004e0
commit
6e07f2cca1
|
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- added `nn.Path.Paths()` method
|
||||
- added `nn.VarStore.Summary()` method
|
||||
- fixed incorrect tensor method `ts.Meshgrid` -> `Meshgrid`
|
||||
- added new API `ConstantPadNdWithVal` `ato_constant_pad_nd` with padding value.
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -864,3 +864,9 @@ func AtmEval(m Cmodule) {
|
|||
func AtmTrain(m Cmodule) {
|
||||
C.atm_train(m)
|
||||
}
|
||||
|
||||
func AtoConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int, value Cscalar) {
|
||||
cpadDataPtr := (*C.int64_t)(unsafe.Pointer(&padData[0]))
|
||||
cpadLen := *(*C.int)(unsafe.Pointer(&padLen))
|
||||
C.ato_constant_pad_nd(ptr, self, cpadDataPtr, cpadLen, value)
|
||||
}
|
||||
|
|
|
@ -510,6 +510,13 @@ void ato_set_learning_rate_group(optimizer t, size_t group,
|
|||
set_lr_group<torch::optim::SGDOptions>(t, group, learning_rate);)
|
||||
}
|
||||
|
||||
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data, int pad_len, scalar value) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::constant_pad_nd(*self, torch::IntArrayRef(pad_data, pad_len), *value);
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
// ============ set/get learning rates ==============================
|
||||
// TT. added for learning rate scheduler
|
||||
// lr scheduler APIs will be in Pytorch 1.9?
|
||||
|
|
|
@ -141,6 +141,9 @@ int64_t ato_param_group_num(optimizer);
|
|||
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
||||
void ato_add_param_group(optimizer, tensor *params, int param_num);
|
||||
|
||||
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no option of adding pad value.
|
||||
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len, scalar value);
|
||||
|
||||
scalar ats_int(int64_t);
|
||||
scalar ats_float(double);
|
||||
int64_t ats_to_int(scalar);
|
||||
|
|
|
@ -1243,3 +1243,28 @@ func SaveMultiNew(namedTensors []NamedTensor, path string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtoConstantPadNd(ptr, ts.ctensor, pad, len(pad), value.cscalar)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (retVal *Tensor) {
|
||||
|
||||
retVal, err := ts.ConstantPadNdWithVal(pad, value, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user