diff --git a/capnp/helpers/asyncHelper.h b/capnp/helpers/asyncHelper.h index 6bfef0b..0ab2e39 100644 --- a/capnp/helpers/asyncHelper.h +++ b/capnp/helpers/asyncHelper.h @@ -2,6 +2,7 @@ #include "kj/async.h" #include "Python.h" +#include "capabilityHelper.h" class PyEventPort: public kj::EventPort { public: @@ -10,14 +11,17 @@ public: // Py_INCREF(py_event_port); } virtual void wait() { + GILAcquire gil; PyObject_CallMethod(py_event_port, const_cast("wait"), NULL); } virtual void poll() { + GILAcquire gil; PyObject_CallMethod(py_event_port, const_cast("poll"), NULL); } virtual void setRunnable(bool runnable) { + GILAcquire gil; PyObject * arg = Py_False; if (runnable) arg = Py_True; @@ -29,9 +33,25 @@ private: }; void waitNeverDone(kj::WaitScope & scope) { + GILRelease gil; kj::NEVER_DONE.wait(scope); } kj::Timer * getTimer(kj::AsyncIoContext * context) { return &context->lowLevelProvider->getTimer(); } + +void waitVoidPromise(kj::Promise * promise, kj::WaitScope & scope) { + GILRelease gil; + promise->wait(scope); +} + +PyObject * waitPyPromise(kj::Promise * promise, kj::WaitScope & scope) { + GILRelease gil; + return promise->wait(scope); +} + +capnp::Response< ::capnp::DynamicStruct> * waitRemote(capnp::RemotePromise< ::capnp::DynamicStruct> * promise, kj::WaitScope & scope) { + GILRelease gil; + return new capnp::Response< ::capnp::DynamicStruct>(promise->wait(scope)); +} diff --git a/capnp/helpers/capabilityHelper.h b/capnp/helpers/capabilityHelper.h index 4a33977..cad1b92 100644 --- a/capnp/helpers/capabilityHelper.h +++ b/capnp/helpers/capabilityHelper.h @@ -3,7 +3,6 @@ #include "capnp/dynamic.h" #include #include "Python.h" -#include extern "C" { PyObject * wrap_remote_call(PyObject * func, capnp::Response &); @@ -17,12 +16,38 @@ extern "C" { ::capnp::RemotePromise< ::capnp::DynamicStruct> * extract_remote_promise(PyObject *); } +class GILAcquire { +public: + GILAcquire() : gstate(PyGILState_Ensure()) {} + ~GILAcquire() { + PyGILState_Release(gstate); + } + + PyGILState_STATE gstate; +}; + +class GILRelease { +public: + GILRelease() { + Py_UNBLOCK_THREADS + } + ~GILRelease() { + Py_BLOCK_THREADS + } + + PyThreadState *_save; // The macros above read/write from this variable +}; + ::kj::Promise convert_to_pypromise(capnp::RemotePromise & promise) { return promise.then([](capnp::Response&& response) { return wrap_dynamic_struct_reader(response); } ); } ::kj::Promise convert_to_pypromise(kj::Promise & promise) { - return promise.then([]() { Py_RETURN_NONE;} ); + return promise.then([]() { + GILAcquire gil; + Py_INCREF( Py_None ); + return Py_None; + }); } template @@ -31,6 +56,7 @@ template } void reraise_kj_exception() { + GILAcquire gil; try { if (PyErr_Occurred()) ; // let the latest Python exn pass through and ignore the current one @@ -51,6 +77,7 @@ void reraise_kj_exception() { } void check_py_error() { + GILAcquire gil; PyObject * err = PyErr_Occurred(); if(err) { PyObject * ptype, *pvalue, *ptraceback; @@ -80,6 +107,7 @@ void check_py_error() { } kj::Promise wrapPyFunc(PyObject * func, PyObject * arg) { + GILAcquire gil; auto arg_promise = extract_promise(arg); if(arg_promise == NULL) { @@ -102,6 +130,7 @@ kj::Promise wrapPyFunc(PyObject * func, PyObject * arg) { } kj::Promise wrapPyFuncNoArg(PyObject * func) { + GILAcquire gil; PyObject * result = PyObject_CallFunctionObjArgs(func, NULL); check_py_error(); @@ -116,6 +145,7 @@ kj::Promise wrapPyFuncNoArg(PyObject * func) { } kj::Promise wrapRemoteCall(PyObject * func, capnp::Response & arg) { + GILAcquire gil; PyObject * ret = wrap_remote_call(func, arg); check_py_error(); @@ -163,10 +193,12 @@ public: PythonInterfaceDynamicImpl(capnp::InterfaceSchema & schema, PyObject * _py_server) : capnp::DynamicCapability::Server(schema), py_server(_py_server) { + GILAcquire gil; Py_INCREF(_py_server); } ~PythonInterfaceDynamicImpl() { + GILAcquire gil; Py_DECREF(py_server); } @@ -192,14 +224,17 @@ public: PyObject * obj; PyRefCounter(PyObject * o) : obj(o) { + GILAcquire gil; Py_INCREF(obj); } PyRefCounter(const PyRefCounter & ref) : obj(ref.obj) { + GILAcquire gil; Py_INCREF(obj); } ~PyRefCounter() { + GILAcquire gil; Py_DECREF(obj); } }; diff --git a/capnp/helpers/helpers.pxd b/capnp/helpers/helpers.pxd index 1e842c2..8de5569 100644 --- a/capnp/helpers/helpers.pxd +++ b/capnp/helpers/helpers.pxd @@ -1,4 +1,4 @@ -from .capnp.includes.capnp_cpp cimport Maybe, DynamicStruct, Request, PyPromise, VoidPromise, PyPromiseArray, RemotePromise, DynamicCapability, InterfaceSchema, EnumSchema, StructSchema, DynamicValue, Capability, RpcSystem, MessageBuilder, MessageReader, TwoPartyVatNetwork, PyRestorer, AnyPointer, DynamicStruct_Builder, WaitScope, AsyncIoContext, StringPtr, TaskSet, Timer +from .capnp.includes.capnp_cpp cimport Maybe, DynamicStruct, Request, Response, PyPromise, VoidPromise, PyPromiseArray, RemotePromise, DynamicCapability, InterfaceSchema, EnumSchema, StructSchema, DynamicValue, Capability, RpcSystem, MessageBuilder, MessageReader, TwoPartyVatNetwork, PyRestorer, AnyPointer, DynamicStruct_Builder, WaitScope, AsyncIoContext, StringPtr, TaskSet, Timer from .capnp.includes.schema_cpp cimport ByteArray @@ -38,4 +38,7 @@ cdef extern from "../helpers/serialize.h": cdef extern from "../helpers/asyncHelper.h": void waitNeverDone(WaitScope&) + Response * waitRemote(RemotePromise *, WaitScope&) + PyObject * waitPyPromise(PyPromise *, WaitScope&) + void waitVoidPromise(VoidPromise *, WaitScope&) Timer * getTimer(AsyncIoContext *) except +reraise_kj_exception diff --git a/capnp/helpers/rpcHelper.h b/capnp/helpers/rpcHelper.h index 7f89011..5885dc4 100644 --- a/capnp/helpers/rpcHelper.h +++ b/capnp/helpers/rpcHelper.h @@ -22,6 +22,7 @@ public: // } capnp::Capability::Client restore(capnp::AnyPointer::Reader objectId) override { + GILAcquire gil; capnp::Capability::Client * ret = call_py_restorer(py_restorer, objectId); check_py_error(); capnp::Capability::Client stack_ret(*ret); @@ -113,17 +114,17 @@ void acceptLoop(kj::TaskSet & tasks, PyRestorer & restorer, kj::Own connectServer(kj::TaskSet & tasks, PyRestorer & restorer, kj::AsyncIoContext * context, kj::StringPtr bindAddress) { - auto paf = kj::newPromiseAndFulfiller(); + auto paf = kj::newPromiseAndFulfiller(); auto portPromise = paf.promise.fork(); tasks.add(context->provider->getNetwork().parseAddress(bindAddress) .then(kj::mvCapture(paf.fulfiller, - [&](kj::Own>&& portFulfiller, + [&](kj::Own>&& portFulfiller, kj::Own&& addr) { auto listener = addr->listen(); portFulfiller->fulfill(listener->getPort()); acceptLoop(tasks, restorer, kj::mv(listener)); }))); - return portPromise.addBranch().then([&](uint port) { return PyLong_FromUnsignedLong(port); }); + return portPromise.addBranch().then([&](unsigned int port) { return PyLong_FromUnsignedLong(port); }); } diff --git a/capnp/includes/capnp_cpp.pxd b/capnp/includes/capnp_cpp.pxd index 0d264fa..be01719 100644 --- a/capnp/includes/capnp_cpp.pxd +++ b/capnp/includes/capnp_cpp.pxd @@ -51,7 +51,7 @@ cdef extern from "kj/memory.h" namespace " ::kj": Own[PyRefCounter] makePyRefCounter" ::kj::heap< PyRefCounter >"(PyObject *) cdef extern from "kj/async.h" namespace " ::kj": - cdef cppclass Promise[T]: + cdef cppclass Promise[T] nogil: Promise() Promise(Promise) Promise(T) diff --git a/capnp/includes/types.pxd b/capnp/includes/types.pxd index 10567ce..c6dd3ba 100644 --- a/capnp/includes/types.pxd +++ b/capnp/includes/types.pxd @@ -1,5 +1,4 @@ from cpython.ref cimport PyObject, Py_INCREF, Py_DECREF -from cpython.exc cimport PyErr_Clear from libc.stdint cimport * ctypedef unsigned int uint ctypedef uint8_t byte diff --git a/capnp/lib/capnp.pyx b/capnp/lib/capnp.pyx index 6071c02..0ab1be9 100644 --- a/capnp/lib/capnp.pyx +++ b/capnp/lib/capnp.pyx @@ -12,6 +12,7 @@ from .capnp.helpers.helpers cimport makeRpcClientWithRestorer from libc.stdlib cimport malloc, free from cython.operator cimport dereference as deref +from cpython.exc cimport PyErr_Clear from types import ModuleType as _ModuleType import os as _os @@ -32,10 +33,10 @@ _CAPNP_VERSION_MICRO = capnp.CAPNP_VERSION_MICRO _CAPNP_VERSION = capnp.CAPNP_VERSION # By making it public, we'll be able to call it from capabilityHelper.h -cdef public object wrap_dynamic_struct_reader(Response & r): +cdef public object wrap_dynamic_struct_reader(Response & r) with gil: return _Response()._init_childptr(new Response(moveResponse(r)), None) -cdef public PyObject * wrap_remote_call(PyObject * func, Response & r) except *: +cdef public PyObject * wrap_remote_call(PyObject * func, Response & r) except * with gil: response = _Response()._init_childptr(new Response(moveResponse(r)), None) func_obj = func @@ -46,7 +47,7 @@ cdef public PyObject * wrap_remote_call(PyObject * func, Response & r) except *: cdef _find_field_order(struct_node): return [f.name for f in sorted(struct_node.fields, key=_attrgetter('codeOrder'))] -cdef public VoidPromise * call_server_method(PyObject * _server, char * _method_name, CallContext & _context) except *: +cdef public VoidPromise * call_server_method(PyObject * _server, char * _method_name, CallContext & _context) except * with gil: server = _server method_name = _method_name @@ -101,7 +102,7 @@ cdef public VoidPromise * call_server_method(PyObject * _server, char * _method_ return NULL -cdef public C_Capability.Client * call_py_restorer(PyObject * _restorer, C_DynamicObject.Reader & _reader) except *: +cdef public C_Capability.Client * call_py_restorer(PyObject * _restorer, C_DynamicObject.Reader & _reader) except * with gil: restorer = _restorer reader = _DynamicObjectReader()._init(_reader, None) @@ -112,10 +113,10 @@ cdef public C_Capability.Client * call_py_restorer(PyObject * _restorer, C_Dynam return new C_Capability.Client(helpers.server_to_client(schema.thisptr, server)) -cdef public convert_array_pyobject(PyArray & arr): +cdef public convert_array_pyobject(PyArray & arr) with gil: return [arr[i] for i in range(arr.size())] -cdef public PyPromise * extract_promise(object obj): +cdef public PyPromise * extract_promise(object obj) with gil: if type(obj) is Promise: promise = obj @@ -126,7 +127,7 @@ cdef public PyPromise * extract_promise(object obj): return NULL -cdef public RemotePromise * extract_remote_promise(object obj): +cdef public RemotePromise * extract_remote_promise(object obj) with gil: if type(obj) is _RemotePromise: promise = <_RemotePromise>obj promise.is_consumed = True @@ -236,14 +237,14 @@ class KjException(Exception): def __str__(self): return self.message -cdef public object wrap_kj_exception(capnp.Exception & exception): +cdef public object wrap_kj_exception(capnp.Exception & exception) with gil: PyErr_Clear() wrapper = _KjExceptionWrapper()._init(exception) ret = KjException(wrapper=wrapper) return ret -cdef public object wrap_kj_exception_for_reraise(capnp.Exception & exception): +cdef public object wrap_kj_exception_for_reraise(capnp.Exception & exception) with gil: wrapper = _KjExceptionWrapper()._init(exception) wrapper_msg = str(wrapper) @@ -265,13 +266,12 @@ cdef public object wrap_kj_exception_for_reraise(capnp.Exception & exception): ret = KjException(wrapper=wrapper) return ret -cdef public object get_exception_info(object exc_type, object exc_obj, object exc_tb): +cdef public object get_exception_info(object exc_type, object exc_obj, object exc_tb) with gil: try: return (exc_tb.tb_frame.f_code.co_filename.encode(), exc_tb.tb_lineno, (repr(exc_type) + ':' + str(exc_obj)).encode()) except: return (b'', 0, b"Couldn't determine python exception") - ctypedef fused _DynamicStructReaderOrBuilder: _DynamicStructReader _DynamicStructBuilder @@ -1341,6 +1341,7 @@ cdef class _EventLoop: cdef _EventLoop C_DEFAULT_EVENT_LOOP = _EventLoop() _C_DEFAULT_EVENT_LOOP_LOCAL = None +_THREAD_LOCAL_EVENT_LOOPS = [] cdef _EventLoop C_DEFAULT_EVENT_LOOP_GETTER(): 'Optimization for not having to deal with threadlocal event loops unless we need to' @@ -1369,11 +1370,28 @@ cdef class Timer: def getTimer(): return Timer()._init(helpers.getTimer(C_DEFAULT_EVENT_LOOP_GETTER().thisptr)) -cpdef remove_event_loop(): +cpdef remove_event_loop(ignore_errors=False): 'Remove the global event loop' global C_DEFAULT_EVENT_LOOP - C_DEFAULT_EVENT_LOOP._remove() - C_DEFAULT_EVENT_LOOP = None + global _THREAD_LOCAL_EVENT_LOOPS + global _C_DEFAULT_EVENT_LOOP_LOCAL + + if C_DEFAULT_EVENT_LOOP: + try: + C_DEFAULT_EVENT_LOOP._remove() + except: + if not ignore_errors: + raise + C_DEFAULT_EVENT_LOOP = None + if len(_THREAD_LOCAL_EVENT_LOOPS) > 0: + for loop in _THREAD_LOCAL_EVENT_LOOPS: + try: + loop._remove() + except: + if not ignore_errors: + raise + _THREAD_LOCAL_EVENT_LOOPS = [] + _C_DEFAULT_EVENT_LOOP_LOCAL = None cpdef create_event_loop(threaded=True): '''Create a new global event loop. This will not remove the previous @@ -1383,7 +1401,9 @@ cpdef create_event_loop(threaded=True): if threaded: if _C_DEFAULT_EVENT_LOOP_LOCAL is None: _C_DEFAULT_EVENT_LOOP_LOCAL = _threading.local() - _C_DEFAULT_EVENT_LOOP_LOCAL.loop = _EventLoop() + loop = _EventLoop() + _C_DEFAULT_EVENT_LOOP_LOCAL.loop = loop + _THREAD_LOCAL_EVENT_LOOPS.append(loop) else: C_DEFAULT_EVENT_LOOP = _EventLoop() @@ -1458,7 +1478,7 @@ cdef class Promise: if self.is_consumed: raise ValueError('Promise was already used in a consuming operation. You can no longer use this Promise object') - ret = self.thisptr.wait(deref(self._event_loop.thisptr).waitScope) + ret = helpers.waitPyPromise(self.thisptr, deref(self._event_loop.thisptr).waitScope) Py_DECREF(ret) self.is_consumed = True @@ -1524,7 +1544,8 @@ cdef class _VoidPromise: if self.is_consumed: raise ValueError('Promise was already used in a consuming operation. You can no longer use this Promise object') - self.thisptr.wait(deref(self._event_loop.thisptr).waitScope) + helpers.waitVoidPromise(self.thisptr, deref(self._event_loop.thisptr).waitScope) + self.is_consumed = True cpdef then(self, func, error_func=None) except +reraise_kj_exception: @@ -1590,7 +1611,7 @@ cdef class _RemotePromise: if self.is_consumed: raise ValueError('Promise was already used in a consuming operation. You can no longer use this Promise object') - ret = _Response()._init_child(self.thisptr.wait(deref(self._event_loop.thisptr).waitScope), self._parent) + ret = _Response()._init_childptr(helpers.waitRemote(self.thisptr, deref(self._event_loop.thisptr).waitScope), self._parent) self.is_consumed = True return ret diff --git a/test/test_threads.py b/test/test_threads.py new file mode 100644 index 0000000..29f1632 --- /dev/null +++ b/test/test_threads.py @@ -0,0 +1,66 @@ +import capnp +import pytest +import test_capability_capnp +import socket +import threading +import platform + + +def test_making_event_loop(): + capnp.remove_event_loop(True) + capnp.create_event_loop() + + capnp.remove_event_loop() + capnp.create_event_loop() + + +def test_making_threaded_event_loop(): + capnp.remove_event_loop(True) + capnp.create_event_loop(True) + + capnp.remove_event_loop() + capnp.create_event_loop(True) + + +class Server(test_capability_capnp.TestInterface.Server): + + def __init__(self, val=1): + self.val = val + + def foo(self, i, j, **kwargs): + return str(i * 5 + self.val) + + +class SimpleRestorer(test_capability_capnp.TestSturdyRefObjectId.Restorer): + + def restore(self, ref_id): + assert ref_id.tag == 'testInterface' + return Server(100) + + +@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="pycapnp's GIL handling isn't working properly at the moment for PyPy") +def test_using_threads(): + capnp.remove_event_loop(True) + capnp.create_event_loop(True) + + read, write = socket.socketpair(socket.AF_UNIX) + + def run_server(): + restorer = SimpleRestorer() + server = capnp.TwoPartyServer(write, restorer) + capnp.wait_forever() + + server_thread = threading.Thread(target=run_server) + server_thread.daemon = True + server_thread.start() + + client = capnp.TwoPartyClient(read) + + ref = test_capability_capnp.TestSturdyRefObjectId.new_message(tag='testInterface') + cap = client.restore(ref) + cap = cap.cast_as(test_capability_capnp.TestInterface) + + remote = cap.foo(i=5) + response = remote.wait() + + assert response.x == '125'