diff --git a/capnp/lib/capnp.pyx b/capnp/lib/capnp.pyx index 014b1eb..2e54375 100644 --- a/capnp/lib/capnp.pyx +++ b/capnp/lib/capnp.pyx @@ -1859,7 +1859,7 @@ def _asyncio_close_patch(loop, oldclose, _EventLoop kjloop): # references to it are gone. Then, if a new asyncio loop ever gets started, a new kj-loop can also be # started. _C_DEFAULT_EVENT_LOOP_LOCAL.loop = _weakref.ref(kjloop) - loop.close = oldclose() + loop.close = oldclose return oldclose() cdef class _EventLoop: @@ -1875,20 +1875,12 @@ cdef class _EventLoop: self._init() cdef _init(self) except +reraise_kj_exception: - try: - loop = asyncio.get_running_loop() - self.customPort = new AsyncIoEventPort(loop) - kjLoop = self.customPort.getKjLoop() - self.waitScope = new WaitScope(deref(kjLoop)) - loop.close = _partial(_asyncio_close_patch, loop, loop.close, self) - self.in_asyncio_mode = True - except RuntimeError: - ptr = new capnp.AsyncIoContext(capnp.setupAsyncIo()) - self.lowLevelProvider = move(ptr.lowLevelProvider) - self.provider = move(ptr.provider) - self.waitScope = &ptr.waitScope - del ptr - self.in_asyncio_mode = False + loop = asyncio.get_running_loop() + self.customPort = new AsyncIoEventPort(loop) + kjLoop = self.customPort.getKjLoop() + self.waitScope = new WaitScope(deref(kjLoop)) + loop.close = _partial(_asyncio_close_patch, loop, loop.close, self) + self.in_asyncio_mode = True def __dealloc__(self): if not self.customPort == NULL: @@ -1931,16 +1923,6 @@ cdef _EventLoop C_DEFAULT_EVENT_LOOP_GETTER(): return _C_DEFAULT_EVENT_LOOP_LOCAL.loop -cpdef remove_event_loop(): - '''Remove the event loop''' - global _C_DEFAULT_EVENT_LOOP_LOCAL - - loop = getattr(_C_DEFAULT_EVENT_LOOP_LOCAL, 'loop', None) - if loop is not None: - loop._remove() - del _C_DEFAULT_EVENT_LOOP_LOCAL.loop - - def wait_forever(): """ Use libcapnp event loop to poll/wait forever @@ -2031,6 +2013,7 @@ cdef class _Promise: cdef Own[PyPromise] thisptr def __init__(self, obj=None): + C_DEFAULT_EVENT_LOOP_GETTER() if obj is not None: self.thisptr = capnp.heap[PyPromise](capnp.heap[PyRefCounter](obj)) @@ -2071,6 +2054,7 @@ cdef class _VoidPromise: cdef _init(self, VoidPromise other): + C_DEFAULT_EVENT_LOOP_GETTER() self.thisptr = capnp.heap[VoidPromise](moveVoidPromise(other)) return self @@ -2742,7 +2726,7 @@ cdef class _PyAsyncIoStreamProtocol(DummyBaseClass, asyncio.BufferedProtocol): cdef cbool read_eof # TODO: Temporary. This is an overflow buffer, which is needed for two blatant violations of the protocol. - # The first violation is int the SSL transport implementation. + # The first violation is in the the SSL transport implementation. # See https://github.com/python/cpython/issues/89322, fixed in Python 3.11. This bug causes the # SSL transport to force data upon us even when we've asked it to pause sending us data. Therefore, # we have to store the data in a overflow buffer. diff --git a/examples/async_calculator_client.py b/examples/async_calculator_client.py index 41ba8f8..5db3ac2 100755 --- a/examples/async_calculator_client.py +++ b/examples/async_calculator_client.py @@ -33,9 +33,7 @@ at the given address and does some RPCs" return parser.parse_args() -async def main(host): - host, port = parse_args().host.split(":") - connection = await capnp.AsyncIoStream.create_connection(host=host, port=port) +async def main(connection): client = capnp.TwoPartyClient(connection) # Bootstrap the Calculator interface @@ -302,6 +300,9 @@ async def main(host): print("PASS") +async def cmd_main(host): + host, port = host.split(":") + await main(await capnp.AsyncIoStream.create_connection(host=host, port=port)) if __name__ == "__main__": - asyncio.run(main(parse_args().host)) + asyncio.run(cmd_main(parse_args().host)) diff --git a/examples/calculator_client.py b/examples/calculator_client.py deleted file mode 100755 index 9cb9fa4..0000000 --- a/examples/calculator_client.py +++ /dev/null @@ -1,304 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import capnp - -import calculator_capnp - - -class PowerFunction(calculator_capnp.Calculator.Function.Server): - - """An implementation of the Function interface wrapping pow(). Note that - we're implementing this on the client side and will pass a reference to - the server. The server will then be able to make calls back to the client.""" - - def call(self, params, **kwargs): - """Note the **kwargs. This is very necessary to include, since - protocols can add parameters over time. Also, by default, a _context - variable is passed to all server methods, but you can also return - results directly as python objects, and they'll be added to the - results struct in the correct order""" - - return pow(params[0], params[1]) - - -def parse_args(): - parser = argparse.ArgumentParser( - usage="Connects to the Calculator server \ -at the given address and does some RPCs" - ) - parser.add_argument("host", help="HOST:PORT") - - return parser.parse_args() - - -def main(host): - client = capnp.TwoPartyClient(host) - - # Bootstrap the server capability and cast it to the Calculator interface - calculator = client.bootstrap().cast_as(calculator_capnp.Calculator) - - """Make a request that just evaluates the literal value 123. - - What's interesting here is that evaluate() returns a "Value", which is - another interface and therefore points back to an object living on the - server. We then have to call read() on that object to read it. - However, even though we are making two RPC's, this block executes in - *one* network round trip because of promise pipelining: we do not wait - for the first call to complete before we send the second call to the - server.""" - - print("Evaluating a literal... ", end="") - - # Make the request. Note we are using the shorter function form (instead - # of evaluate_request), and we are passing a dictionary that represents a - # struct and its member to evaluate - eval_promise = calculator.evaluate({"literal": 123}) - - # This is equivalent to: - """ - request = calculator.evaluate_request() - request.expression.literal = 123 - - # Send it, which returns a promise for the result (without blocking). - eval_promise = request.send() - """ - - # Using the promise, create a pipelined request to call read() on the - # returned object. Note that here we are using the shortened method call - # syntax read(), which is mostly just sugar for read_request().send() - read_promise = eval_promise.value.read() - - # Now that we've sent all the requests, wait for the response. Until this - # point, we haven't waited at all! - response = read_promise.wait() - assert response.value == 123 - - print("PASS") - - """Make a request to evaluate 123 + 45 - 67. - - The Calculator interface requires that we first call getOperator() to - get the addition and subtraction functions, then call evaluate() to use - them. But, once again, we can get both functions, call evaluate(), and - then read() the result -- four RPCs -- in the time of *one* network - round trip, because of promise pipelining.""" - - print("Using add and subtract... ", end="") - - # Get the "add" function from the server. - add = calculator.getOperator(op="add").func - # Get the "subtract" function from the server. - subtract = calculator.getOperator(op="subtract").func - - # Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate' - # + '_request', where 'evaluate' is the name of the method we want to call - request = calculator.evaluate_request() - subtract_call = request.expression.init("call") - subtract_call.function = subtract - subtract_params = subtract_call.init("params", 2) - subtract_params[1].literal = 67.0 - - add_call = subtract_params[0].init("call") - add_call.function = add - add_params = add_call.init("params", 2) - add_params[0].literal = 123 - add_params[1].literal = 45 - - # Send the evaluate() request, read() the result, and wait for read() to finish. - eval_promise = request.send() - read_promise = eval_promise.value.read() - - response = read_promise.wait() - assert response.value == 101 - - print("PASS") - - """ - Note: a one liner version of building the previous request (I highly - recommend not doing it this way for such a complicated structure, but I - just wanted to demonstrate it is possible to set all of the fields with a - dictionary): - - eval_promise = calculator.evaluate( -{'call': {'function': subtract, - 'params': [{'call': {'function': add, - 'params': [{'literal': 123}, - {'literal': 45}]}}, - {'literal': 67.0}]}}) - """ - - """Make a request to evaluate 4 * 6, then use the result in two more - requests that add 3 and 5. - - Since evaluate() returns its result wrapped in a `Value`, we can pass - that `Value` back to the server in subsequent requests before the first - `evaluate()` has actually returned. Thus, this example again does only - one network round trip.""" - - print("Pipelining eval() calls... ", end="") - - # Get the "add" function from the server. - add = calculator.getOperator(op="add").func - # Get the "multiply" function from the server. - multiply = calculator.getOperator(op="multiply").func - - # Build the request to evaluate 4 * 6 - request = calculator.evaluate_request() - - multiply_call = request.expression.init("call") - multiply_call.function = multiply - multiply_params = multiply_call.init("params", 2) - multiply_params[0].literal = 4 - multiply_params[1].literal = 6 - - multiply_result = request.send().value - - # Use the result in two calls that add 3 and add 5. - - add_3_request = calculator.evaluate_request() - add_3_call = add_3_request.expression.init("call") - add_3_call.function = add - add_3_params = add_3_call.init("params", 2) - add_3_params[0].previousResult = multiply_result - add_3_params[1].literal = 3 - add_3_promise = add_3_request.send().value.read() - - add_5_request = calculator.evaluate_request() - add_5_call = add_5_request.expression.init("call") - add_5_call.function = add - add_5_params = add_5_call.init("params", 2) - add_5_params[0].previousResult = multiply_result - add_5_params[1].literal = 5 - add_5_promise = add_5_request.send().value.read() - - # Now wait for the results. - assert add_3_promise.wait().value == 27 - assert add_5_promise.wait().value == 29 - - print("PASS") - - """Our calculator interface supports defining functions. Here we use it - to define two functions and then make calls to them as follows: - - f(x, y) = x * 100 + y - g(x) = f(x, x + 1) * 2; - f(12, 34) - g(21) - - Once again, the whole thing takes only one network round trip.""" - - print("Defining functions... ", end="") - - # Get the "add" function from the server. - add = calculator.getOperator(op="add").func - # Get the "multiply" function from the server. - multiply = calculator.getOperator(op="multiply").func - - # Define f. - request = calculator.defFunction_request() - request.paramCount = 2 - - # Build the function body. - add_call = request.body.init("call") - add_call.function = add - add_params = add_call.init("params", 2) - add_params[1].parameter = 1 # y - - multiply_call = add_params[0].init("call") - multiply_call.function = multiply - multiply_params = multiply_call.init("params", 2) - multiply_params[0].parameter = 0 # x - multiply_params[1].literal = 100 - - f = request.send().func - - # Define g. - request = calculator.defFunction_request() - request.paramCount = 1 - - # Build the function body. - multiply_call = request.body.init("call") - multiply_call.function = multiply - multiply_params = multiply_call.init("params", 2) - multiply_params[1].literal = 2 - - f_call = multiply_params[0].init("call") - f_call.function = f - f_params = f_call.init("params", 2) - f_params[0].parameter = 0 - - add_call = f_params[1].init("call") - add_call.function = add - add_params = add_call.init("params", 2) - add_params[0].parameter = 0 - add_params[1].literal = 1 - - g = request.send().func - - # OK, we've defined all our functions. Now create our eval requests. - - # f(12, 34) - f_eval_request = calculator.evaluate_request() - f_call = f_eval_request.expression.init("call") - f_call.function = f - f_params = f_call.init("params", 2) - f_params[0].literal = 12 - f_params[1].literal = 34 - f_eval_promise = f_eval_request.send().value.read() - - # g(21) - g_eval_request = calculator.evaluate_request() - g_call = g_eval_request.expression.init("call") - g_call.function = g - g_call.init("params", 1)[0].literal = 21 - g_eval_promise = g_eval_request.send().value.read() - - # Wait for the results. - assert f_eval_promise.wait().value == 1234 - assert g_eval_promise.wait().value == 4244 - - print("PASS") - - """Make a request that will call back to a function defined locally. - - Specifically, we will compute 2^(4 + 5). However, exponent is not - defined by the Calculator server. So, we'll implement the Function - interface locally and pass it to the server for it to use when - evaluating the expression. - - This example requires two network round trips to complete, because the - server calls back to the client once before finishing. In this - particular case, this could potentially be optimized by using a tail - call on the server side -- see CallContext::tailCall(). However, to - keep the example simpler, we haven't implemented this optimization in - the sample server.""" - - print("Using a callback... ", end="") - - # Get the "add" function from the server. - add = calculator.getOperator(op="add").func - - # Build the eval request for 2^(4+5). - request = calculator.evaluate_request() - - pow_call = request.expression.init("call") - pow_call.function = PowerFunction() - pow_params = pow_call.init("params", 2) - pow_params[0].literal = 2 - - add_call = pow_params[1].init("call") - add_call.function = add - add_params = add_call.init("params", 2) - add_params[0].literal = 4 - add_params[1].literal = 5 - - # Send the request and wait. - response = request.send().value.read().wait() - assert response.value == 512 - - print("PASS") - - -if __name__ == "__main__": - main(parse_args().host) diff --git a/examples/calculator_server.py b/examples/calculator_server.py deleted file mode 100755 index 8464439..0000000 --- a/examples/calculator_server.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import capnp -import time - -import calculator_capnp - - -def read_value(value): - """Helper function to asynchronously call read() on a Calculator::Value and - return a promise for the result. (In the future, the generated code might - include something like this automatically.)""" - - return value.read().then(lambda result: result.value) - - -def evaluate_impl(expression, params=None): - """Implementation of CalculatorImpl::evaluate(), also shared by - FunctionImpl::call(). In the latter case, `params` are the parameter - values passed to the function; in the former case, `params` is just an - empty list.""" - - which = expression.which() - - if which == "literal": - return capnp.Promise(expression.literal) - elif which == "previousResult": - return read_value(expression.previousResult) - elif which == "parameter": - assert expression.parameter < len(params) - return capnp.Promise(params[expression.parameter]) - elif which == "call": - call = expression.call - func = call.function - - # Evaluate each parameter. - paramPromises = [evaluate_impl(param, params) for param in call.params] - - joinedParams = capnp.join_promises(paramPromises) - # When the parameters are complete, call the function. - ret = joinedParams.then(lambda vals: func.call(vals)).then( - lambda result: result.value - ) - - return ret - else: - raise ValueError("Unknown expression type: " + which) - - -class ValueImpl(calculator_capnp.Calculator.Value.Server): - "Simple implementation of the Calculator.Value Cap'n Proto interface." - - def __init__(self, value): - self.value = value - - def read(self, **kwargs): - return self.value - - -class FunctionImpl(calculator_capnp.Calculator.Function.Server): - - """Implementation of the Calculator.Function Cap'n Proto interface, where the - function is defined by a Calculator.Expression.""" - - def __init__(self, paramCount, body): - self.paramCount = paramCount - self.body = body.as_builder() - - def call(self, params, _context, **kwargs): - """Note that we're returning a Promise object here, and bypassing the - helper functionality that normally sets the results struct from the - returned object. Instead, we set _context.results directly inside of - another promise""" - - assert len(params) == self.paramCount - # using setattr because '=' is not allowed inside of lambdas - return evaluate_impl(self.body, params).then( - lambda value: setattr(_context.results, "value", value) - ) - - -class OperatorImpl(calculator_capnp.Calculator.Function.Server): - - """Implementation of the Calculator.Function Cap'n Proto interface, wrapping - basic binary arithmetic operators.""" - - def __init__(self, op): - self.op = op - - def call(self, params, **kwargs): - assert len(params) == 2 - - op = self.op - - if op == "add": - return params[0] + params[1] - elif op == "subtract": - return params[0] - params[1] - elif op == "multiply": - return params[0] * params[1] - elif op == "divide": - return params[0] / params[1] - else: - raise ValueError("Unknown operator") - - -class CalculatorImpl(calculator_capnp.Calculator.Server): - "Implementation of the Calculator Cap'n Proto interface." - - def evaluate(self, expression, _context, **kwargs): - return evaluate_impl(expression).then( - lambda value: setattr(_context.results, "value", ValueImpl(value)) - ) - - def defFunction(self, paramCount, body, _context, **kwargs): - return FunctionImpl(paramCount, body) - - def getOperator(self, op, **kwargs): - return OperatorImpl(op) - - -def parse_args(): - parser = argparse.ArgumentParser( - usage="""Runs the server bound to the\ -given address/port ADDRESS may be '*' to bind to all local addresses.\ -:PORT may be omitted to choose a port automatically. """ - ) - - parser.add_argument("address", help="ADDRESS[:PORT]") - - return parser.parse_args() - - -def main(): - address = parse_args().address - - server = capnp.TwoPartyServer(address, bootstrap=CalculatorImpl()) - while True: - server.poll_once() - time.sleep(0.001) - - -if __name__ == "__main__": - main() diff --git a/examples/thread_client.py b/examples/thread_client.py deleted file mode 100755 index 317bc62..0000000 --- a/examples/thread_client.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import threading -import time -import capnp - -import thread_capnp - - -def parse_args(): - parser = argparse.ArgumentParser( - usage="Connects to the Example thread server \ -at the given address and does some RPCs" - ) - parser.add_argument("host", help="HOST:PORT") - - return parser.parse_args() - - -class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server): - - """An implementation of the StatusSubscriber interface""" - - def status(self, value, **kwargs): - print("status: {}".format(time.time())) - - -def start_status_thread(host): - client = capnp.TwoPartyClient(host) - cap = client.bootstrap().cast_as(thread_capnp.Example) - - subscriber = StatusSubscriber() - promise = cap.subscribeStatus(subscriber) - promise.wait() - - -def main(host): - client = capnp.TwoPartyClient(host) - cap = client.bootstrap().cast_as(thread_capnp.Example) - - status_thread = threading.Thread(target=start_status_thread, args=(host,)) - status_thread.daemon = True - status_thread.start() - - print("main: {}".format(time.time())) - cap.longRunning().wait() - print("main: {}".format(time.time())) - cap.longRunning().wait() - print("main: {}".format(time.time())) - cap.longRunning().wait() - print("main: {}".format(time.time())) - - -if __name__ == "__main__": - main(parse_args().host) diff --git a/examples/thread_server.py b/examples/thread_server.py deleted file mode 100755 index 79b0dea..0000000 --- a/examples/thread_server.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import capnp - -import thread_capnp - - -class ExampleImpl(thread_capnp.Example.Server): - "Implementation of the Example threading Cap'n Proto interface." - - def subscribeStatus(self, subscriber, **kwargs): - return ( - subscriber.status(True) - .then(lambda _: self.subscribeStatus(subscriber)) - ) - - def longRunning(self, **kwargs): - return - - -def parse_args(): - parser = argparse.ArgumentParser( - usage="""Runs the server bound to the\ -given address/port ADDRESS may be '*' to bind to all local addresses.\ -:PORT may be omitted to choose a port automatically. """ - ) - - parser.add_argument("address", help="ADDRESS[:PORT]") - - return parser.parse_args() - - -def main(): - address = parse_args().address - - server = capnp.TwoPartyServer(address, bootstrap=ExampleImpl()) - server.run_forever() - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 801edb2..fd9aa9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,5 @@ [build-system] requires = ["setuptools", "wheel", "pkgconfig", "cython"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" \ No newline at end of file diff --git a/test/test_capability.py b/test/test_capability.py index 19229f6..f52211b 100644 --- a/test/test_capability.py +++ b/test/test_capability.py @@ -1,5 +1,4 @@ import pytest -import time import capnp import test_capability_capnp as capability @@ -32,7 +31,7 @@ class PipelineServer(capability.TestPipeline.Server): return inCap.foo(i=n).then(_then) -def test_client(): +async def test_client(): client = capability.TestInterface._new_client(Server()) req = client._request("foo") @@ -65,7 +64,7 @@ def test_client(): req.baz = 1 -def test_simple_client(): +async def test_simple_client(): client = capability.TestInterface._new_client(Server()) remote = client._send("foo", i=5) @@ -125,7 +124,7 @@ def test_simple_client(): remote = client.foo(baz=5) -def test_pipeline(): +async def test_pipeline(): client = capability.TestPipeline._new_client(PipelineServer()) foo_client = capability.TestInterface._new_client(Server()) @@ -152,7 +151,7 @@ class BadServer(capability.TestInterface.Server): return str(i * 5 + extra + self.val), 10 # returning too many args -def test_exception_client(): +async def test_exception_client(): client = capability.TestInterface._new_client(BadServer()) remote = client._send("foo", i=5) @@ -173,7 +172,7 @@ class BadPipelineServer(capability.TestPipeline.Server): return inCap.foo(i=n).then(_then, _error) -def test_exception_chain(): +async def test_exception_chain(): client = capability.TestPipeline._new_client(BadPipelineServer()) foo_client = capability.TestInterface._new_client(BadServer()) @@ -185,7 +184,7 @@ def test_exception_chain(): assert "test was a success" in str(e) -def test_pipeline_exception(): +async def test_pipeline_exception(): client = capability.TestPipeline._new_client(BadPipelineServer()) foo_client = capability.TestInterface._new_client(BadServer()) @@ -201,7 +200,7 @@ def test_pipeline_exception(): remote.wait() -def test_casting(): +async def test_casting(): client = capability.TestExtends._new_client(Server()) client2 = client.upcast(capability.TestInterface) _ = client2.cast_as(capability.TestInterface) @@ -243,7 +242,7 @@ class TailCallee(capability.TestTailCallee.Server): results.c = TailCallOrder() -def test_tail_call(): +async def test_tail_call(): callee_server = TailCallee() caller_server = TailCaller() @@ -272,7 +271,7 @@ def test_tail_call(): assert caller_server.count == 1 -def test_cancel(): +async def test_cancel(): client = capability.TestInterface._new_client(Server()) req = client._request("foo") @@ -300,7 +299,7 @@ def test_cancel(): req.wait() -def test_double_send(): +async def test_double_send(): client = capability.TestInterface._new_client(Server()) req = client._request("foo") @@ -311,7 +310,7 @@ def test_double_send(): req.send() -def test_then_args(): +async def test_then_args(): capnp.Promise(0).then(lambda x: 1) with pytest.raises(Exception): @@ -350,7 +349,7 @@ class PromiseJoinServer(capability.TestPipeline.Server): ) -def test_promise_joining(): +async def test_promise_joining(): client = capability.TestPipeline._new_client(PromiseJoinServer()) foo_client = capability.TestInterface._new_client(Server()) @@ -363,7 +362,7 @@ class ExtendsServer(Server): pass -def test_inheritance(): +async def test_inheritance(): client = capability.TestExtends._new_client(ExtendsServer()) client.qux().wait() @@ -381,7 +380,7 @@ class PassedCapTest(capability.TestPassedCap.Server): return cap.foo(5).then(set_result) -def test_null_cap(): +async def test_null_cap(): client = capability.TestPassedCap._new_client(PassedCapTest()) assert client.foo(Server()).wait().x == "26" @@ -394,7 +393,7 @@ class StructArgTest(capability.TestStructArg.Server): return a + str(b) -def test_struct_args(): +async def test_struct_args(): client = capability.TestStructArg._new_client(StructArgTest()) assert client.bar(a="test", b=1).wait().c == "test1" with pytest.raises(capnp.KjException): @@ -406,7 +405,7 @@ class GenericTest(capability.TestGeneric.Server): return a.as_text() + "test" -def test_generic(): +async def test_generic(): client = capability.TestGeneric._new_client(GenericTest()) obj = capnp._MallocMessageBuilder().get_root_as_any() diff --git a/test/test_capability_context.py b/test/test_capability_context.py index e669191..6d47631 100644 --- a/test/test_capability_context.py +++ b/test/test_capability_context.py @@ -39,7 +39,7 @@ class PipelineServer: return context.params.inCap.foo(i=context.params.n).then(_then) -def test_client_context(capability): +async def test_client_context(capability): client = capability.TestInterface._new_client(Server()) req = client._request("foo") @@ -72,7 +72,7 @@ def test_client_context(capability): req.baz = 1 -def test_simple_client_context(capability): +async def test_simple_client_context(capability): client = capability.TestInterface._new_client(Server()) remote = client._send("foo", i=5) @@ -159,7 +159,7 @@ class BadServer: context.results.x2 = 5 # raises exception -def test_exception_client_context(capability): +async def test_exception_client_context(capability): client = capability.TestInterface._new_client(BadServer()) remote = client._send("foo", i=5) @@ -181,7 +181,7 @@ class BadPipelineServer: return context.params.inCap.foo(i=context.params.n).then(_then, _error) -def test_exception_chain_context(capability): +async def test_exception_chain_context(capability): client = capability.TestPipeline._new_client(BadPipelineServer()) foo_client = capability.TestInterface._new_client(BadServer()) @@ -193,7 +193,7 @@ def test_exception_chain_context(capability): assert "test was a success" in str(e) -def test_pipeline_exception_context(capability): +async def test_pipeline_exception_context(capability): client = capability.TestPipeline._new_client(BadPipelineServer()) foo_client = capability.TestInterface._new_client(BadServer()) @@ -209,7 +209,7 @@ def test_pipeline_exception_context(capability): remote.wait() -def test_casting_context(capability): +async def test_casting_context(capability): client = capability.TestExtends._new_client(Server()) client2 = client.upcast(capability.TestInterface) _ = client2.cast_as(capability.TestInterface) diff --git a/test/test_capability_old.py b/test/test_capability_old.py index c99e493..2d1c4a0 100644 --- a/test/test_capability_old.py +++ b/test/test_capability_old.py @@ -37,7 +37,7 @@ class PipelineServer: return inCap.foo(i=n).then(_then) -def test_client(capability): +async def test_client(capability): client = capability.TestInterface._new_client(Server()) req = client._request("foo") @@ -70,7 +70,7 @@ def test_client(capability): req.baz = 1 -def test_simple_client(capability): +async def test_simple_client(capability): client = capability.TestInterface._new_client(Server()) remote = client._send("foo", i=5) @@ -159,7 +159,7 @@ class BadServer: return str(i * 5 + extra + self.val), 10 # returning too many args -def test_exception_client(capability): +async def test_exception_client(capability): client = capability.TestInterface._new_client(BadServer()) remote = client._send("foo", i=5) @@ -180,7 +180,7 @@ class BadPipelineServer: return inCap.foo(i=n).then(_then, _error) -def test_exception_chain(capability): +async def test_exception_chain(capability): client = capability.TestPipeline._new_client(BadPipelineServer()) foo_client = capability.TestInterface._new_client(BadServer()) @@ -192,7 +192,7 @@ def test_exception_chain(capability): assert "test was a success" in str(e) -def test_pipeline_exception(capability): +async def test_pipeline_exception(capability): client = capability.TestPipeline._new_client(BadPipelineServer()) foo_client = capability.TestInterface._new_client(BadServer()) @@ -208,7 +208,7 @@ def test_pipeline_exception(capability): remote.wait() -def test_casting(capability): +async def test_casting(capability): client = capability.TestExtends._new_client(Server()) client2 = client.upcast(capability.TestInterface) _ = client2.cast_as(capability.TestInterface) diff --git a/test/test_examples.py b/test/test_examples.py index 40d496b..c99afd3 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -124,16 +124,6 @@ def test_async_calculator_example(cleanup): run_subprocesses(address, server, client) -@pytest.mark.xfail( - reason="Some versions of python don't like to share ports, don't worry if this fails" -) -def test_thread_example(cleanup): - address = "{}:36433".format(hostname) - server = "thread_server.py" - client = "thread_client.py" - run_subprocesses(address, server, client, wildcard_server=True) - - def test_addressbook_example(cleanup): proc = subprocess.Popen( [sys.executable, os.path.join(examples_dir, "addressbook.py")] diff --git a/test/test_response.py b/test/test_response.py index 0d49e2a..c277ce1 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -17,7 +17,7 @@ class BazServer(test_response_capnp.Baz.Server): return {"foo": FooServer()} -def test_response_reference(): +async def test_response_reference(): baz = test_response_capnp.Baz._new_client(BazServer()) bar = baz.grault().wait().bar @@ -27,7 +27,7 @@ def test_response_reference(): assert foo.foo().wait().val == 1 -def test_response_reference2(): +async def test_response_reference2(): baz = test_response_capnp.Baz._new_client(BazServer()) bar = baz.grault().wait().bar diff --git a/test/test_rpc.py b/test/test_rpc.py index fa5daa7..975fa20 100644 --- a/test/test_rpc.py +++ b/test/test_rpc.py @@ -17,8 +17,10 @@ class Server(test_capability_capnp.TestInterface.Server): return str(i * 5 + self.val) -def test_simple_rpc_with_options(): +async def test_simple_rpc_with_options(): read, write = socket.socketpair() + read = await capnp.AsyncIoStream.create_connection(sock = read) + write = await capnp.AsyncIoStream.create_connection(sock = write) _ = capnp.TwoPartyServer(write, bootstrap=Server()) # This traversal limit is too low to receive the response in, so we expect @@ -32,8 +34,10 @@ def test_simple_rpc_with_options(): _ = remote.wait() -def test_simple_rpc_bootstrap(): +async def test_simple_rpc_bootstrap(): read, write = socket.socketpair() + read = await capnp.AsyncIoStream.create_connection(sock = read) + write = await capnp.AsyncIoStream.create_connection(sock = write) _ = capnp.TwoPartyServer(write, bootstrap=Server(100)) client = capnp.TwoPartyClient(read) @@ -42,6 +46,6 @@ def test_simple_rpc_bootstrap(): cap = cap.cast_as(test_capability_capnp.TestInterface) remote = cap.foo(i=5) - response = remote.wait() + response = await remote assert response.x == "125" diff --git a/test/test_rpc_calculator.py b/test/test_rpc_calculator.py index ca5fd16..d101bdf 100644 --- a/test/test_rpc_calculator.py +++ b/test/test_rpc_calculator.py @@ -9,57 +9,20 @@ import capnp examples_dir = os.path.join(os.path.dirname(__file__), "..", "examples") sys.path.append(examples_dir) -import calculator_client # noqa: E402 -import calculator_server # noqa: E402 - -# Uses run_subprocesses function -import test_examples # noqa: E402 - -processes = [] +import async_calculator_client # noqa: E402 +import async_calculator_server # noqa: E402 -@pytest.fixture -def cleanup(): - yield - for p in processes: - p.kill() - - -def test_calculator(): +async def test_calculator(): read, write = socket.socketpair() + read = await capnp.AsyncIoStream.create_connection(sock = read) + write = await capnp.AsyncIoStream.create_connection(sock = write) - _ = capnp.TwoPartyServer(write, bootstrap=calculator_server.CalculatorImpl()) - calculator_client.main(read) + _ = capnp.TwoPartyServer(write, bootstrap=async_calculator_server.CalculatorImpl()) + await async_calculator_client.main(read) -@pytest.mark.xfail( - reason="Some versions of python don't like to share ports, don't worry if this fails" -) -def test_calculator_tcp(cleanup): - address = "localhost:36431" - test_examples.run_subprocesses( - address, "calculator_server.py", "calculator_client.py", wildcard_server=True - ) - - -@pytest.mark.xfail( - reason="Some versions of python don't like to share ports, don't worry if this fails" -) -@pytest.mark.skipif(os.name == "nt", reason="socket.AF_UNIX not supported on Windows") -def test_calculator_unix(cleanup): - path = "/tmp/pycapnp-test" - try: - os.unlink(path) - except OSError: - pass - - address = "unix:" + path - test_examples.run_subprocesses( - address, "calculator_server.py", "calculator_client.py" - ) - - -def test_calculator_gc(): +async def test_calculator_gc(): def new_evaluate_impl(old_evaluate_impl): def call(*args, **kwargs): gc.collect() @@ -68,12 +31,14 @@ def test_calculator_gc(): return call read, write = socket.socketpair() + read = await capnp.AsyncIoStream.create_connection(sock = read) + write = await capnp.AsyncIoStream.create_connection(sock = write) # inject a gc.collect to the beginning of every evaluate_impl call - evaluate_impl_orig = calculator_server.evaluate_impl - calculator_server.evaluate_impl = new_evaluate_impl(evaluate_impl_orig) + evaluate_impl_orig = async_calculator_server.evaluate_impl + async_calculator_server.evaluate_impl = new_evaluate_impl(evaluate_impl_orig) - _ = capnp.TwoPartyServer(write, bootstrap=calculator_server.CalculatorImpl()) - calculator_client.main(read) + _ = capnp.TwoPartyServer(write, bootstrap=async_calculator_server.CalculatorImpl()) + await async_calculator_client.main(read) - calculator_server.evaluate_impl = evaluate_impl_orig + async_calculator_server.evaluate_impl = evaluate_impl_orig diff --git a/test/test_threads.py b/test/test_threads.py deleted file mode 100644 index 4756294..0000000 --- a/test/test_threads.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -thread test -""" - -import platform -import socket -import threading - -import pytest - -import capnp - -import test_capability_capnp - - -class Server(test_capability_capnp.TestInterface.Server): - """ - Server - """ - - def __init__(self, val=100): - self.val = val - - def foo(self, i, j, **kwargs): - """ - foo - """ - return str(i * 5 + self.val) - - -@pytest.mark.skipif( - platform.python_implementation() == "PyPy", - reason="pycapnp's GIL handling isn't working properly at the moment for PyPy", -) -def test_using_threads(): - """ - Thread test - """ - read, write = socket.socketpair() - - def run_server(): - _ = capnp.TwoPartyServer(write, bootstrap=Server()) - capnp.wait_forever() - - server_thread = threading.Thread(target=run_server) - server_thread.daemon = True - server_thread.start() - - client = capnp.TwoPartyClient(read) - cap = client.bootstrap().cast_as(test_capability_capnp.TestInterface) - - remote = cap.foo(i=5) - response = remote.wait() - - assert response.x == "125"