Force server methods to be async and client calls to use await

This commit is contained in:
Lasse Blaauwbroek 2023-06-08 03:56:57 +02:00
parent a69bc72a0b
commit 4b5c4211f1
11 changed files with 193 additions and 310 deletions

View file

@ -36,6 +36,7 @@ import threading as _threading
import traceback as _traceback
import warnings as _warnings
import weakref as _weakref
import traceback as _traceback
from types import ModuleType as _ModuleType
from operator import attrgetter as _attrgetter
@ -84,7 +85,7 @@ def void_task_done_callback(method_name, _VoidPromiseFulfiller fulfiller, task):
exc = task.exception()
if exc is not None:
fulfiller.fulfiller.reject(makeException(capnp.StringPtr(str(exc))))
fulfiller.fulfiller.reject(makeException(capnp.StringPtr(''.join(_traceback.format_exception(exc)))))
return
res = task.result()
@ -123,27 +124,16 @@ cdef api VoidPromise * call_server_method(object server,
func = getattr(server, method_name+'_context', None)
if func is not None:
ret = func(context)
if ret is not None:
if type(ret) is _VoidPromise:
return new VoidPromise(moveVoidPromise(deref((<_VoidPromise>ret).thisptr)))
elif type(ret) is _Promise:
return new VoidPromise(helpers.convert_to_voidpromise(move((<_Promise>ret).thisptr)))
elif asyncio.iscoroutine(ret):
task = asyncio.create_task(ret)
callback = _partial(void_task_done_callback, method_name)
return new VoidPromise(helpers.taskToPromise(
capnp.heap[PyRefCounter](<PyObject*>task),
<PyObject*>callback))
else:
try:
warning_msg = (
"Server function ({}) returned a value that was not a Promise: return = {}"
.format(method_name, str(ret)))
except Exception:
warning_msg = 'Server function (%s) returned a value that was not a Promise' % (method_name)
_warnings.warn_explicit(
warning_msg, UserWarning, _inspect.getsourcefile(func), _inspect.getsourcelines(func)[1])
if asyncio.iscoroutine(ret):
task = asyncio.create_task(ret)
callback = _partial(void_task_done_callback, method_name)
return new VoidPromise(helpers.taskToPromise(
capnp.heap[PyRefCounter](<PyObject*>task),
<PyObject*>callback))
else:
raise ValueError(
"Server function ({}) is not a coroutine"
.format(method_name, str(ret)))
else:
func = getattr(server, method_name) # will raise if no function found
params = context.params
@ -151,21 +141,18 @@ cdef api VoidPromise * call_server_method(object server,
params_dict['_context'] = context
ret = func(**params_dict)
if ret is not None:
if type(ret) is _VoidPromise:
return new VoidPromise(moveVoidPromise(deref((<_VoidPromise>ret).thisptr)))
elif type(ret) is _Promise:
return new VoidPromise(helpers.convert_to_voidpromise(move((<_Promise>ret).thisptr)))
elif asyncio.iscoroutine(ret):
async def finalize():
fill_context(method_name, context, await ret)
task = asyncio.create_task(finalize())
callback = _partial(void_task_done_callback, method_name)
return new VoidPromise(helpers.taskToPromise(
capnp.heap[PyRefCounter](<PyObject*>task),
<PyObject*>callback))
else:
fill_context(method_name, context, ret)
if asyncio.iscoroutine(ret):
async def finalize():
fill_context(method_name, context, await ret)
task = asyncio.create_task(finalize())
callback = _partial(void_task_done_callback, method_name)
return new VoidPromise(helpers.taskToPromise(
capnp.heap[PyRefCounter](<PyObject*>task),
<PyObject*>callback))
else:
raise ValueError(
"Server function ({}) is not a coroutine"
.format(method_name, str(ret)))
return NULL
@ -1970,9 +1957,11 @@ cdef _promise_to_asyncio(PromiseTypes promise):
fut = asyncio.get_running_loop().create_future()
# Attach the promise to the future, so that it doesn't get destroyed
fut.kjpromise = promise.then(
fut.kjpromise = _promise_then(
promise,
lambda res: fut.set_result(res) if not fut.cancelled() else None,
lambda err: fut.set_exception(err) if not fut.cancelled() else None)
lambda err: fut.set_exception(err) if not fut.cancelled() else None,
1)
del promise
fut.add_done_callback(
lambda fut: fut.kjpromise.cancel() if fut.cancelled() else None)
@ -1990,15 +1979,6 @@ cdef class _Promise:
self.thisptr = capnp.heap[PyPromise](movePromise(other))
return self
cpdef wait(self) except +reraise_kj_exception:
_promise_check_consumed(self)
cdef Own[PyPromise] prom = move(self.thisptr) # Explicit move to not leave thisptr dangling
cdef Own[PyRefCounter] ret
cdef _EventLoop loop = C_DEFAULT_EVENT_LOOP_GETTER()
with nogil:
ret = move(prom.get().wait(deref(loop.waitScope)))
return <object>ret.get().obj
async def a_wait(self):
"""
Asyncio version of wait().
@ -2011,9 +1991,6 @@ cdef class _Promise:
def __await__(self):
return _promise_to_asyncio(self).__await__()
cpdef then(self, func, error_func=None) except +reraise_kj_exception:
return _promise_then(self, func, error_func, 1)
cpdef cancel(self) except +reraise_kj_exception:
self.thisptr = Own[PyPromise]()
@ -2027,13 +2004,6 @@ cdef class _VoidPromise:
self.thisptr = capnp.heap[VoidPromise](moveVoidPromise(other))
return self
cpdef wait(self) except +reraise_kj_exception:
_promise_check_consumed(self)
cdef Own[VoidPromise] prom = move(self.thisptr) # Explicit move to not leave thisptr dangling
cdef _EventLoop loop = C_DEFAULT_EVENT_LOOP_GETTER()
with nogil:
prom.get().wait(deref(loop.waitScope))
async def a_wait(self):
"""
Asyncio version of wait().
@ -2051,9 +2021,6 @@ cdef class _VoidPromise:
_promise_check_consumed(self)
return _Promise()._init(helpers.convert_to_pypromise(move(self.thisptr)))
cpdef then(self, func, error_func=None) except +reraise_kj_exception:
return _promise_then(self, func, error_func, 0)
cpdef cancel(self) except +reraise_kj_exception:
self.thisptr = Own[VoidPromise]()
@ -2073,14 +2040,6 @@ cdef class _RemotePromise:
self._parent = parent
return self
cpdef wait(self) except +reraise_kj_exception:
"""Wait on the promise. This will block until the promise has completed."""
_promise_check_consumed(self)
cdef _EventLoop loop = C_DEFAULT_EVENT_LOOP_GETTER()
with nogil:
response = helpers.waitRemote(move(self.thisptr), deref(loop.waitScope))
return _Response()._init_childptr(response, None)
async def a_wait(self):
"""
Asyncio version of wait().
@ -2133,11 +2092,6 @@ cdef class _RemotePromise:
def to_dict(self, verbose=False, ordered=False):
return _to_dict(self, verbose, ordered)
cpdef then(self, func, error_func=None) except +reraise_kj_exception:
parent = self._parent
self._parent = None # We don't need parent anymore. Setting to none allows quicker garbage collection
return _promise_then(self, func, error_func, 1, attach=parent)
cpdef cancel(self) except +reraise_kj_exception:
self.thisptr = Own[RemotePromise]()
self._parent = None # We don't need parent anymore. Setting to none allows quicker garbage collection

View file

@ -13,7 +13,7 @@ class PowerFunction(calculator_capnp.Calculator.Function.Server):
we're implementing this on the client side and will pass a reference to
the server. The server will then be able to make calls back to the client."""
def call(self, params, **kwargs):
async def call(self, params, **kwargs):
"""Note the **kwargs. This is very necessary to include, since
protocols can add parameters over time. Also, by default, a _context
variable is passed to all server methods, but you can also return

View file

@ -48,7 +48,7 @@ class ValueImpl(calculator_capnp.Calculator.Value.Server):
def __init__(self, value):
self.value = value
def read(self, **kwargs):
async def read(self, **kwargs):
return self.value
@ -79,7 +79,7 @@ class OperatorImpl(calculator_capnp.Calculator.Function.Server):
def __init__(self, op):
self.op = op
def call(self, params, **kwargs):
async def call(self, params, **kwargs):
assert len(params) == 2
op = self.op
@ -102,10 +102,10 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
async def evaluate(self, expression, _context, **kwargs):
return ValueImpl(await evaluate_impl(expression))
def defFunction(self, paramCount, body, _context, **kwargs):
async def defFunction(self, paramCount, body, _context, **kwargs):
return FunctionImpl(paramCount, body)
def getOperator(self, op, **kwargs):
async def getOperator(self, op, **kwargs):
return OperatorImpl(op)

View file

@ -19,7 +19,7 @@ class PowerFunction(calculator_capnp.Calculator.Function.Server):
we're implementing this on the client side and will pass a reference to
the server. The server will then be able to make calls back to the client."""
def call(self, params, **kwargs):
async def call(self, params, **kwargs):
"""Note the **kwargs. This is very necessary to include, since
protocols can add parameters over time. Also, by default, a _context
variable is passed to all server methods, but you can also return

View file

@ -17,15 +17,7 @@ logger.setLevel(logging.DEBUG)
this_dir = os.path.dirname(os.path.abspath(__file__))
def read_value(value):
"""Helper function to asynchronously call read() on a Calculator::Value and
return a promise for the result. (In the future, the generated code might
include something like this automatically.)"""
return value.read().then(lambda result: result.value)
def evaluate_impl(expression, params=None):
async def evaluate_impl(expression, params=None):
"""Implementation of CalculatorImpl::evaluate(), also shared by
FunctionImpl::call(). In the latter case, `params` are the parameter
values passed to the function; in the former case, `params` is just an
@ -34,26 +26,23 @@ def evaluate_impl(expression, params=None):
which = expression.which()
if which == "literal":
return capnp.Promise(expression.literal)
return expression.literal
elif which == "previousResult":
return read_value(expression.previousResult)
return (await expression.previousResult.read()).value
elif which == "parameter":
assert expression.parameter < len(params)
return capnp.Promise(params[expression.parameter])
return params[expression.parameter]
elif which == "call":
call = expression.call
func = call.function
# Evaluate each parameter.
paramPromises = [evaluate_impl(param, params) for param in call.params]
vals = await asyncio.gather(*paramPromises)
joinedParams = capnp.join_promises(paramPromises)
# When the parameters are complete, call the function.
ret = joinedParams.then(lambda vals: func.call(vals)).then(
lambda result: result.value
)
return ret
result = await func.call(vals)
return result.value
else:
raise ValueError("Unknown expression type: " + which)
@ -64,7 +53,7 @@ class ValueImpl(calculator_capnp.Calculator.Value.Server):
def __init__(self, value):
self.value = value
def read(self, **kwargs):
async def read(self, **kwargs):
return self.value
@ -77,17 +66,14 @@ class FunctionImpl(calculator_capnp.Calculator.Function.Server):
self.paramCount = paramCount
self.body = body.as_builder()
def call(self, params, _context, **kwargs):
async def call(self, params, _context, **kwargs):
"""Note that we're returning a Promise object here, and bypassing the
helper functionality that normally sets the results struct from the
returned object. Instead, we set _context.results directly inside of
another promise"""
assert len(params) == self.paramCount
# using setattr because '=' is not allowed inside of lambdas
return evaluate_impl(self.body, params).then(
lambda value: setattr(_context.results, "value", value)
)
return await evaluate_impl(self.body, params)
class OperatorImpl(calculator_capnp.Calculator.Function.Server):
@ -98,7 +84,7 @@ class OperatorImpl(calculator_capnp.Calculator.Function.Server):
def __init__(self, op):
self.op = op
def call(self, params, **kwargs):
async def call(self, params, **kwargs):
assert len(params) == 2
op = self.op
@ -118,15 +104,13 @@ class OperatorImpl(calculator_capnp.Calculator.Function.Server):
class CalculatorImpl(calculator_capnp.Calculator.Server):
"Implementation of the Calculator Cap'n Proto interface."
def evaluate(self, expression, _context, **kwargs):
return evaluate_impl(expression).then(
lambda value: setattr(_context.results, "value", ValueImpl(value))
)
async def evaluate(self, expression, _context, **kwargs):
return ValueImpl(await evaluate_impl(expression))
def defFunction(self, paramCount, body, _context, **kwargs):
async def defFunction(self, paramCount, body, _context, **kwargs):
return FunctionImpl(paramCount, body)
def getOperator(self, op, **kwargs):
async def getOperator(self, op, **kwargs):
return OperatorImpl(op)

View file

@ -28,7 +28,7 @@ class ExampleImpl(thread_capnp.Example.Server):
async def longRunning(self, **kwargs):
await asyncio.sleep(0.1)
def alive(self, **kwargs):
async def alive(self, **kwargs):
return True

View file

@ -8,27 +8,25 @@ class Server(capability.TestInterface.Server):
def __init__(self, val=1):
self.val = val
def foo(self, i, j, **kwargs):
async def foo(self, i, j, **kwargs):
extra = 0
if j:
extra = 1
return str(i * 5 + extra + self.val)
def buz(self, i, **kwargs):
async def buz(self, i, **kwargs):
return i.host + "_test"
def bam(self, i, **kwargs):
async def bam(self, i, **kwargs):
return str(i) + "_test", i
class PipelineServer(capability.TestPipeline.Server):
def getCap(self, n, inCap, _context, **kwargs):
def _then(response):
_results = _context.results
_results.s = response.x + "_foo"
_results.outBox.cap = Server(100)
return inCap.foo(i=n).then(_then)
async def getCap(self, n, inCap, _context, **kwargs):
response = await inCap.foo(i=n)
_results = _context.results
_results.s = response.x + "_foo"
_results.outBox.cap = Server(100)
async def test_client():
@ -38,7 +36,7 @@ async def test_client():
req.i = 5
remote = req.send()
response = remote.wait()
response = await remote
assert response.x == "26"
@ -46,7 +44,7 @@ async def test_client():
req.i = 5
remote = req.send()
response = remote.wait()
response = await remote
assert response.x == "26"
@ -68,42 +66,42 @@ async def test_simple_client():
client = capability.TestInterface._new_client(Server())
remote = client._send("foo", i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(i=5, j=True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.foo(5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(5, True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.foo(5, j=True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.buz(capability.TestSturdyRefHostId.new_message(host="localhost"))
response = remote.wait()
response = await remote
assert response.x == "localhost_test"
remote = client.bam(i=5)
response = remote.wait()
response = await remote
assert response.x == "5_test"
assert response.i == 5
@ -133,10 +131,10 @@ async def test_pipeline():
outCap = remote.outBox.cap
pipelinePromise = outCap.foo(i=10)
response = pipelinePromise.wait()
response = await pipelinePromise
assert response.x == "150"
response = remote.wait()
response = await remote
assert response.s == "26_foo"
@ -144,7 +142,7 @@ class BadServer(capability.TestInterface.Server):
def __init__(self, val=1):
self.val = val
def foo(self, i, j, **kwargs):
async def foo(self, i, j, **kwargs):
extra = 0
if j:
extra = 1
@ -156,21 +154,16 @@ async def test_exception_client():
remote = client._send("foo", i=5)
with pytest.raises(capnp.KjException):
remote.wait()
await remote
class BadPipelineServer(capability.TestPipeline.Server):
def getCap(self, n, inCap, _context, **kwargs):
def _then(response):
_results = _context.results
_results.s = response.x + "_foo"
_results.outBox.cap = Server(100)
def _error(error):
async def getCap(self, n, inCap, _context, **kwargs):
try:
await inCap.foo(i=n)
except capnp.KjException:
raise Exception("test was a success")
return inCap.foo(i=n).then(_then, _error)
async def test_exception_chain():
client = capability.TestPipeline._new_client(BadPipelineServer())
@ -179,7 +172,7 @@ async def test_exception_chain():
remote = client.getCap(n=5, inCap=foo_client)
try:
remote.wait()
await remote
except Exception as e:
assert "test was a success" in str(e)
@ -194,10 +187,10 @@ async def test_pipeline_exception():
pipelinePromise = outCap.foo(i=10)
with pytest.raises(Exception):
pipelinePromise.wait()
await pipelinePromise
with pytest.raises(Exception):
remote.wait()
await remote
async def test_casting():
@ -213,7 +206,7 @@ class TailCallOrder(capability.TestCallOrder.Server):
def __init__(self):
self.count = -1
def getCallSequence(self, expected, **kwargs):
async def getCallSequence(self, expected, **kwargs):
self.count += 1
return self.count
@ -222,18 +215,18 @@ class TailCaller(capability.TestTailCaller.Server):
def __init__(self):
self.count = 0
def foo(self, i, callee, _context, **kwargs):
async def foo(self, i, callee, _context, **kwargs):
self.count += 1
tail = callee.foo_request(i=i, t="from TailCaller")
return _context.tail_call(tail)
return await _context.tail_call(tail)
class TailCallee(capability.TestTailCallee.Server):
def __init__(self):
self.count = 0
def foo(self, i, t, _context, **kwargs):
async def foo(self, i, t, _context, **kwargs):
self.count += 1
results = _context.results
@ -252,7 +245,7 @@ async def test_tail_call():
promise = caller.foo(i=456, callee=callee)
dependent_call1 = promise.c.getCallSequence()
response = promise.wait()
response = await promise
assert response.i == 456
assert response.i == 456
@ -260,11 +253,11 @@ async def test_tail_call():
dependent_call2 = response.c.getCallSequence()
dependent_call3 = response.c.getCallSequence()
result = dependent_call1.wait()
result = await dependent_call1
assert result.n == 0
result = dependent_call2.wait()
result = await dependent_call2
assert result.n == 1
result = dependent_call3.wait()
result = await dependent_call3
assert result.n == 2
assert callee_server.count == 1
@ -281,22 +274,21 @@ async def test_cancel():
remote.cancel()
with pytest.raises(Exception):
remote.wait()
await remote
req = client.foo(5)
trans = req.then(lambda x: 5)
await req
req.cancel() # Cancel a promise that was already consumed
assert trans.wait() == 5
req = client.foo(5)
req.cancel()
with pytest.raises(Exception):
trans = req.then(lambda x: 5)
await req
req = client.foo(5)
assert req.wait().x == "26"
assert (await req).x == "26"
with pytest.raises(Exception):
req.wait()
await req
async def test_double_send():
@ -305,48 +297,18 @@ async def test_double_send():
req = client._request("foo")
req.i = 5
req.send()
await req.send()
with pytest.raises(Exception):
req.send()
async def test_then_args():
capnp.Promise(0).then(lambda x: 1)
with pytest.raises(Exception):
capnp.Promise(0).then(lambda: 1)
with pytest.raises(Exception):
capnp.Promise(0).then(lambda x, y: 1)
client = capability.TestInterface._new_client(Server())
client.foo(i=5).then(lambda x: 1)
with pytest.raises(Exception):
client.foo(i=5).then(lambda: 1)
with pytest.raises(Exception):
client.foo(i=5).then(lambda x, y: 1)
await req.send()
class PromiseJoinServer(capability.TestPipeline.Server):
def getCap(self, n, inCap, _context, **kwargs):
def _then(response):
_results = _context.results
_results.s = response.x + "_bar"
_results.outBox.cap = inCap
return (
inCap.foo(i=n)
.then(
lambda res: capnp.Promise(int(res.x))
) # Make sure that Promise is flattened
.then(
lambda x: inCap.foo(i=x + 1)
) # Make sure that RemotePromise is flattened
.then(_then)
)
async def getCap(self, n, inCap, _context, **kwargs):
res = await inCap.foo(i=n)
response = await inCap.foo(i = int(res.x) + 1)
_results = _context.results
_results.s = response.x + "_bar"
_results.outBox.cap = inCap
async def test_promise_joining():
@ -354,54 +316,52 @@ async def test_promise_joining():
foo_client = capability.TestInterface._new_client(Server())
remote = client.getCap(n=5, inCap=foo_client)
assert remote.wait().s == "136_bar"
assert (await remote).s == "136_bar"
class ExtendsServer(Server):
def qux(self, **kwargs):
async def qux(self, **kwargs):
pass
async def test_inheritance():
client = capability.TestExtends._new_client(ExtendsServer())
client.qux().wait()
await client.qux()
remote = client.foo(i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
class PassedCapTest(capability.TestPassedCap.Server):
def foo(self, cap, _context, **kwargs):
def set_result(res):
_context.results.x = res.x
return cap.foo(5).then(set_result)
async def foo(self, cap, _context, **kwargs):
res = await cap.foo(5)
_context.results.x = res.x
async def test_null_cap():
client = capability.TestPassedCap._new_client(PassedCapTest())
assert client.foo(Server()).wait().x == "26"
assert (await client.foo(Server())).x == "26"
with pytest.raises(capnp.KjException):
client.foo().wait()
await client.foo()
class StructArgTest(capability.TestStructArg.Server):
def bar(self, a, b, **kwargs):
async def bar(self, a, b, **kwargs):
return a + str(b)
async def test_struct_args():
client = capability.TestStructArg._new_client(StructArgTest())
assert client.bar(a="test", b=1).wait().c == "test1"
assert (await client.bar(a="test", b=1)).c == "test1"
with pytest.raises(capnp.KjException):
assert client.bar("test", 1).wait().c == "test1"
assert (await client.bar("test", 1)).c == "test1"
class GenericTest(capability.TestGeneric.Server):
def foo(self, a, **kwargs):
async def foo(self, a, **kwargs):
return a.as_text() + "test"
@ -410,4 +370,4 @@ async def test_generic():
obj = capnp._MallocMessageBuilder().get_root_as_any()
obj.set_as_text("anypointer_")
assert client.foo(obj).wait().b == "anypointer_test"
assert (await client.foo(obj)).b == "anypointer_test"

View file

@ -18,13 +18,13 @@ class Server:
def __init__(self, val=1):
self.val = val
def foo_context(self, context):
async def foo_context(self, context):
extra = 0
if context.params.j:
extra = 1
context.results.x = str(context.params.i * 5 + extra + self.val)
def buz_context(self, context):
async def buz_context(self, context):
context.results.x = context.params.i.host + "_test"
@ -32,14 +32,12 @@ class PipelineServer:
def __init__(self, capability):
self.capability = capability
def getCap_context(self, context):
def _then(response):
context.results.s = response.x + "_foo"
context.results.outBox.cap = self.capability.TestInterface._new_server(
Server(100)
)
return context.params.inCap.foo(i=context.params.n).then(_then)
async def getCap_context(self, context):
response = await context.params.inCap.foo(i=context.params.n)
context.results.s = response.x + "_foo"
context.results.outBox.cap = self.capability.TestInterface._new_server(
Server(100)
)
async def test_client_context(capability):
@ -49,7 +47,7 @@ async def test_client_context(capability):
req.i = 5
remote = req.send()
response = remote.wait()
response = await remote
assert response.x == "26"
@ -57,7 +55,7 @@ async def test_client_context(capability):
req.i = 5
remote = req.send()
response = remote.wait()
response = await remote
assert response.x == "26"
@ -79,37 +77,37 @@ async def test_simple_client_context(capability):
client = capability.TestInterface._new_client(Server())
remote = client._send("foo", i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(i=5, j=True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.foo(5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(5, True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.foo(5, j=True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.buz(capability.TestSturdyRefHostId.new_message(host="localhost"))
response = remote.wait()
response = await remote
assert response.x == "localhost_test"
@ -138,10 +136,10 @@ async def test_pipeline_context(capability):
outCap = remote.outBox.cap
pipelinePromise = outCap.foo(i=10)
response = pipelinePromise.wait()
response = await pipelinePromise
assert response.x == "150"
response = remote.wait()
response = await remote
assert response.s == "26_foo"
@ -149,7 +147,7 @@ class BadServer:
def __init__(self, val=1):
self.val = val
def foo_context(self, context):
async def foo_context(self, context):
context.results.x = str(context.params.i * 5 + self.val)
context.results.x2 = 5 # raises exception
@ -159,25 +157,19 @@ async def test_exception_client_context(capability):
remote = client._send("foo", i=5)
with pytest.raises(capnp.KjException):
remote.wait()
await remote
class BadPipelineServer:
def __init__(self, capability):
self.capability = capability
def getCap_context(self, context):
def _then(response):
context.results.s = response.x + "_foo"
context.results.outBox.cap = self.capability.TestInterface._new_server(
Server(100)
)
def _error(error):
async def getCap_context(self, context):
try:
await context.params.inCap.foo(i=context.params.n)
except capnp.KjException:
raise Exception("test was a success")
return context.params.inCap.foo(i=context.params.n).then(_then, _error)
async def test_exception_chain_context(capability):
client = capability.TestPipeline._new_client(BadPipelineServer(capability))
@ -186,7 +178,7 @@ async def test_exception_chain_context(capability):
remote = client.getCap(n=5, inCap=foo_client)
try:
remote.wait()
await remote
except Exception as e:
assert "test was a success" in str(e)
@ -201,10 +193,10 @@ async def test_pipeline_exception_context(capability):
pipelinePromise = outCap.foo(i=10)
with pytest.raises(Exception):
pipelinePromise.wait()
await pipelinePromise
with pytest.raises(Exception):
remote.wait()
await remote
async def test_casting_context(capability):
@ -220,7 +212,7 @@ class TailCallOrder:
def __init__(self):
self.count = -1
def getCallSequence_context(self, context):
async def getCallSequence_context(self, context):
self.count += 1
context.results.n = self.count
@ -229,13 +221,13 @@ class TailCaller:
def __init__(self):
self.count = 0
def foo_context(self, context):
async def foo_context(self, context):
self.count += 1
tail = context.params.callee.foo_request(
i=context.params.i, t="from TailCaller"
)
return context.tail_call(tail)
await context.tail_call(tail)
class TailCallee:
@ -243,7 +235,7 @@ class TailCallee:
self.count = 0
self.capability = capability
def foo_context(self, context):
async def foo_context(self, context):
self.count += 1
results = context.results
@ -262,7 +254,7 @@ async def test_tail_call(capability):
promise = caller.foo(i=456, callee=callee)
dependent_call1 = promise.c.getCallSequence()
response = promise.wait()
response = await promise
assert response.i == 456
assert response.i == 456
@ -270,11 +262,11 @@ async def test_tail_call(capability):
dependent_call2 = response.c.getCallSequence()
dependent_call3 = response.c.getCallSequence()
result = dependent_call1.wait()
result = await dependent_call1
assert result.n == 0
result = dependent_call2.wait()
result = await dependent_call2
assert result.n == 1
result = dependent_call3.wait()
result = await dependent_call3
assert result.n == 2
assert callee_server.count == 1

View file

@ -17,13 +17,13 @@ class Server:
def __init__(self, val=1):
self.val = val
def foo(self, i, j, **kwargs):
async def foo(self, i, j, **kwargs):
extra = 0
if j:
extra = 1
return str(i * 5 + extra + self.val)
def buz(self, i, **kwargs):
async def buz(self, i, **kwargs):
return i.host + "_test"
@ -31,13 +31,11 @@ class PipelineServer:
def __init__(self, capability):
self.capability = capability
def getCap(self, n, inCap, _context, **kwargs):
def _then(response):
_results = _context.results
_results.s = response.x + "_foo"
_results.outBox.cap = self.capability.TestInterface._new_server(Server(100))
return inCap.foo(i=n).then(_then)
async def getCap(self, n, inCap, _context, **kwargs):
response = await inCap.foo(i=n)
_results = _context.results
_results.s = response.x + "_foo"
_results.outBox.cap = self.capability.TestInterface._new_server(Server(100))
async def test_client(capability):
@ -47,7 +45,7 @@ async def test_client(capability):
req.i = 5
remote = req.send()
response = remote.wait()
response = await remote
assert response.x == "26"
@ -55,7 +53,7 @@ async def test_client(capability):
req.i = 5
remote = req.send()
response = remote.wait()
response = await remote
assert response.x == "26"
@ -77,37 +75,37 @@ async def test_simple_client(capability):
client = capability.TestInterface._new_client(Server())
remote = client._send("foo", i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(i=5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(i=5, j=True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.foo(5)
response = remote.wait()
response = await remote
assert response.x == "26"
remote = client.foo(5, True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.foo(5, j=True)
response = remote.wait()
response = await remote
assert response.x == "27"
remote = client.buz(capability.TestSturdyRefHostId.new_message(host="localhost"))
response = remote.wait()
response = await remote
assert response.x == "localhost_test"
@ -136,10 +134,10 @@ async def test_pipeline(capability):
outCap = remote.outBox.cap
pipelinePromise = outCap.foo(i=10)
response = pipelinePromise.wait()
response = await pipelinePromise
assert response.x == "150"
response = remote.wait()
response = await remote
assert response.s == "26_foo"
@ -147,7 +145,7 @@ class BadServer:
def __init__(self, val=1):
self.val = val
def foo(self, i, j, **kwargs):
async def foo(self, i, j, **kwargs):
extra = 0
if j:
extra = 1
@ -159,24 +157,19 @@ async def test_exception_client(capability):
remote = client._send("foo", i=5)
with pytest.raises(capnp.KjException):
remote.wait()
await remote
class BadPipelineServer:
def __init__(self, capability):
self.capability = capability
def getCap(self, n, inCap, _context, **kwargs):
def _then(response):
_results = _context.results
_results.s = response.x + "_foo"
_results.outBox.cap = self.capability.TestInterface._new_server(Server(100))
def _error(error):
async def getCap(self, n, inCap, _context, **kwargs):
try:
await inCap.foo(i=n)
except capnp.KjException:
raise Exception("test was a success")
return inCap.foo(i=n).then(_then, _error)
async def test_exception_chain(capability):
client = capability.TestPipeline._new_client(BadPipelineServer(capability))
@ -185,7 +178,7 @@ async def test_exception_chain(capability):
remote = client.getCap(n=5, inCap=foo_client)
try:
remote.wait()
await remote
except Exception as e:
assert "test was a success" in str(e)
@ -200,10 +193,10 @@ async def test_pipeline_exception(capability):
pipelinePromise = outCap.foo(i=10)
with pytest.raises(Exception):
pipelinePromise.wait()
await pipelinePromise
with pytest.raises(Exception):
remote.wait()
await remote
async def test_casting(capability):
@ -219,7 +212,7 @@ class TailCallOrder:
def __init__(self):
self.count = -1
def getCallSequence(self, expected, **kwargs):
async def getCallSequence(self, expected, **kwargs):
self.count += 1
return self.count
@ -228,11 +221,11 @@ class TailCaller:
def __init__(self):
self.count = 0
def foo(self, i, callee, _context, **kwargs):
async def foo(self, i, callee, _context, **kwargs):
self.count += 1
tail = callee.foo_request(i=i, t="from TailCaller")
return _context.tail_call(tail)
await _context.tail_call(tail)
class TailCallee:
@ -240,7 +233,7 @@ class TailCallee:
self.count = 0
self.capability = capability
def foo(self, i, t, _context, **kwargs):
async def foo(self, i, t, _context, **kwargs):
self.count += 1
results = _context.results
@ -259,7 +252,7 @@ async def test_tail_call(capability):
promise = caller.foo(i=456, callee=callee)
dependent_call1 = promise.c.getCallSequence()
response = promise.wait()
response = await promise
assert response.i == 456
assert response.i == 456
@ -267,11 +260,11 @@ async def test_tail_call(capability):
dependent_call2 = response.c.getCallSequence()
dependent_call3 = response.c.getCallSequence()
result = dependent_call1.wait()
result = await dependent_call1
assert result.n == 0
result = dependent_call2.wait()
result = await dependent_call2
assert result.n == 1
result = dependent_call3.wait()
result = await dependent_call3
assert result.n == 2
assert callee_server.count == 1

View file

@ -5,7 +5,7 @@ class FooServer(test_response_capnp.Foo.Server):
def __init__(self, val=1):
self.val = val
def foo(self, **kwargs):
async def foo(self, **kwargs):
return 1
@ -13,27 +13,27 @@ class BazServer(test_response_capnp.Baz.Server):
def __init__(self, val=1):
self.val = val
def grault(self, **kwargs):
async def grault(self, **kwargs):
return {"foo": FooServer()}
async def test_response_reference():
baz = test_response_capnp.Baz._new_client(BazServer())
bar = baz.grault().wait().bar
bar = (await baz.grault()).bar
foo = bar.foo
# This used to cause an exception about invalid pointers because the response got garbage collected
assert foo.foo().wait().val == 1
assert (await foo.foo()).val == 1
async def test_response_reference2():
baz = test_response_capnp.Baz._new_client(BazServer())
bar = baz.grault().wait().bar
bar = (await baz.grault()).bar
# This always worked since it saved the intermediate response object
response = baz.grault().wait()
response = await baz.grault()
bar = response.bar
foo = bar.foo
assert foo.foo().wait().val == 1
assert (await foo.foo()).val == 1

View file

@ -13,7 +13,7 @@ class Server(test_capability_capnp.TestInterface.Server):
def __init__(self, val=100):
self.val = val
def foo(self, i, j, **kwargs):
async def foo(self, i, j, **kwargs):
return str(i * 5 + self.val)