aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRalf W. Grosse-Kunstleve <rwgk@google.com>2023-06-27 15:08:32 -0700
committerGitHub <noreply@github.com>2023-06-27 15:08:32 -0700
commit2fb3d7cbde264a0b3f921e802f287195387e8263 (patch)
tree154266ac3d6ac45f68877003dbd67f3dd935f0ce
parente10da79b6ee2554be364ef14df1c988f94df02ea (diff)
downloadpybind11-2fb3d7cbde264a0b3f921e802f287195387e8263.tar.gz
Trivial refactoring to make the capsule API more user friendly. (#4720)
* Trivial refactoring to make the capsule API more user friendly. * Use new API in production code. Thanks @Lalaland for pointing this out.
-rw-r--r--include/pybind11/pybind11.h2
-rw-r--r--include/pybind11/pytypes.h51
-rw-r--r--tests/test_pytypes.cpp9
-rw-r--r--tests/test_pytypes.py13
4 files changed, 54 insertions, 21 deletions
diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h
index 28ebc222..3bce1a01 100644
--- a/include/pybind11/pybind11.h
+++ b/include/pybind11/pybind11.h
@@ -508,8 +508,8 @@ protected:
rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS;
capsule rec_capsule(unique_rec.release(),
+ detail::get_function_record_capsule_name(),
[](void *ptr) { destruct((detail::function_record *) ptr); });
- rec_capsule.set_name(detail::get_function_record_capsule_name());
guarded_strdup.release();
object scope_module;
diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h
index f5d3f34f..c93e3d3b 100644
--- a/include/pybind11/pytypes.h
+++ b/include/pybind11/pytypes.h
@@ -1925,28 +1925,13 @@ public:
}
}
+ /// Capsule name is nullptr.
capsule(const void *value, void (*destructor)(void *)) {
- m_ptr = PyCapsule_New(const_cast<void *>(value), nullptr, [](PyObject *o) {
- // guard if destructor called while err indicator is set
- error_scope error_guard;
- auto destructor = reinterpret_cast<void (*)(void *)>(PyCapsule_GetContext(o));
- if (destructor == nullptr && PyErr_Occurred()) {
- throw error_already_set();
- }
- const char *name = get_name_in_error_scope(o);
- void *ptr = PyCapsule_GetPointer(o, name);
- if (ptr == nullptr) {
- throw error_already_set();
- }
-
- if (destructor != nullptr) {
- destructor(ptr);
- }
- });
+ initialize_with_void_ptr_destructor(value, nullptr, destructor);
+ }
- if (!m_ptr || PyCapsule_SetContext(m_ptr, reinterpret_cast<void *>(destructor)) != 0) {
- throw error_already_set();
- }
+ capsule(const void *value, const char *name, void (*destructor)(void *)) {
+ initialize_with_void_ptr_destructor(value, name, destructor);
}
explicit capsule(void (*destructor)()) {
@@ -2014,6 +1999,32 @@ private:
return name;
}
+
+ void initialize_with_void_ptr_destructor(const void *value,
+ const char *name,
+ void (*destructor)(void *)) {
+ m_ptr = PyCapsule_New(const_cast<void *>(value), name, [](PyObject *o) {
+ // guard if destructor called while err indicator is set
+ error_scope error_guard;
+ auto destructor = reinterpret_cast<void (*)(void *)>(PyCapsule_GetContext(o));
+ if (destructor == nullptr && PyErr_Occurred()) {
+ throw error_already_set();
+ }
+ const char *name = get_name_in_error_scope(o);
+ void *ptr = PyCapsule_GetPointer(o, name);
+ if (ptr == nullptr) {
+ throw error_already_set();
+ }
+
+ if (destructor != nullptr) {
+ destructor(ptr);
+ }
+ });
+
+ if (!m_ptr || PyCapsule_SetContext(m_ptr, reinterpret_cast<void *>(destructor)) != 0) {
+ throw error_already_set();
+ }
+ }
};
class tuple : public object {
diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp
index 1028bb58..b4ee6428 100644
--- a/tests/test_pytypes.cpp
+++ b/tests/test_pytypes.cpp
@@ -260,6 +260,15 @@ TEST_SUBMODULE(pytypes, m) {
});
});
+ m.def("return_capsule_with_destructor_3", []() {
+ py::print("creating capsule");
+ auto cap = py::capsule((void *) 1233, "oname", [](void *ptr) {
+ py::print("destructing capsule: {}"_s.format((size_t) ptr));
+ });
+ py::print("original name: {}"_s.format(cap.name()));
+ return cap;
+ });
+
m.def("return_renamed_capsule_with_destructor_2", []() {
py::print("creating capsule");
auto cap = py::capsule((void *) 1234, [](void *ptr) {
diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py
index afb7a1ce..eda7a20a 100644
--- a/tests/test_pytypes.py
+++ b/tests/test_pytypes.py
@@ -320,6 +320,19 @@ def test_capsule(capture):
)
with capture:
+ a = m.return_capsule_with_destructor_3()
+ del a
+ pytest.gc_collect()
+ assert (
+ capture.unordered
+ == """
+ creating capsule
+ destructing capsule: 1233
+ original name: oname
+ """
+ )
+
+ with capture:
a = m.return_renamed_capsule_with_destructor_2()
del a
pytest.gc_collect()