Make pycapnp GIL friendly

Now calling `wait` from one thread will not block all threads
This commit is contained in:
Jason Paryani 2014-07-09 00:49:35 -07:00
parent 5befda532f
commit a1f7d32853
8 changed files with 171 additions and 26 deletions

View file

@ -2,6 +2,7 @@
#include "kj/async.h"
#include "Python.h"
#include "capabilityHelper.h"
class PyEventPort: public kj::EventPort {
public:
@ -10,14 +11,17 @@ public:
// Py_INCREF(py_event_port);
}
virtual void wait() {
GILAcquire gil;
PyObject_CallMethod(py_event_port, const_cast<char *>("wait"), NULL);
}
virtual void poll() {
GILAcquire gil;
PyObject_CallMethod(py_event_port, const_cast<char *>("poll"), NULL);
}
virtual void setRunnable(bool runnable) {
GILAcquire gil;
PyObject * arg = Py_False;
if (runnable)
arg = Py_True;
@ -29,9 +33,25 @@ private:
};
void waitNeverDone(kj::WaitScope & scope) {
GILRelease gil;
kj::NEVER_DONE.wait(scope);
}
kj::Timer * getTimer(kj::AsyncIoContext * context) {
return &context->lowLevelProvider->getTimer();
}
void waitVoidPromise(kj::Promise<void> * promise, kj::WaitScope & scope) {
GILRelease gil;
promise->wait(scope);
}
PyObject * waitPyPromise(kj::Promise<PyObject *> * promise, kj::WaitScope & scope) {
GILRelease gil;
return promise->wait(scope);
}
capnp::Response< ::capnp::DynamicStruct> * waitRemote(capnp::RemotePromise< ::capnp::DynamicStruct> * promise, kj::WaitScope & scope) {
GILRelease gil;
return new capnp::Response< ::capnp::DynamicStruct>(promise->wait(scope));
}

View file

@ -3,7 +3,6 @@
#include "capnp/dynamic.h"
#include <stdexcept>
#include "Python.h"
#include <iostream>
extern "C" {
PyObject * wrap_remote_call(PyObject * func, capnp::Response<capnp::DynamicStruct> &);
@ -17,12 +16,38 @@ extern "C" {
::capnp::RemotePromise< ::capnp::DynamicStruct> * extract_remote_promise(PyObject *);
}
class GILAcquire {
public:
GILAcquire() : gstate(PyGILState_Ensure()) {}
~GILAcquire() {
PyGILState_Release(gstate);
}
PyGILState_STATE gstate;
};
class GILRelease {
public:
GILRelease() {
Py_UNBLOCK_THREADS
}
~GILRelease() {
Py_BLOCK_THREADS
}
PyThreadState *_save; // The macros above read/write from this variable
};
::kj::Promise<PyObject *> convert_to_pypromise(capnp::RemotePromise<capnp::DynamicStruct> & promise) {
return promise.then([](capnp::Response<capnp::DynamicStruct>&& response) { return wrap_dynamic_struct_reader(response); } );
}
::kj::Promise<PyObject *> convert_to_pypromise(kj::Promise<void> & promise) {
return promise.then([]() { Py_RETURN_NONE;} );
return promise.then([]() {
GILAcquire gil;
Py_INCREF( Py_None );
return Py_None;
});
}
template<class T>
@ -31,6 +56,7 @@ template<class T>
}
void reraise_kj_exception() {
GILAcquire gil;
try {
if (PyErr_Occurred())
; // let the latest Python exn pass through and ignore the current one
@ -51,6 +77,7 @@ void reraise_kj_exception() {
}
void check_py_error() {
GILAcquire gil;
PyObject * err = PyErr_Occurred();
if(err) {
PyObject * ptype, *pvalue, *ptraceback;
@ -80,6 +107,7 @@ void check_py_error() {
}
kj::Promise<PyObject *> wrapPyFunc(PyObject * func, PyObject * arg) {
GILAcquire gil;
auto arg_promise = extract_promise(arg);
if(arg_promise == NULL) {
@ -102,6 +130,7 @@ kj::Promise<PyObject *> wrapPyFunc(PyObject * func, PyObject * arg) {
}
kj::Promise<PyObject *> wrapPyFuncNoArg(PyObject * func) {
GILAcquire gil;
PyObject * result = PyObject_CallFunctionObjArgs(func, NULL);
check_py_error();
@ -116,6 +145,7 @@ kj::Promise<PyObject *> wrapPyFuncNoArg(PyObject * func) {
}
kj::Promise<PyObject *> wrapRemoteCall(PyObject * func, capnp::Response<capnp::DynamicStruct> & arg) {
GILAcquire gil;
PyObject * ret = wrap_remote_call(func, arg);
check_py_error();
@ -163,10 +193,12 @@ public:
PythonInterfaceDynamicImpl(capnp::InterfaceSchema & schema, PyObject * _py_server)
: capnp::DynamicCapability::Server(schema), py_server(_py_server) {
GILAcquire gil;
Py_INCREF(_py_server);
}
~PythonInterfaceDynamicImpl() {
GILAcquire gil;
Py_DECREF(py_server);
}
@ -192,14 +224,17 @@ public:
PyObject * obj;
PyRefCounter(PyObject * o) : obj(o) {
GILAcquire gil;
Py_INCREF(obj);
}
PyRefCounter(const PyRefCounter & ref) : obj(ref.obj) {
GILAcquire gil;
Py_INCREF(obj);
}
~PyRefCounter() {
GILAcquire gil;
Py_DECREF(obj);
}
};

View file

@ -1,4 +1,4 @@
from .capnp.includes.capnp_cpp cimport Maybe, DynamicStruct, Request, PyPromise, VoidPromise, PyPromiseArray, RemotePromise, DynamicCapability, InterfaceSchema, EnumSchema, StructSchema, DynamicValue, Capability, RpcSystem, MessageBuilder, MessageReader, TwoPartyVatNetwork, PyRestorer, AnyPointer, DynamicStruct_Builder, WaitScope, AsyncIoContext, StringPtr, TaskSet, Timer
from .capnp.includes.capnp_cpp cimport Maybe, DynamicStruct, Request, Response, PyPromise, VoidPromise, PyPromiseArray, RemotePromise, DynamicCapability, InterfaceSchema, EnumSchema, StructSchema, DynamicValue, Capability, RpcSystem, MessageBuilder, MessageReader, TwoPartyVatNetwork, PyRestorer, AnyPointer, DynamicStruct_Builder, WaitScope, AsyncIoContext, StringPtr, TaskSet, Timer
from .capnp.includes.schema_cpp cimport ByteArray
@ -38,4 +38,7 @@ cdef extern from "../helpers/serialize.h":
cdef extern from "../helpers/asyncHelper.h":
void waitNeverDone(WaitScope&)
Response * waitRemote(RemotePromise *, WaitScope&)
PyObject * waitPyPromise(PyPromise *, WaitScope&)
void waitVoidPromise(VoidPromise *, WaitScope&)
Timer * getTimer(AsyncIoContext *) except +reraise_kj_exception

View file

@ -22,6 +22,7 @@ public:
// }
capnp::Capability::Client restore(capnp::AnyPointer::Reader objectId) override {
GILAcquire gil;
capnp::Capability::Client * ret = call_py_restorer(py_restorer, objectId);
check_py_error();
capnp::Capability::Client stack_ret(*ret);
@ -113,17 +114,17 @@ void acceptLoop(kj::TaskSet & tasks, PyRestorer & restorer, kj::Own<kj::Connecti
}
kj::Promise<PyObject *> connectServer(kj::TaskSet & tasks, PyRestorer & restorer, kj::AsyncIoContext * context, kj::StringPtr bindAddress) {
auto paf = kj::newPromiseAndFulfiller<uint>();
auto paf = kj::newPromiseAndFulfiller<unsigned int>();
auto portPromise = paf.promise.fork();
tasks.add(context->provider->getNetwork().parseAddress(bindAddress)
.then(kj::mvCapture(paf.fulfiller,
[&](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller,
[&](kj::Own<kj::PromiseFulfiller<unsigned int>>&& portFulfiller,
kj::Own<kj::NetworkAddress>&& addr) {
auto listener = addr->listen();
portFulfiller->fulfill(listener->getPort());
acceptLoop(tasks, restorer, kj::mv(listener));
})));
return portPromise.addBranch().then([&](uint port) { return PyLong_FromUnsignedLong(port); });
return portPromise.addBranch().then([&](unsigned int port) { return PyLong_FromUnsignedLong(port); });
}

View file

@ -51,7 +51,7 @@ cdef extern from "kj/memory.h" namespace " ::kj":
Own[PyRefCounter] makePyRefCounter" ::kj::heap< PyRefCounter >"(PyObject *)
cdef extern from "kj/async.h" namespace " ::kj":
cdef cppclass Promise[T]:
cdef cppclass Promise[T] nogil:
Promise()
Promise(Promise)
Promise(T)

View file

@ -1,5 +1,4 @@
from cpython.ref cimport PyObject, Py_INCREF, Py_DECREF
from cpython.exc cimport PyErr_Clear
from libc.stdint cimport *
ctypedef unsigned int uint
ctypedef uint8_t byte

View file

@ -12,6 +12,7 @@ from .capnp.helpers.helpers cimport makeRpcClientWithRestorer
from libc.stdlib cimport malloc, free
from cython.operator cimport dereference as deref
from cpython.exc cimport PyErr_Clear
from types import ModuleType as _ModuleType
import os as _os
@ -32,10 +33,10 @@ _CAPNP_VERSION_MICRO = capnp.CAPNP_VERSION_MICRO
_CAPNP_VERSION = capnp.CAPNP_VERSION
# By making it public, we'll be able to call it from capabilityHelper.h
cdef public object wrap_dynamic_struct_reader(Response & r):
cdef public object wrap_dynamic_struct_reader(Response & r) with gil:
return _Response()._init_childptr(new Response(moveResponse(r)), None)
cdef public PyObject * wrap_remote_call(PyObject * func, Response & r) except *:
cdef public PyObject * wrap_remote_call(PyObject * func, Response & r) except * with gil:
response = _Response()._init_childptr(new Response(moveResponse(r)), None)
func_obj = <object>func
@ -46,7 +47,7 @@ cdef public PyObject * wrap_remote_call(PyObject * func, Response & r) except *:
cdef _find_field_order(struct_node):
return [f.name for f in sorted(struct_node.fields, key=_attrgetter('codeOrder'))]
cdef public VoidPromise * call_server_method(PyObject * _server, char * _method_name, CallContext & _context) except *:
cdef public VoidPromise * call_server_method(PyObject * _server, char * _method_name, CallContext & _context) except * with gil:
server = <object>_server
method_name = <object>_method_name
@ -101,7 +102,7 @@ cdef public VoidPromise * call_server_method(PyObject * _server, char * _method_
return NULL
cdef public C_Capability.Client * call_py_restorer(PyObject * _restorer, C_DynamicObject.Reader & _reader) except *:
cdef public C_Capability.Client * call_py_restorer(PyObject * _restorer, C_DynamicObject.Reader & _reader) except * with gil:
restorer = <object>_restorer
reader = _DynamicObjectReader()._init(_reader, None)
@ -112,10 +113,10 @@ cdef public C_Capability.Client * call_py_restorer(PyObject * _restorer, C_Dynam
return new C_Capability.Client(helpers.server_to_client(schema.thisptr, <PyObject *>server))
cdef public convert_array_pyobject(PyArray & arr):
cdef public convert_array_pyobject(PyArray & arr) with gil:
return [<object>arr[i] for i in range(arr.size())]
cdef public PyPromise * extract_promise(object obj):
cdef public PyPromise * extract_promise(object obj) with gil:
if type(obj) is Promise:
promise = <Promise>obj
@ -126,7 +127,7 @@ cdef public PyPromise * extract_promise(object obj):
return NULL
cdef public RemotePromise * extract_remote_promise(object obj):
cdef public RemotePromise * extract_remote_promise(object obj) with gil:
if type(obj) is _RemotePromise:
promise = <_RemotePromise>obj
promise.is_consumed = True
@ -236,14 +237,14 @@ class KjException(Exception):
def __str__(self):
return self.message
cdef public object wrap_kj_exception(capnp.Exception & exception):
cdef public object wrap_kj_exception(capnp.Exception & exception) with gil:
PyErr_Clear()
wrapper = _KjExceptionWrapper()._init(exception)
ret = KjException(wrapper=wrapper)
return ret
cdef public object wrap_kj_exception_for_reraise(capnp.Exception & exception):
cdef public object wrap_kj_exception_for_reraise(capnp.Exception & exception) with gil:
wrapper = _KjExceptionWrapper()._init(exception)
wrapper_msg = str(wrapper)
@ -265,13 +266,12 @@ cdef public object wrap_kj_exception_for_reraise(capnp.Exception & exception):
ret = KjException(wrapper=wrapper)
return ret
cdef public object get_exception_info(object exc_type, object exc_obj, object exc_tb):
cdef public object get_exception_info(object exc_type, object exc_obj, object exc_tb) with gil:
try:
return (exc_tb.tb_frame.f_code.co_filename.encode(), exc_tb.tb_lineno, (repr(exc_type) + ':' + str(exc_obj)).encode())
except:
return (b'', 0, b"Couldn't determine python exception")
ctypedef fused _DynamicStructReaderOrBuilder:
_DynamicStructReader
_DynamicStructBuilder
@ -1341,6 +1341,7 @@ cdef class _EventLoop:
cdef _EventLoop C_DEFAULT_EVENT_LOOP = _EventLoop()
_C_DEFAULT_EVENT_LOOP_LOCAL = None
_THREAD_LOCAL_EVENT_LOOPS = []
cdef _EventLoop C_DEFAULT_EVENT_LOOP_GETTER():
'Optimization for not having to deal with threadlocal event loops unless we need to'
@ -1369,11 +1370,28 @@ cdef class Timer:
def getTimer():
return Timer()._init(helpers.getTimer(C_DEFAULT_EVENT_LOOP_GETTER().thisptr))
cpdef remove_event_loop():
cpdef remove_event_loop(ignore_errors=False):
'Remove the global event loop'
global C_DEFAULT_EVENT_LOOP
global _THREAD_LOCAL_EVENT_LOOPS
global _C_DEFAULT_EVENT_LOOP_LOCAL
if C_DEFAULT_EVENT_LOOP:
try:
C_DEFAULT_EVENT_LOOP._remove()
except:
if not ignore_errors:
raise
C_DEFAULT_EVENT_LOOP = None
if len(_THREAD_LOCAL_EVENT_LOOPS) > 0:
for loop in _THREAD_LOCAL_EVENT_LOOPS:
try:
loop._remove()
except:
if not ignore_errors:
raise
_THREAD_LOCAL_EVENT_LOOPS = []
_C_DEFAULT_EVENT_LOOP_LOCAL = None
cpdef create_event_loop(threaded=True):
'''Create a new global event loop. This will not remove the previous
@ -1383,7 +1401,9 @@ cpdef create_event_loop(threaded=True):
if threaded:
if _C_DEFAULT_EVENT_LOOP_LOCAL is None:
_C_DEFAULT_EVENT_LOOP_LOCAL = _threading.local()
_C_DEFAULT_EVENT_LOOP_LOCAL.loop = _EventLoop()
loop = _EventLoop()
_C_DEFAULT_EVENT_LOOP_LOCAL.loop = loop
_THREAD_LOCAL_EVENT_LOOPS.append(loop)
else:
C_DEFAULT_EVENT_LOOP = _EventLoop()
@ -1458,7 +1478,7 @@ cdef class Promise:
if self.is_consumed:
raise ValueError('Promise was already used in a consuming operation. You can no longer use this Promise object')
ret = <object>self.thisptr.wait(deref(self._event_loop.thisptr).waitScope)
ret = <object>helpers.waitPyPromise(self.thisptr, deref(self._event_loop.thisptr).waitScope)
Py_DECREF(ret)
self.is_consumed = True
@ -1524,7 +1544,8 @@ cdef class _VoidPromise:
if self.is_consumed:
raise ValueError('Promise was already used in a consuming operation. You can no longer use this Promise object')
self.thisptr.wait(deref(self._event_loop.thisptr).waitScope)
helpers.waitVoidPromise(self.thisptr, deref(self._event_loop.thisptr).waitScope)
self.is_consumed = True
cpdef then(self, func, error_func=None) except +reraise_kj_exception:
@ -1590,7 +1611,7 @@ cdef class _RemotePromise:
if self.is_consumed:
raise ValueError('Promise was already used in a consuming operation. You can no longer use this Promise object')
ret = _Response()._init_child(self.thisptr.wait(deref(self._event_loop.thisptr).waitScope), self._parent)
ret = _Response()._init_childptr(helpers.waitRemote(self.thisptr, deref(self._event_loop.thisptr).waitScope), self._parent)
self.is_consumed = True
return ret

66
test/test_threads.py Normal file
View file

@ -0,0 +1,66 @@
import capnp
import pytest
import test_capability_capnp
import socket
import threading
import platform
def test_making_event_loop():
capnp.remove_event_loop(True)
capnp.create_event_loop()
capnp.remove_event_loop()
capnp.create_event_loop()
def test_making_threaded_event_loop():
capnp.remove_event_loop(True)
capnp.create_event_loop(True)
capnp.remove_event_loop()
capnp.create_event_loop(True)
class Server(test_capability_capnp.TestInterface.Server):
def __init__(self, val=1):
self.val = val
def foo(self, i, j, **kwargs):
return str(i * 5 + self.val)
class SimpleRestorer(test_capability_capnp.TestSturdyRefObjectId.Restorer):
def restore(self, ref_id):
assert ref_id.tag == 'testInterface'
return Server(100)
@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="pycapnp's GIL handling isn't working properly at the moment for PyPy")
def test_using_threads():
capnp.remove_event_loop(True)
capnp.create_event_loop(True)
read, write = socket.socketpair(socket.AF_UNIX)
def run_server():
restorer = SimpleRestorer()
server = capnp.TwoPartyServer(write, restorer)
capnp.wait_forever()
server_thread = threading.Thread(target=run_server)
server_thread.daemon = True
server_thread.start()
client = capnp.TwoPartyClient(read)
ref = test_capability_capnp.TestSturdyRefObjectId.new_message(tag='testInterface')
cap = client.restore(ref)
cap = cap.cast_as(test_capability_capnp.TestInterface)
remote = cap.foo(i=5)
response = remote.wait()
assert response.x == '125'