Applying black formatting

- Fixing flake8 configuration to agree with black
- Adding black validation check to github actions
This commit is contained in:
Jacob Alexander 2021-10-01 11:00:22 -07:00
parent 5dade41aeb
commit 6e7fffd7de
Failed to generate hash of commit
51 changed files with 2536 additions and 1837 deletions

6
.flake8 Normal file
View file

@ -0,0 +1,6 @@
[flake8]
max-line-length = 120
extend-ignore = E203,E211,E225,E226,E227,E231,E251,E261,E262,E265,E402,E999
max-complexity = 10
per-file-ignores =
test/test_examples.py: C901

View file

@ -29,11 +29,12 @@ jobs:
run: | run: |
python setup.py build python setup.py build
pip install . pip install .
- name: Lint with flake8 - name: Lint with flake8 and check black
run: | run: |
pip install flake8 pip install black flake8
flake8 . --filename '*.py,*.pyx,*.pxd' --count --max-complexity=10 --max-line-length=120 --ignore=E211,E225,E226,E227,E231,E251,E261,E262,E265,E402,E999 --show-source --statistics --exclude benchmark,build,capnp/templates/module.pyx flake8 . --filename '*.py,*.pyx,*.pxd' --count --show-source --statistics --exclude benchmark,build,capnp/templates/module.pyx
flake8 . --count --max-complexity=10 --max-line-length=120 --show-source --statistics --exclude benchmark,build flake8 . --count --show-source --statistics --exclude benchmark,build
black . --check --diff --color
- name: Packaging - name: Packaging
run: | run: |
python setup.py bdist_wheel python setup.py bdist_wheel

View file

@ -2,50 +2,50 @@ import os
import capnp import capnp
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
addressbook = capnp.load(os.path.join(this_dir, 'addressbook.capnp')) addressbook = capnp.load(os.path.join(this_dir, "addressbook.capnp"))
print = lambda *x: x print = lambda *x: x
def writeAddressBook(): def writeAddressBook():
addressBook = addressbook.AddressBook.new_message() addressBook = addressbook.AddressBook.new_message()
people = addressBook.init_resizable_list('people') people = addressBook.init_resizable_list("people")
alice = people.add() alice = people.add()
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
bob = people.add() bob = people.add()
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
people.finish() people.finish()
msg_bytes = addressBook.to_bytes() msg_bytes = addressBook.to_bytes()
return msg_bytes return msg_bytes
def printAddressBook(msg_bytes): def printAddressBook(msg_bytes):
addressBook = addressbook.AddressBook.from_bytes(msg_bytes) addressBook = addressbook.AddressBook.from_bytes(msg_bytes)
for person in addressBook.people: for person in addressBook.people:
print(person.name, ':', person.email) print(person.name, ":", person.email)
for phone in person.phones: for phone in person.phones:
print(phone.type, ':', phone.number) print(phone.type, ":", phone.number)
print() print()
if __name__ == '__main__': if __name__ == "__main__":
for i in range(10000): for i in range(10000):
msg_bytes = writeAddressBook() msg_bytes = writeAddressBook()
printAddressBook(msg_bytes) printAddressBook(msg_bytes)

View file

@ -6,35 +6,38 @@ try:
except: except:
profile = lambda func: func profile = lambda func: func
this_dir = os.path.dirname(__file__) this_dir = os.path.dirname(__file__)
addressbook = capnp.load(os.path.join(this_dir, 'addressbook.capnp')) addressbook = capnp.load(os.path.join(this_dir, "addressbook.capnp"))
print = lambda *x: x print = lambda *x: x
@profile @profile
def writeAddressBook(): def writeAddressBook():
addressBook = addressbook.AddressBook.new_message() addressBook = addressbook.AddressBook.new_message()
people = addressBook.init('people', 2) people = addressBook.init("people", 2)
alice = people[0] alice = people[0]
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
bob = people[1] bob = people[1]
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
msg_bytes = addressBook.to_bytes() msg_bytes = addressBook.to_bytes()
return msg_bytes return msg_bytes
@profile @profile
def printAddressBook(msg_bytes): def printAddressBook(msg_bytes):
addressBook = addressbook.AddressBook.from_bytes(msg_bytes) addressBook = addressbook.AddressBook.from_bytes(msg_bytes)
@ -44,31 +47,34 @@ def printAddressBook(msg_bytes):
for phone in person.phones: for phone in person.phones:
phone.type, phone.number phone.type, phone.number
@profile @profile
def writeAddressBookDict(): def writeAddressBookDict():
addressBook = addressbook.AddressBook.new_message() addressBook = addressbook.AddressBook.new_message()
people = addressBook.init('people', 2) people = addressBook.init("people", 2)
alice = people[0] alice = people[0]
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
bob = people[1] bob = people[1]
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
msg = addressBook.to_dict() msg = addressBook.to_dict()
return msg return msg
@profile @profile
def printAddressBookDict(msg): def printAddressBookDict(msg):
addressBook = addressbook.AddressBook.new_message(**msg) addressBook = addressbook.AddressBook.new_message(**msg)
@ -79,7 +85,7 @@ def printAddressBookDict(msg):
phone.type, phone.number phone.type, phone.number
if __name__ == '__main__': if __name__ == "__main__":
# for i in range(10000): # for i in range(10000):
# msg_bytes = writeAddressBook() # msg_bytes = writeAddressBook()
@ -88,4 +94,3 @@ if __name__ == '__main__':
msg = writeAddressBookDict() msg = writeAddressBookDict()
printAddressBookDict(msg) printAddressBookDict(msg)

View file

@ -9,16 +9,16 @@ def writeAddressBook():
alice = addressBook.person.add() alice = addressBook.person.add()
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = [alice.phone.add()] alicePhones = [alice.phone.add()]
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = addressbook.Person.MOBILE alicePhones[0].type = addressbook.Person.MOBILE
bob = addressBook.person.add() bob = addressBook.person.add()
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = [bob.phone.add(), bob.phone.add()] bobPhones = [bob.phone.add(), bob.phone.add()]
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = addressbook.Person.HOME bobPhones[0].type = addressbook.Person.HOME
@ -34,15 +34,14 @@ def printAddressBook(message_string):
addressBook.ParseFromString(message_string) addressBook.ParseFromString(message_string)
for person in addressBook.person: for person in addressBook.person:
print(person.name, ':', person.email) print(person.name, ":", person.email)
for phone in person.phone: for phone in person.phone:
print(phone.type, ':', phone.number) print(phone.type, ":", phone.number)
print() print()
if __name__ == '__main__': if __name__ == "__main__":
for i in range(10000): for i in range(10000):
message_string = writeAddressBook() message_string = writeAddressBook()
printAddressBook(message_string) printAddressBook(message_string)

View file

@ -2,204 +2,276 @@
# source: addressbook.proto # source: addressbook.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1"))
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='addressbook.proto', name="addressbook.proto",
package='tutorial', package="tutorial",
syntax='proto2', syntax="proto2",
serialized_pb=_b('\n\x11\x61\x64\x64ressbook.proto\x12\x08tutorial\"\xda\x01\n\x06Person\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\n\n\x02id\x18\x02 \x02(\x05\x12\r\n\x05\x65mail\x18\x03 \x02(\t\x12+\n\x05phone\x18\x04 \x03(\x0b\x32\x1c.tutorial.Person.PhoneNumber\x1aM\n\x0bPhoneNumber\x12\x0e\n\x06number\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x01(\x0e\x32\x1a.tutorial.Person.PhoneType:\x04HOME\"+\n\tPhoneType\x12\n\n\x06MOBILE\x10\x00\x12\x08\n\x04HOME\x10\x01\x12\x08\n\x04WORK\x10\x02\"/\n\x0b\x41\x64\x64ressBook\x12 \n\x06person\x18\x01 \x03(\x0b\x32\x10.tutorial.Person') serialized_pb=_b(
'\n\x11\x61\x64\x64ressbook.proto\x12\x08tutorial"\xda\x01\n\x06Person\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\n\n\x02id\x18\x02 \x02(\x05\x12\r\n\x05\x65mail\x18\x03 \x02(\t\x12+\n\x05phone\x18\x04 \x03(\x0b\x32\x1c.tutorial.Person.PhoneNumber\x1aM\n\x0bPhoneNumber\x12\x0e\n\x06number\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x01(\x0e\x32\x1a.tutorial.Person.PhoneType:\x04HOME"+\n\tPhoneType\x12\n\n\x06MOBILE\x10\x00\x12\x08\n\x04HOME\x10\x01\x12\x08\n\x04WORK\x10\x02"/\n\x0b\x41\x64\x64ressBook\x12 \n\x06person\x18\x01 \x03(\x0b\x32\x10.tutorial.Person'
),
) )
_PERSON_PHONETYPE = _descriptor.EnumDescriptor( _PERSON_PHONETYPE = _descriptor.EnumDescriptor(
name='PhoneType', name="PhoneType",
full_name='tutorial.Person.PhoneType', full_name="tutorial.Person.PhoneType",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='MOBILE', index=0, number=0, name="MOBILE", index=0, number=0, options=None, type=None
options=None, ),
type=None), _descriptor.EnumValueDescriptor(
_descriptor.EnumValueDescriptor( name="HOME", index=1, number=1, options=None, type=None
name='HOME', index=1, number=1, ),
options=None, _descriptor.EnumValueDescriptor(
type=None), name="WORK", index=2, number=2, options=None, type=None
_descriptor.EnumValueDescriptor( ),
name='WORK', index=2, number=2, ],
options=None, containing_type=None,
type=None), options=None,
], serialized_start=207,
containing_type=None, serialized_end=250,
options=None,
serialized_start=207,
serialized_end=250,
) )
_sym_db.RegisterEnumDescriptor(_PERSON_PHONETYPE) _sym_db.RegisterEnumDescriptor(_PERSON_PHONETYPE)
_PERSON_PHONENUMBER = _descriptor.Descriptor( _PERSON_PHONENUMBER = _descriptor.Descriptor(
name='PhoneNumber', name="PhoneNumber",
full_name='tutorial.Person.PhoneNumber', full_name="tutorial.Person.PhoneNumber",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='number', full_name='tutorial.Person.PhoneNumber.number', index=0, name="number",
number=1, type=9, cpp_type=9, label=2, full_name="tutorial.Person.PhoneNumber.number",
has_default_value=False, default_value=_b("").decode('utf-8'), index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=9,
options=None), cpp_type=9,
_descriptor.FieldDescriptor( label=2,
name='type', full_name='tutorial.Person.PhoneNumber.type', index=1, has_default_value=False,
number=2, type=14, cpp_type=8, label=1, default_value=_b("").decode("utf-8"),
has_default_value=True, default_value=1, message_type=None,
message_type=None, enum_type=None, containing_type=None, enum_type=None,
is_extension=False, extension_scope=None, containing_type=None,
options=None), is_extension=False,
], extension_scope=None,
extensions=[ options=None,
], ),
nested_types=[], _descriptor.FieldDescriptor(
enum_types=[ name="type",
], full_name="tutorial.Person.PhoneNumber.type",
options=None, index=1,
is_extendable=False, number=2,
syntax='proto2', type=14,
extension_ranges=[], cpp_type=8,
oneofs=[ label=1,
], has_default_value=True,
serialized_start=128, default_value=1,
serialized_end=205, message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None,
),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=128,
serialized_end=205,
) )
_PERSON = _descriptor.Descriptor( _PERSON = _descriptor.Descriptor(
name='Person', name="Person",
full_name='tutorial.Person', full_name="tutorial.Person",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='tutorial.Person.name', index=0, name="name",
number=1, type=9, cpp_type=9, label=2, full_name="tutorial.Person.name",
has_default_value=False, default_value=_b("").decode('utf-8'), index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=9,
options=None), cpp_type=9,
_descriptor.FieldDescriptor( label=2,
name='id', full_name='tutorial.Person.id', index=1, has_default_value=False,
number=2, type=5, cpp_type=1, label=2, default_value=_b("").decode("utf-8"),
has_default_value=False, default_value=0, message_type=None,
message_type=None, enum_type=None, containing_type=None, enum_type=None,
is_extension=False, extension_scope=None, containing_type=None,
options=None), is_extension=False,
_descriptor.FieldDescriptor( extension_scope=None,
name='email', full_name='tutorial.Person.email', index=2, options=None,
number=3, type=9, cpp_type=9, label=2, ),
has_default_value=False, default_value=_b("").decode('utf-8'), _descriptor.FieldDescriptor(
message_type=None, enum_type=None, containing_type=None, name="id",
is_extension=False, extension_scope=None, full_name="tutorial.Person.id",
options=None), index=1,
_descriptor.FieldDescriptor( number=2,
name='phone', full_name='tutorial.Person.phone', index=3, type=5,
number=4, type=11, cpp_type=10, label=3, cpp_type=1,
has_default_value=False, default_value=[], label=2,
message_type=None, enum_type=None, containing_type=None, has_default_value=False,
is_extension=False, extension_scope=None, default_value=0,
options=None), message_type=None,
], enum_type=None,
extensions=[ containing_type=None,
], is_extension=False,
nested_types=[_PERSON_PHONENUMBER, ], extension_scope=None,
enum_types=[ options=None,
_PERSON_PHONETYPE, ),
], _descriptor.FieldDescriptor(
options=None, name="email",
is_extendable=False, full_name="tutorial.Person.email",
syntax='proto2', index=2,
extension_ranges=[], number=3,
oneofs=[ type=9,
], cpp_type=9,
serialized_start=32, label=2,
serialized_end=250, has_default_value=False,
default_value=_b("").decode("utf-8"),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None,
),
_descriptor.FieldDescriptor(
name="phone",
full_name="tutorial.Person.phone",
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None,
),
],
extensions=[],
nested_types=[
_PERSON_PHONENUMBER,
],
enum_types=[
_PERSON_PHONETYPE,
],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=32,
serialized_end=250,
) )
_ADDRESSBOOK = _descriptor.Descriptor( _ADDRESSBOOK = _descriptor.Descriptor(
name='AddressBook', name="AddressBook",
full_name='tutorial.AddressBook', full_name="tutorial.AddressBook",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='person', full_name='tutorial.AddressBook.person', index=0, name="person",
number=1, type=11, cpp_type=10, label=3, full_name="tutorial.AddressBook.person",
has_default_value=False, default_value=[], index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=11,
options=None), cpp_type=10,
], label=3,
extensions=[ has_default_value=False,
], default_value=[],
nested_types=[], message_type=None,
enum_types=[ enum_type=None,
], containing_type=None,
options=None, is_extension=False,
is_extendable=False, extension_scope=None,
syntax='proto2', options=None,
extension_ranges=[], ),
oneofs=[ ],
], extensions=[],
serialized_start=252, nested_types=[],
serialized_end=299, enum_types=[],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=252,
serialized_end=299,
) )
_PERSON_PHONENUMBER.fields_by_name['type'].enum_type = _PERSON_PHONETYPE _PERSON_PHONENUMBER.fields_by_name["type"].enum_type = _PERSON_PHONETYPE
_PERSON_PHONENUMBER.containing_type = _PERSON _PERSON_PHONENUMBER.containing_type = _PERSON
_PERSON.fields_by_name['phone'].message_type = _PERSON_PHONENUMBER _PERSON.fields_by_name["phone"].message_type = _PERSON_PHONENUMBER
_PERSON_PHONETYPE.containing_type = _PERSON _PERSON_PHONETYPE.containing_type = _PERSON
_ADDRESSBOOK.fields_by_name['person'].message_type = _PERSON _ADDRESSBOOK.fields_by_name["person"].message_type = _PERSON
DESCRIPTOR.message_types_by_name['Person'] = _PERSON DESCRIPTOR.message_types_by_name["Person"] = _PERSON
DESCRIPTOR.message_types_by_name['AddressBook'] = _ADDRESSBOOK DESCRIPTOR.message_types_by_name["AddressBook"] = _ADDRESSBOOK
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
Person = _reflection.GeneratedProtocolMessageType('Person', (_message.Message,), dict( Person = _reflection.GeneratedProtocolMessageType(
"Person",
PhoneNumber = _reflection.GeneratedProtocolMessageType('PhoneNumber', (_message.Message,), dict( (_message.Message,),
DESCRIPTOR = _PERSON_PHONENUMBER, dict(
__module__ = 'addressbook_pb2' PhoneNumber=_reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:tutorial.Person.PhoneNumber) "PhoneNumber",
)) (_message.Message,),
, dict(
DESCRIPTOR = _PERSON, DESCRIPTOR=_PERSON_PHONENUMBER,
__module__ = 'addressbook_pb2' __module__="addressbook_pb2"
# @@protoc_insertion_point(class_scope:tutorial.Person) # @@protoc_insertion_point(class_scope:tutorial.Person.PhoneNumber)
)) ),
),
DESCRIPTOR=_PERSON,
__module__="addressbook_pb2"
# @@protoc_insertion_point(class_scope:tutorial.Person)
),
)
_sym_db.RegisterMessage(Person) _sym_db.RegisterMessage(Person)
_sym_db.RegisterMessage(Person.PhoneNumber) _sym_db.RegisterMessage(Person.PhoneNumber)
AddressBook = _reflection.GeneratedProtocolMessageType('AddressBook', (_message.Message,), dict( AddressBook = _reflection.GeneratedProtocolMessageType(
DESCRIPTOR = _ADDRESSBOOK, "AddressBook",
__module__ = 'addressbook_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:tutorial.AddressBook) dict(
)) DESCRIPTOR=_ADDRESSBOOK,
__module__="addressbook_pb2"
# @@protoc_insertion_point(class_scope:tutorial.AddressBook)
),
)
_sym_db.RegisterMessage(AddressBook) _sym_db.RegisterMessage(AddressBook)

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,18 @@ from random import choice
MAKES = ["Toyota", "GM", "Ford", "Honda", "Tesla"] MAKES = ["Toyota", "GM", "Ford", "Honda", "Tesla"]
MODELS = ["Camry", "Prius", "Volt", "Accord", "Leaf", "Model S"] MODELS = ["Camry", "Prius", "Volt", "Accord", "Leaf", "Model S"]
COLORS = ["black", "white", "red", "green", "blue", "cyan", "magenta", "yellow", "silver"] COLORS = [
"black",
"white",
"red",
"green",
"blue",
"cyan",
"magenta",
"yellow",
"silver",
]
def random_car(car): def random_car(car):
car.make = choice(MAKES) car.make = choice(MAKES)
@ -42,6 +53,7 @@ def random_car(car):
car.cup_holders = rand_int(12) car.cup_holders = rand_int(12)
car.has_nav_system = rand_bool() car.has_nav_system = rand_bool()
def calc_value(car): def calc_value(car):
result = 0 result = 0
@ -57,9 +69,9 @@ def calc_value(car):
result += engine.horsepower * 40 result += engine.horsepower * 40
if engine.uses_electric: if engine.uses_electric:
if engine.uses_gas: if engine.uses_gas:
result += 5000 result += 5000
else: else:
result += 3000 result += 3000
result += 100 if car.has_power_windows else 0 result += 100 if car.has_power_windows else 0
result += 200 if car.has_power_steering else 0 result += 200 if car.has_power_steering else 0
@ -70,6 +82,7 @@ def calc_value(car):
return result return result
class Benchmark: class Benchmark:
def __init__(self, compression): def __init__(self, compression):
self.Request = carsales_pb2.ParkingLot self.Request = carsales_pb2.ParkingLot
@ -81,17 +94,17 @@ class Benchmark:
def setup(self, request): def setup(self, request):
result = 0 result = 0
for _ in range(rand_int(200)): for _ in range(rand_int(200)):
car = request.car.add() car = request.car.add()
random_car(car) random_car(car)
result += calc_value(car) result += calc_value(car)
return result return result
def handle(self, request, response): def handle(self, request, response):
result = 0 result = 0
for car in request.car: for car in request.car:
result += calc_value(car) result += calc_value(car)
response.amount = result response.amount = result
def check(self, response, expected): def check(self, response, expected):
return response.amount == expected return response.amount == expected

View file

@ -7,7 +7,18 @@ from random import choice
MAKES = ["Toyota", "GM", "Ford", "Honda", "Tesla"] MAKES = ["Toyota", "GM", "Ford", "Honda", "Tesla"]
MODELS = ["Camry", "Prius", "Volt", "Accord", "Leaf", "Model S"] MODELS = ["Camry", "Prius", "Volt", "Accord", "Leaf", "Model S"]
COLORS = ["black", "white", "red", "green", "blue", "cyan", "magenta", "yellow", "silver"] COLORS = [
"black",
"white",
"red",
"green",
"blue",
"cyan",
"magenta",
"yellow",
"silver",
]
def random_car(car): def random_car(car):
car.make = choice(MAKES) car.make = choice(MAKES)
@ -17,7 +28,7 @@ def random_car(car):
car.seats = 2 + rand_int(6) car.seats = 2 + rand_int(6)
car.doors = 2 + rand_int(3) car.doors = 2 + rand_int(3)
for wheel in car.init('wheels', 4): for wheel in car.init("wheels", 4):
wheel.diameter = 25 + rand_int(15) wheel.diameter = 25 + rand_int(15)
wheel.airPressure = 30 + rand_double(20) wheel.airPressure = 30 + rand_double(20)
wheel.snowTires = rand_int(16) == 0 wheel.snowTires = rand_int(16) == 0
@ -27,7 +38,7 @@ def random_car(car):
car.height = 54 + rand_int(48) car.height = 54 + rand_int(48)
car.weight = car.length * car.width * car.height // 200 car.weight = car.length * car.width * car.height // 200
engine = car.init('engine') engine = car.init("engine")
engine.horsepower = 100 * rand_int(400) engine.horsepower = 100 * rand_int(400)
engine.cylinders = 4 + 2 * rand_int(3) engine.cylinders = 4 + 2 * rand_int(3)
engine.cc = 800 + rand_int(10000) engine.cc = 800 + rand_int(10000)
@ -42,6 +53,7 @@ def random_car(car):
car.cupHolders = rand_int(12) car.cupHolders = rand_int(12)
car.hasNavSystem = rand_bool() car.hasNavSystem = rand_bool()
def calc_value(car): def calc_value(car):
result = 0 result = 0
@ -57,9 +69,9 @@ def calc_value(car):
result += engine.horsepower * 40 result += engine.horsepower * 40
if engine.usesElectric: if engine.usesElectric:
if engine.usesGas: if engine.usesGas:
result += 5000 result += 5000
else: else:
result += 3000 result += 3000
result += 100 if car.hasPowerWindows else 0 result += 100 if car.hasPowerWindows else 0
result += 200 if car.hasPowerSteering else 0 result += 200 if car.hasPowerSteering else 0
@ -70,11 +82,12 @@ def calc_value(car):
return result return result
class Benchmark: class Benchmark:
def __init__(self, compression): def __init__(self, compression):
self.Request = carsales_capnp.ParkingLot.new_message self.Request = carsales_capnp.ParkingLot.new_message
self.Response = carsales_capnp.TotalValue.new_message self.Response = carsales_capnp.TotalValue.new_message
if compression == 'packed': if compression == "packed":
self.from_bytes_request = carsales_capnp.ParkingLot.from_bytes_packed self.from_bytes_request = carsales_capnp.ParkingLot.from_bytes_packed
self.from_bytes_response = carsales_capnp.TotalValue.from_bytes_packed self.from_bytes_response = carsales_capnp.TotalValue.from_bytes_packed
self.to_bytes = lambda x: x.to_bytes_packed() self.to_bytes = lambda x: x.to_bytes_packed()
@ -85,17 +98,17 @@ class Benchmark:
def setup(self, request): def setup(self, request):
result = 0 result = 0
for car in request.init('cars', rand_int(200)): for car in request.init("cars", rand_int(200)):
random_car(car) random_car(car)
result += calc_value(car) result += calc_value(car)
return result return result
def handle(self, request, response): def handle(self, request, response):
result = 0 result = 0
for car in request.cars: for car in request.cars:
result += calc_value(car) result += calc_value(car)
response.amount = result response.amount = result
def check(self, response, expected): def check(self, response, expected):
return response.amount == expected return response.amount == expected

View file

@ -2,121 +2,163 @@
# source: catrank.proto # source: catrank.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1"))
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='catrank.proto', name="catrank.proto",
package='capnp.benchmark.protobuf', package="capnp.benchmark.protobuf",
syntax='proto2', syntax="proto2",
serialized_pb=_b('\n\rcatrank.proto\x12\x18\x63\x61pnp.benchmark.protobuf\"J\n\x10SearchResultList\x12\x36\n\x06result\x18\x01 \x03(\x0b\x32&.capnp.benchmark.protobuf.SearchResult\";\n\x0cSearchResult\x12\x0b\n\x03url\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x01\x12\x0f\n\x07snippet\x18\x03 \x01(\t') serialized_pb=_b(
'\n\rcatrank.proto\x12\x18\x63\x61pnp.benchmark.protobuf"J\n\x10SearchResultList\x12\x36\n\x06result\x18\x01 \x03(\x0b\x32&.capnp.benchmark.protobuf.SearchResult";\n\x0cSearchResult\x12\x0b\n\x03url\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x01\x12\x0f\n\x07snippet\x18\x03 \x01(\t'
),
) )
_SEARCHRESULTLIST = _descriptor.Descriptor( _SEARCHRESULTLIST = _descriptor.Descriptor(
name='SearchResultList', name="SearchResultList",
full_name='capnp.benchmark.protobuf.SearchResultList', full_name="capnp.benchmark.protobuf.SearchResultList",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='result', full_name='capnp.benchmark.protobuf.SearchResultList.result', index=0, name="result",
number=1, type=11, cpp_type=10, label=3, full_name="capnp.benchmark.protobuf.SearchResultList.result",
has_default_value=False, default_value=[], index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=11,
options=None), cpp_type=10,
], label=3,
extensions=[ has_default_value=False,
], default_value=[],
nested_types=[], message_type=None,
enum_types=[ enum_type=None,
], containing_type=None,
options=None, is_extension=False,
is_extendable=False, extension_scope=None,
syntax='proto2', options=None,
extension_ranges=[], ),
oneofs=[ ],
], extensions=[],
serialized_start=43, nested_types=[],
serialized_end=117, enum_types=[],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=43,
serialized_end=117,
) )
_SEARCHRESULT = _descriptor.Descriptor( _SEARCHRESULT = _descriptor.Descriptor(
name='SearchResult', name="SearchResult",
full_name='capnp.benchmark.protobuf.SearchResult', full_name="capnp.benchmark.protobuf.SearchResult",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='url', full_name='capnp.benchmark.protobuf.SearchResult.url', index=0, name="url",
number=1, type=9, cpp_type=9, label=1, full_name="capnp.benchmark.protobuf.SearchResult.url",
has_default_value=False, default_value=_b("").decode('utf-8'), index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=9,
options=None), cpp_type=9,
_descriptor.FieldDescriptor( label=1,
name='score', full_name='capnp.benchmark.protobuf.SearchResult.score', index=1, has_default_value=False,
number=2, type=1, cpp_type=5, label=1, default_value=_b("").decode("utf-8"),
has_default_value=False, default_value=float(0), message_type=None,
message_type=None, enum_type=None, containing_type=None, enum_type=None,
is_extension=False, extension_scope=None, containing_type=None,
options=None), is_extension=False,
_descriptor.FieldDescriptor( extension_scope=None,
name='snippet', full_name='capnp.benchmark.protobuf.SearchResult.snippet', index=2, options=None,
number=3, type=9, cpp_type=9, label=1, ),
has_default_value=False, default_value=_b("").decode('utf-8'), _descriptor.FieldDescriptor(
message_type=None, enum_type=None, containing_type=None, name="score",
is_extension=False, extension_scope=None, full_name="capnp.benchmark.protobuf.SearchResult.score",
options=None), index=1,
], number=2,
extensions=[ type=1,
], cpp_type=5,
nested_types=[], label=1,
enum_types=[ has_default_value=False,
], default_value=float(0),
options=None, message_type=None,
is_extendable=False, enum_type=None,
syntax='proto2', containing_type=None,
extension_ranges=[], is_extension=False,
oneofs=[ extension_scope=None,
], options=None,
serialized_start=119, ),
serialized_end=178, _descriptor.FieldDescriptor(
name="snippet",
full_name="capnp.benchmark.protobuf.SearchResult.snippet",
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode("utf-8"),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None,
),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=119,
serialized_end=178,
) )
_SEARCHRESULTLIST.fields_by_name['result'].message_type = _SEARCHRESULT _SEARCHRESULTLIST.fields_by_name["result"].message_type = _SEARCHRESULT
DESCRIPTOR.message_types_by_name['SearchResultList'] = _SEARCHRESULTLIST DESCRIPTOR.message_types_by_name["SearchResultList"] = _SEARCHRESULTLIST
DESCRIPTOR.message_types_by_name['SearchResult'] = _SEARCHRESULT DESCRIPTOR.message_types_by_name["SearchResult"] = _SEARCHRESULT
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
SearchResultList = _reflection.GeneratedProtocolMessageType('SearchResultList', (_message.Message,), dict( SearchResultList = _reflection.GeneratedProtocolMessageType(
DESCRIPTOR = _SEARCHRESULTLIST, "SearchResultList",
__module__ = 'catrank_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.SearchResultList) dict(
)) DESCRIPTOR=_SEARCHRESULTLIST,
__module__="catrank_pb2"
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.SearchResultList)
),
)
_sym_db.RegisterMessage(SearchResultList) _sym_db.RegisterMessage(SearchResultList)
SearchResult = _reflection.GeneratedProtocolMessageType('SearchResult', (_message.Message,), dict( SearchResult = _reflection.GeneratedProtocolMessageType(
DESCRIPTOR = _SEARCHRESULT, "SearchResult",
__module__ = 'catrank_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.SearchResult) dict(
)) DESCRIPTOR=_SEARCHRESULT,
__module__="catrank_pb2"
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.SearchResult)
),
)
_sym_db.RegisterMessage(SearchResult) _sym_db.RegisterMessage(SearchResult)

View file

@ -3,13 +3,15 @@
from common import rand_int, rand_double, rand_bool, WORDS, from_bytes_helper from common import rand_int, rand_double, rand_bool, WORDS, from_bytes_helper
from random import choice from random import choice
from string import ascii_letters from string import ascii_letters
try: try:
# Python 2 # Python 2
from itertools import izip from itertools import izip
except ImportError: except ImportError:
izip = zip izip = zip
import catrank_pb2 import catrank_pb2
class Benchmark: class Benchmark:
def __init__(self, compression): def __init__(self, compression):
self.Request = catrank_pb2.SearchResultList self.Request = catrank_pb2.SearchResultList
@ -26,7 +28,9 @@ class Benchmark:
result = request.result.add() result = request.result.add()
result.score = 1000 - i result.score = 1000 - i
url_size = rand_int(100) url_size = rand_int(100)
result.url = "http://example.com/" + ''.join([choice(ascii_letters) for _ in range(url_size)]) result.url = "http://example.com/" + "".join(
[choice(ascii_letters) for _ in range(url_size)]
)
isCat = rand_bool() isCat = rand_bool()
isDog = rand_bool() isDog = rand_bool()
@ -42,7 +46,7 @@ class Benchmark:
snippet += [choice(WORDS) for i in range(rand_int(20))] snippet += [choice(WORDS) for i in range(rand_int(20))]
result.snippet = ''.join(snippet) result.snippet = "".join(snippet)
return goodCount return goodCount
@ -60,10 +64,9 @@ class Benchmark:
resp.url = req.url resp.url = req.url
resp.snippet = req.snippet resp.snippet = req.snippet
def check(self, response, expected): def check(self, response, expected):
goodCount = 0 goodCount = 0
for result in response.result: for result in response.result:
if result.score > 1001: if result.score > 1001:
goodCount += 1 goodCount += 1

View file

@ -5,17 +5,19 @@ import catrank_capnp
from common import rand_int, rand_double, rand_bool, WORDS from common import rand_int, rand_double, rand_bool, WORDS
from random import choice from random import choice
from string import ascii_letters from string import ascii_letters
try: try:
# Python 2 # Python 2
from itertools import izip from itertools import izip
except ImportError: except ImportError:
izip = zip izip = zip
class Benchmark: class Benchmark:
def __init__(self, compression): def __init__(self, compression):
self.Request = catrank_capnp.SearchResultList.new_message self.Request = catrank_capnp.SearchResultList.new_message
self.Response = catrank_capnp.SearchResultList.new_message self.Response = catrank_capnp.SearchResultList.new_message
if compression == 'packed': if compression == "packed":
self.from_bytes_request = catrank_capnp.SearchResultList.from_bytes_packed self.from_bytes_request = catrank_capnp.SearchResultList.from_bytes_packed
self.from_bytes_response = catrank_capnp.SearchResultList.from_bytes_packed self.from_bytes_response = catrank_capnp.SearchResultList.from_bytes_packed
self.to_bytes = lambda x: x.to_bytes_packed() self.to_bytes = lambda x: x.to_bytes_packed()
@ -28,12 +30,14 @@ class Benchmark:
goodCount = 0 goodCount = 0
count = rand_int(1000) count = rand_int(1000)
results = request.init('results', count) results = request.init("results", count)
for i, result in enumerate(results): for i, result in enumerate(results):
result.score = 1000 - i result.score = 1000 - i
url_size = rand_int(100) url_size = rand_int(100)
result.url = "http://example.com/" + ''.join([choice(ascii_letters) for _ in range(url_size)]) result.url = "http://example.com/" + "".join(
[choice(ascii_letters) for _ in range(url_size)]
)
isCat = rand_bool() isCat = rand_bool()
isDog = rand_bool() isDog = rand_bool()
@ -49,12 +53,12 @@ class Benchmark:
snippet += [choice(WORDS) for i in range(rand_int(20))] snippet += [choice(WORDS) for i in range(rand_int(20))]
result.snippet = ''.join(snippet) result.snippet = "".join(snippet)
return goodCount return goodCount
def handle(self, request, response): def handle(self, request, response):
results = response.init('results', len(request.results)) results = response.init("results", len(request.results))
for req, resp in izip(request.results, results): for req, resp in izip(request.results, results):
score = req.score score = req.score
@ -68,7 +72,6 @@ class Benchmark:
resp.url = req.url resp.url = req.url
resp.snippet = req.snippet resp.snippet = req.snippet
def check(self, response, expected): def check(self, response, expected):
goodCount = 0 goodCount = 0

View file

@ -1,19 +1,37 @@
from random import random from random import random
import pyximport import pyximport
importers = pyximport.install() importers = pyximport.install()
from common_fast import rand_int, rand_double, rand_bool from common_fast import rand_int, rand_double, rand_bool
pyximport.uninstall(*importers) pyximport.uninstall(*importers)
WORDS = ["foo ", "bar ", "baz ", "qux ", "quux ", "corge ", "grault ", "garply ", "waldo ", "fred ", WORDS = [
"plugh ", "xyzzy ", "thud "] "foo ",
"bar ",
"baz ",
"qux ",
"quux ",
"corge ",
"grault ",
"garply ",
"waldo ",
"fred ",
"plugh ",
"xyzzy ",
"thud ",
]
def from_bytes_helper(klass): def from_bytes_helper(klass):
def helper(text): def helper(text):
obj = klass() obj = klass()
obj.ParseFromString(text) obj.ParseFromString(text)
return obj return obj
return helper return helper
def pass_by_object(reuse, iters, benchmark): def pass_by_object(reuse, iters, benchmark):
for _ in range(iters): for _ in range(iters):
request = benchmark.Request() request = benchmark.Request()
@ -23,7 +41,8 @@ def pass_by_object(reuse, iters, benchmark):
benchmark.handle(request, response) benchmark.handle(request, response)
if not benchmark.check(response, expected): if not benchmark.check(response, expected):
raise ValueError('Expected {}'.format(expected)) raise ValueError("Expected {}".format(expected))
def pass_by_bytes(reuse, iters, benchmark): def pass_by_bytes(reuse, iters, benchmark):
for _ in range(iters): for _ in range(iters):
@ -38,7 +57,8 @@ def pass_by_bytes(reuse, iters, benchmark):
response2 = benchmark.from_bytes_response(resp_bytes) response2 = benchmark.from_bytes_response(resp_bytes)
if not benchmark.check(response2, expected): if not benchmark.check(response2, expected):
raise ValueError('Expected {}'.format(expected)) raise ValueError("Expected {}".format(expected))
def do_benchmark(mode, *args, **kwargs): def do_benchmark(mode, *args, **kwargs):
if mode == "client": if mode == "client":
@ -49,6 +69,8 @@ def do_benchmark(mode, *args, **kwargs):
return pass_by_bytes(*args, **kwargs) return pass_by_bytes(*args, **kwargs)
else: else:
raise ValueError("Unknown mode: " + str(mode)) raise ValueError("Unknown mode: " + str(mode))
# typedef typename BenchmarkTypes::template BenchmarkMethods<TestCase, Reuse, Compression> # typedef typename BenchmarkTypes::template BenchmarkMethods<TestCase, Reuse, Compression>
# BenchmarkMethods; # BenchmarkMethods;
# if (mode == "client") { # if (mode == "client") {
@ -69,4 +91,4 @@ def do_benchmark(mode, *args, **kwargs):
# fprintf(stderr, "Unknown mode: %s\n", mode.c_str()); # fprintf(stderr, "Unknown mode: %s\n", mode.c_str());
# exit(1); # exit(1);
# } # }
# } # }

View file

@ -2,58 +2,55 @@
# source: eval.proto # source: eval.proto
import sys import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1"))
from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='eval.proto', name="eval.proto",
package='capnp.benchmark.protobuf', package="capnp.benchmark.protobuf",
syntax='proto2', syntax="proto2",
serialized_pb=_b('\n\neval.proto\x12\x18\x63\x61pnp.benchmark.protobuf\"\xe5\x01\n\nExpression\x12/\n\x02op\x18\x01 \x02(\x0e\x32#.capnp.benchmark.protobuf.Operation\x12\x12\n\nleft_value\x18\x02 \x01(\x05\x12=\n\x0fleft_expression\x18\x03 \x01(\x0b\x32$.capnp.benchmark.protobuf.Expression\x12\x13\n\x0bright_value\x18\x04 \x01(\x05\x12>\n\x10right_expression\x18\x05 \x01(\x0b\x32$.capnp.benchmark.protobuf.Expression\"!\n\x10\x45valuationResult\x12\r\n\x05value\x18\x01 \x02(\x11*I\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\x0c\n\x08SUBTRACT\x10\x01\x12\x0c\n\x08MULTIPLY\x10\x02\x12\n\n\x06\x44IVIDE\x10\x03\x12\x0b\n\x07MODULUS\x10\x04') serialized_pb=_b(
'\n\neval.proto\x12\x18\x63\x61pnp.benchmark.protobuf"\xe5\x01\n\nExpression\x12/\n\x02op\x18\x01 \x02(\x0e\x32#.capnp.benchmark.protobuf.Operation\x12\x12\n\nleft_value\x18\x02 \x01(\x05\x12=\n\x0fleft_expression\x18\x03 \x01(\x0b\x32$.capnp.benchmark.protobuf.Expression\x12\x13\n\x0bright_value\x18\x04 \x01(\x05\x12>\n\x10right_expression\x18\x05 \x01(\x0b\x32$.capnp.benchmark.protobuf.Expression"!\n\x10\x45valuationResult\x12\r\n\x05value\x18\x01 \x02(\x11*I\n\tOperation\x12\x07\n\x03\x41\x44\x44\x10\x00\x12\x0c\n\x08SUBTRACT\x10\x01\x12\x0c\n\x08MULTIPLY\x10\x02\x12\n\n\x06\x44IVIDE\x10\x03\x12\x0b\n\x07MODULUS\x10\x04'
),
) )
_OPERATION = _descriptor.EnumDescriptor( _OPERATION = _descriptor.EnumDescriptor(
name='Operation', name="Operation",
full_name='capnp.benchmark.protobuf.Operation', full_name="capnp.benchmark.protobuf.Operation",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
values=[ values=[
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='ADD', index=0, number=0, name="ADD", index=0, number=0, options=None, type=None
options=None, ),
type=None), _descriptor.EnumValueDescriptor(
_descriptor.EnumValueDescriptor( name="SUBTRACT", index=1, number=1, options=None, type=None
name='SUBTRACT', index=1, number=1, ),
options=None, _descriptor.EnumValueDescriptor(
type=None), name="MULTIPLY", index=2, number=2, options=None, type=None
_descriptor.EnumValueDescriptor( ),
name='MULTIPLY', index=2, number=2, _descriptor.EnumValueDescriptor(
options=None, name="DIVIDE", index=3, number=3, options=None, type=None
type=None), ),
_descriptor.EnumValueDescriptor( _descriptor.EnumValueDescriptor(
name='DIVIDE', index=3, number=3, name="MODULUS", index=4, number=4, options=None, type=None
options=None, ),
type=None), ],
_descriptor.EnumValueDescriptor( containing_type=None,
name='MODULUS', index=4, number=4, options=None,
options=None, serialized_start=307,
type=None), serialized_end=380,
],
containing_type=None,
options=None,
serialized_start=307,
serialized_end=380,
) )
_sym_db.RegisterEnumDescriptor(_OPERATION) _sym_db.RegisterEnumDescriptor(_OPERATION)
@ -65,116 +62,177 @@ DIVIDE = 3
MODULUS = 4 MODULUS = 4
_EXPRESSION = _descriptor.Descriptor( _EXPRESSION = _descriptor.Descriptor(
name='Expression', name="Expression",
full_name='capnp.benchmark.protobuf.Expression', full_name="capnp.benchmark.protobuf.Expression",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='op', full_name='capnp.benchmark.protobuf.Expression.op', index=0, name="op",
number=1, type=14, cpp_type=8, label=2, full_name="capnp.benchmark.protobuf.Expression.op",
has_default_value=False, default_value=0, index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=14,
options=None), cpp_type=8,
_descriptor.FieldDescriptor( label=2,
name='left_value', full_name='capnp.benchmark.protobuf.Expression.left_value', index=1, has_default_value=False,
number=2, type=5, cpp_type=1, label=1, default_value=0,
has_default_value=False, default_value=0, message_type=None,
message_type=None, enum_type=None, containing_type=None, enum_type=None,
is_extension=False, extension_scope=None, containing_type=None,
options=None), is_extension=False,
_descriptor.FieldDescriptor( extension_scope=None,
name='left_expression', full_name='capnp.benchmark.protobuf.Expression.left_expression', index=2, options=None,
number=3, type=11, cpp_type=10, label=1, ),
has_default_value=False, default_value=None, _descriptor.FieldDescriptor(
message_type=None, enum_type=None, containing_type=None, name="left_value",
is_extension=False, extension_scope=None, full_name="capnp.benchmark.protobuf.Expression.left_value",
options=None), index=1,
_descriptor.FieldDescriptor( number=2,
name='right_value', full_name='capnp.benchmark.protobuf.Expression.right_value', index=3, type=5,
number=4, type=5, cpp_type=1, label=1, cpp_type=1,
has_default_value=False, default_value=0, label=1,
message_type=None, enum_type=None, containing_type=None, has_default_value=False,
is_extension=False, extension_scope=None, default_value=0,
options=None), message_type=None,
_descriptor.FieldDescriptor( enum_type=None,
name='right_expression', full_name='capnp.benchmark.protobuf.Expression.right_expression', index=4, containing_type=None,
number=5, type=11, cpp_type=10, label=1, is_extension=False,
has_default_value=False, default_value=None, extension_scope=None,
message_type=None, enum_type=None, containing_type=None, options=None,
is_extension=False, extension_scope=None, ),
options=None), _descriptor.FieldDescriptor(
], name="left_expression",
extensions=[ full_name="capnp.benchmark.protobuf.Expression.left_expression",
], index=2,
nested_types=[], number=3,
enum_types=[ type=11,
], cpp_type=10,
options=None, label=1,
is_extendable=False, has_default_value=False,
syntax='proto2', default_value=None,
extension_ranges=[], message_type=None,
oneofs=[ enum_type=None,
], containing_type=None,
serialized_start=41, is_extension=False,
serialized_end=270, extension_scope=None,
options=None,
),
_descriptor.FieldDescriptor(
name="right_value",
full_name="capnp.benchmark.protobuf.Expression.right_value",
index=3,
number=4,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None,
),
_descriptor.FieldDescriptor(
name="right_expression",
full_name="capnp.benchmark.protobuf.Expression.right_expression",
index=4,
number=5,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None,
),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=41,
serialized_end=270,
) )
_EVALUATIONRESULT = _descriptor.Descriptor( _EVALUATIONRESULT = _descriptor.Descriptor(
name='EvaluationResult', name="EvaluationResult",
full_name='capnp.benchmark.protobuf.EvaluationResult', full_name="capnp.benchmark.protobuf.EvaluationResult",
filename=None, filename=None,
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='value', full_name='capnp.benchmark.protobuf.EvaluationResult.value', index=0, name="value",
number=1, type=17, cpp_type=1, label=2, full_name="capnp.benchmark.protobuf.EvaluationResult.value",
has_default_value=False, default_value=0, index=0,
message_type=None, enum_type=None, containing_type=None, number=1,
is_extension=False, extension_scope=None, type=17,
options=None), cpp_type=1,
], label=2,
extensions=[ has_default_value=False,
], default_value=0,
nested_types=[], message_type=None,
enum_types=[ enum_type=None,
], containing_type=None,
options=None, is_extension=False,
is_extendable=False, extension_scope=None,
syntax='proto2', options=None,
extension_ranges=[], ),
oneofs=[ ],
], extensions=[],
serialized_start=272, nested_types=[],
serialized_end=305, enum_types=[],
options=None,
is_extendable=False,
syntax="proto2",
extension_ranges=[],
oneofs=[],
serialized_start=272,
serialized_end=305,
) )
_EXPRESSION.fields_by_name['op'].enum_type = _OPERATION _EXPRESSION.fields_by_name["op"].enum_type = _OPERATION
_EXPRESSION.fields_by_name['left_expression'].message_type = _EXPRESSION _EXPRESSION.fields_by_name["left_expression"].message_type = _EXPRESSION
_EXPRESSION.fields_by_name['right_expression'].message_type = _EXPRESSION _EXPRESSION.fields_by_name["right_expression"].message_type = _EXPRESSION
DESCRIPTOR.message_types_by_name['Expression'] = _EXPRESSION DESCRIPTOR.message_types_by_name["Expression"] = _EXPRESSION
DESCRIPTOR.message_types_by_name['EvaluationResult'] = _EVALUATIONRESULT DESCRIPTOR.message_types_by_name["EvaluationResult"] = _EVALUATIONRESULT
DESCRIPTOR.enum_types_by_name['Operation'] = _OPERATION DESCRIPTOR.enum_types_by_name["Operation"] = _OPERATION
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
Expression = _reflection.GeneratedProtocolMessageType('Expression', (_message.Message,), dict( Expression = _reflection.GeneratedProtocolMessageType(
DESCRIPTOR = _EXPRESSION, "Expression",
__module__ = 'eval_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.Expression) dict(
)) DESCRIPTOR=_EXPRESSION,
__module__="eval_pb2"
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.Expression)
),
)
_sym_db.RegisterMessage(Expression) _sym_db.RegisterMessage(Expression)
EvaluationResult = _reflection.GeneratedProtocolMessageType('EvaluationResult', (_message.Message,), dict( EvaluationResult = _reflection.GeneratedProtocolMessageType(
DESCRIPTOR = _EVALUATIONRESULT, "EvaluationResult",
__module__ = 'eval_pb2' (_message.Message,),
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.EvaluationResult) dict(
)) DESCRIPTOR=_EVALUATIONRESULT,
__module__="eval_pb2"
# @@protoc_insertion_point(class_scope:capnp.benchmark.protobuf.EvaluationResult)
),
)
_sym_db.RegisterMessage(EvaluationResult) _sym_db.RegisterMessage(EvaluationResult)

View file

@ -4,15 +4,11 @@ from common import rand_int, rand_double, rand_bool, from_bytes_helper
from random import choice from random import choice
import eval_pb2 import eval_pb2
MAX_INT = 2**31 - 1 MAX_INT = 2 ** 31 - 1
MIN_INT = -(2**31) MIN_INT = -(2 ** 31)
OPERATIONS = ["add", "subtract", "multiply", "divide", "modulus"]
OPERATIONS = [
"add",
"subtract",
"multiply",
"divide",
"modulus"]
def clamp(res): def clamp(res):
if res > MAX_INT: if res > MAX_INT:
@ -22,6 +18,7 @@ def clamp(res):
else: else:
return res return res
def div(a, b): def div(a, b):
if b == 0: if b == 0:
return MAX_INT return MAX_INT
@ -30,6 +27,7 @@ def div(a, b):
return a // b return a // b
def mod(a, b): def mod(a, b):
if b == 0: if b == 0:
return MAX_INT return MAX_INT
@ -38,6 +36,7 @@ def mod(a, b):
return a % b return a % b
def make_expression(exp, depth): def make_expression(exp, depth):
exp.op = rand_int(len(OPERATIONS)) exp.op = rand_int(len(OPERATIONS))
@ -45,13 +44,13 @@ def make_expression(exp, depth):
left = rand_int(128) + 1 left = rand_int(128) + 1
exp.left_value = left exp.left_value = left
else: else:
left = make_expression(exp.left_expression, depth+1) left = make_expression(exp.left_expression, depth + 1)
if rand_int(8) < depth: if rand_int(8) < depth:
right = rand_int(128) + 1 right = rand_int(128) + 1
exp.right_value = right exp.right_value = right
else: else:
right = make_expression(exp.right_expression, depth+1) right = make_expression(exp.right_expression, depth + 1)
op = exp.op op = exp.op
if op == 0: if op == 0:
@ -66,21 +65,21 @@ def make_expression(exp, depth):
return mod(left, right) return mod(left, right)
raise RuntimeError("op wasn't a valid value: " + str(op)) raise RuntimeError("op wasn't a valid value: " + str(op))
def evaluate_expression(exp): def evaluate_expression(exp):
left = 0 left = 0
right = 0 right = 0
if exp.HasField('left_value'): if exp.HasField("left_value"):
left = exp.left_value left = exp.left_value
else: else:
left = evaluate_expression(exp.left_expression) left = evaluate_expression(exp.left_expression)
if exp.HasField('right_value'): if exp.HasField("right_value"):
right = exp.right_value right = exp.right_value
else: else:
right = evaluate_expression(exp.right_expression) right = evaluate_expression(exp.right_expression)
op = exp.op op = exp.op
if op == 0: if op == 0:
return clamp(left + right) return clamp(left + right)
@ -94,6 +93,7 @@ def evaluate_expression(exp):
return mod(left, right) return mod(left, right)
raise RuntimeError("op wasn't a valid value: " + str(op)) raise RuntimeError("op wasn't a valid value: " + str(op))
class Benchmark: class Benchmark:
def __init__(self, compression): def __init__(self, compression):
self.Request = eval_pb2.Expression self.Request = eval_pb2.Expression

View file

@ -5,15 +5,11 @@ import eval_capnp
from common import rand_int, rand_double, rand_bool from common import rand_int, rand_double, rand_bool
from random import choice from random import choice
MAX_INT = 2**31 - 1 MAX_INT = 2 ** 31 - 1
MIN_INT = -(2**31) MIN_INT = -(2 ** 31)
OPERATIONS = ["add", "subtract", "multiply", "divide", "modulus"]
OPERATIONS = [
"add",
"subtract",
"multiply",
"divide",
"modulus"]
def clamp(res): def clamp(res):
if res > MAX_INT: if res > MAX_INT:
@ -23,6 +19,7 @@ def clamp(res):
else: else:
return res return res
def div(a, b): def div(a, b):
if b == 0: if b == 0:
return MAX_INT return MAX_INT
@ -31,6 +28,7 @@ def div(a, b):
return a // b return a // b
def mod(a, b): def mod(a, b):
if b == 0: if b == 0:
return MAX_INT return MAX_INT
@ -39,6 +37,7 @@ def mod(a, b):
return a % b return a % b
def make_expression(exp, depth): def make_expression(exp, depth):
exp.op = choice(OPERATIONS) exp.op = choice(OPERATIONS)
@ -46,62 +45,63 @@ def make_expression(exp, depth):
left = rand_int(128) + 1 left = rand_int(128) + 1
exp.left.value = left exp.left.value = left
else: else:
left = make_expression(exp.left.init('expression'), depth+1) left = make_expression(exp.left.init("expression"), depth + 1)
if rand_int(8) < depth: if rand_int(8) < depth:
right = rand_int(128) + 1 right = rand_int(128) + 1
exp.right.value = right exp.right.value = right
else: else:
right = make_expression(exp.right.init('expression'), depth+1) right = make_expression(exp.right.init("expression"), depth + 1)
op = exp.op op = exp.op
if op == 'add': if op == "add":
return clamp(left + right) return clamp(left + right)
elif op == 'subtract': elif op == "subtract":
return clamp(left - right) return clamp(left - right)
elif op == 'multiply': elif op == "multiply":
return clamp(left * right) return clamp(left * right)
elif op == 'divide': elif op == "divide":
return div(left, right) return div(left, right)
elif op == 'modulus': elif op == "modulus":
return mod(left, right) return mod(left, right)
raise RuntimeError("op wasn't a valid value: " + str(op)) raise RuntimeError("op wasn't a valid value: " + str(op))
def evaluate_expression(exp): def evaluate_expression(exp):
left = 0 left = 0
right = 0 right = 0
which = exp.left.which() which = exp.left.which()
if which == 'value': if which == "value":
left = exp.left.value left = exp.left.value
elif which == 'expression': elif which == "expression":
left = evaluate_expression(exp.left.expression) left = evaluate_expression(exp.left.expression)
which = exp.right.which() which = exp.right.which()
if which == 'value': if which == "value":
right = exp.right.value right = exp.right.value
elif which == 'expression': elif which == "expression":
right = evaluate_expression(exp.right.expression) right = evaluate_expression(exp.right.expression)
op = exp.op op = exp.op
if op == 'add': if op == "add":
return clamp(left + right) return clamp(left + right)
elif op == 'subtract': elif op == "subtract":
return clamp(left - right) return clamp(left - right)
elif op == 'multiply': elif op == "multiply":
return clamp(left * right) return clamp(left * right)
elif op == 'divide': elif op == "divide":
return div(left, right) return div(left, right)
elif op == 'modulus': elif op == "modulus":
return mod(left, right) return mod(left, right)
raise RuntimeError("op wasn't a valid value: " + str(op)) raise RuntimeError("op wasn't a valid value: " + str(op))
class Benchmark: class Benchmark:
def __init__(self, compression): def __init__(self, compression):
self.Request = eval_capnp.Expression.new_message self.Request = eval_capnp.Expression.new_message
self.Response = eval_capnp.EvaluationResult.new_message self.Response = eval_capnp.EvaluationResult.new_message
if compression == 'packed': if compression == "packed":
self.from_bytes_request = eval_capnp.Expression.from_bytes_packed self.from_bytes_request = eval_capnp.Expression.from_bytes_packed
self.from_bytes_response = eval_capnp.EvaluationResult.from_bytes_packed self.from_bytes_response = eval_capnp.EvaluationResult.from_bytes_packed
self.to_bytes = lambda x: x.to_bytes_packed() self.to_bytes = lambda x: x.to_bytes_packed()
@ -117,4 +117,4 @@ class Benchmark:
response.value = evaluate_expression(request) response.value = evaluate_expression(request)
def check(self, response, expected): def check(self, response, expected):
return response.value == expected return response.value == expected

View file

@ -1,4 +1,4 @@
'Build the bundled capnp distribution' "Build the bundled capnp distribution"
import subprocess import subprocess
import os import os
@ -8,53 +8,53 @@ import sys
def build_libcapnp(bundle_dir, build_dir): # noqa: C901 def build_libcapnp(bundle_dir, build_dir): # noqa: C901
''' """
Build capnproto Build capnproto
''' """
bundle_dir = os.path.abspath(bundle_dir) bundle_dir = os.path.abspath(bundle_dir)
capnp_dir = os.path.join(bundle_dir, 'capnproto-c++') capnp_dir = os.path.join(bundle_dir, "capnproto-c++")
build_dir = os.path.abspath(build_dir) build_dir = os.path.abspath(build_dir)
tmp_dir = os.path.join(capnp_dir, 'build{}'.format(8 * struct.calcsize("P"))) tmp_dir = os.path.join(capnp_dir, "build{}".format(8 * struct.calcsize("P")))
# Clean the tmp build directory every time # Clean the tmp build directory every time
if os.path.exists(tmp_dir): if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
os.mkdir(tmp_dir) os.mkdir(tmp_dir)
cxxflags = os.environ.get('CXXFLAGS', None) cxxflags = os.environ.get("CXXFLAGS", None)
ldflags = os.environ.get('LDFLAGS', None) ldflags = os.environ.get("LDFLAGS", None)
os.environ['CXXFLAGS'] = (cxxflags or '') + ' -O2 -DNDEBUG' os.environ["CXXFLAGS"] = (cxxflags or "") + " -O2 -DNDEBUG"
os.environ['LDFLAGS'] = (ldflags or '') os.environ["LDFLAGS"] = ldflags or ""
# Enable ninja for compilation if available # Enable ninja for compilation if available
build_type = [] build_type = []
if shutil.which('ninja'): if shutil.which("ninja"):
build_type = ['-G', 'Ninja'] build_type = ["-G", "Ninja"]
# Determine python shell architecture for Windows # Determine python shell architecture for Windows
python_arch = 8 * struct.calcsize("P") python_arch = 8 * struct.calcsize("P")
build_arch = [] build_arch = []
build_flags = [] build_flags = []
if os.name == 'nt': if os.name == "nt":
if python_arch == 64: if python_arch == 64:
build_arch_flag = "x64" build_arch_flag = "x64"
elif python_arch == 32: elif python_arch == 32:
build_arch_flag = "Win32" build_arch_flag = "Win32"
else: else:
raise RuntimeError('Unknown windows build arch') raise RuntimeError("Unknown windows build arch")
build_arch = ['-A', build_arch_flag] build_arch = ["-A", build_arch_flag]
build_flags = ['--config', 'Release'] build_flags = ["--config", "Release"]
print('Building module for {}'.format(python_arch)) print("Building module for {}".format(python_arch))
if not shutil.which('cmake'): if not shutil.which("cmake"):
raise RuntimeError('Could not find cmake in your path!') raise RuntimeError("Could not find cmake in your path!")
args = [ args = [
'cmake', "cmake",
'-DCMAKE_POSITION_INDEPENDENT_CODE=1', "-DCMAKE_POSITION_INDEPENDENT_CODE=1",
'-DBUILD_TESTING=OFF', "-DBUILD_TESTING=OFF",
'-DBUILD_SHARED_LIBS=OFF', "-DBUILD_SHARED_LIBS=OFF",
'-DCMAKE_INSTALL_PREFIX:PATH={}'.format(build_dir), "-DCMAKE_INSTALL_PREFIX:PATH={}".format(build_dir),
capnp_dir, capnp_dir,
] ]
args.extend(build_type) args.extend(build_type)
@ -62,26 +62,26 @@ def build_libcapnp(bundle_dir, build_dir): # noqa: C901
conf = subprocess.Popen(args, cwd=tmp_dir, stdout=sys.stdout) conf = subprocess.Popen(args, cwd=tmp_dir, stdout=sys.stdout)
returncode = conf.wait() returncode = conf.wait()
if returncode != 0: if returncode != 0:
raise RuntimeError('CMake failed {}'.format(returncode)) raise RuntimeError("CMake failed {}".format(returncode))
# Run build through cmake # Run build through cmake
args = [ args = [
'cmake', "cmake",
'--build', "--build",
'.', ".",
'--target', "--target",
'install', "install",
] ]
args.extend(build_flags) args.extend(build_flags)
build = subprocess.Popen(args, cwd=tmp_dir, stdout=sys.stdout) build = subprocess.Popen(args, cwd=tmp_dir, stdout=sys.stdout)
returncode = build.wait() returncode = build.wait()
if cxxflags is None: if cxxflags is None:
del os.environ['CXXFLAGS'] del os.environ["CXXFLAGS"]
else: else:
os.environ['CXXFLAGS'] = cxxflags os.environ["CXXFLAGS"] = cxxflags
if ldflags is None: if ldflags is None:
del os.environ['LDFLAGS'] del os.environ["LDFLAGS"]
else: else:
os.environ['LDFLAGS'] = ldflags os.environ["LDFLAGS"] = ldflags
if returncode != 0: if returncode != 0:
raise RuntimeError('capnproto compilation failed: {}'.format(returncode)) raise RuntimeError("capnproto compilation failed: {}".format(returncode))

View file

@ -41,7 +41,7 @@ ROOT = os.path.dirname(HERE)
def untgz(archive): def untgz(archive):
"""Remove .tar.gz""" """Remove .tar.gz"""
return archive.replace('.tar.gz', '') return archive.replace(".tar.gz", "")
def localpath(*args): def localpath(*args):
@ -60,7 +60,7 @@ def fetch_archive(savedir, url, fname, force=False):
if not os.path.exists(savedir): if not os.path.exists(savedir):
os.makedirs(savedir) os.makedirs(savedir)
req = urlopen(url) req = urlopen(url)
with open(dest, 'wb') as f: with open(dest, "wb") as f:
f.write(req.read()) f.write(req.read())
return dest return dest
@ -76,7 +76,7 @@ def fetch_libcapnp(savedir, url=None):
if url is None: if url is None:
url = libcapnp_url url = libcapnp_url
is_preconfigured = True is_preconfigured = True
dest = pjoin(savedir, 'capnproto-c++') dest = pjoin(savedir, "capnproto-c++")
if os.path.exists(dest): if os.path.exists(dest):
print("already have %s" % dest) print("already have %s" % dest)
return return
@ -89,5 +89,5 @@ def fetch_libcapnp(savedir, url=None):
if is_preconfigured: if is_preconfigured:
shutil.move(with_version, dest) shutil.move(with_version, dest)
else: else:
cpp_dir = os.path.join(with_version, 'c++') cpp_dir = os.path.join(with_version, "c++")
shutil.move(cpp_dir, dest) shutil.move(cpp_dir, dest)

View file

@ -8,59 +8,73 @@ import schema_capnp
def find_type(code, id): def find_type(code, id):
for node in code['nodes']: for node in code["nodes"]:
if node['id'] == id: if node["id"] == id:
return node return node
return None return None
def main(): def main():
env = Environment(loader=PackageLoader('capnp', 'templates')) env = Environment(loader=PackageLoader("capnp", "templates"))
env.filters['format_name'] = lambda name: name[name.find(':') + 1:] env.filters["format_name"] = lambda name: name[name.find(":") + 1 :]
code = schema_capnp.CodeGeneratorRequest.read(sys.stdin) code = schema_capnp.CodeGeneratorRequest.read(sys.stdin)
code = code.to_dict() code = code.to_dict()
code['nodes'] = [node for node in code['nodes'] if 'struct' in node and node['scopeId'] != 0] code["nodes"] = [
for node in code['nodes']: node for node in code["nodes"] if "struct" in node and node["scopeId"] != 0
displayName = node['displayName'] ]
parent, path = displayName.split(':') for node in code["nodes"]:
node['module_path'] = parent.replace('.', '_') + '.' + '.'.join([x[0].upper() + x[1:] for x in path.split('.')]) displayName = node["displayName"]
node['module_name'] = path.replace('.', '_') parent, path = displayName.split(":")
node['c_module_path'] = '::'.join([x[0].upper() + x[1:] for x in path.split('.')]) node["module_path"] = (
node['schema'] = '_{}_Schema'.format(node['module_name']) parent.replace(".", "_")
+ "."
+ ".".join([x[0].upper() + x[1:] for x in path.split(".")])
)
node["module_name"] = path.replace(".", "_")
node["c_module_path"] = "::".join(
[x[0].upper() + x[1:] for x in path.split(".")]
)
node["schema"] = "_{}_Schema".format(node["module_name"])
is_union = False is_union = False
for field in node['struct']['fields']: for field in node["struct"]["fields"]:
if field['discriminantValue'] != 65535: if field["discriminantValue"] != 65535:
is_union = True is_union = True
field['c_name'] = field['name'][0].upper() + field['name'][1:] field["c_name"] = field["name"][0].upper() + field["name"][1:]
if 'slot' in field: if "slot" in field:
field['type'] = list(field['slot']['type'].keys())[0] field["type"] = list(field["slot"]["type"].keys())[0]
if not isinstance(field['slot']['type'][field['type']], dict): if not isinstance(field["slot"]["type"][field["type"]], dict):
continue continue
sub_type = field['slot']['type'][field['type']].get('typeId', None) sub_type = field["slot"]["type"][field["type"]].get("typeId", None)
if sub_type: if sub_type:
field['sub_type'] = find_type(code, sub_type) field["sub_type"] = find_type(code, sub_type)
sub_type = field['slot']['type'][field['type']].get('elementType', None) sub_type = field["slot"]["type"][field["type"]].get("elementType", None)
if sub_type: if sub_type:
field['sub_type'] = sub_type field["sub_type"] = sub_type
else: else:
field['type'] = find_type(code, field['group']['typeId']) field["type"] = find_type(code, field["group"]["typeId"])
node['is_union'] = is_union node["is_union"] = is_union
include_dir = os.path.abspath(os.path.join(os.path.dirname(capnp.__file__), '..')) include_dir = os.path.abspath(os.path.join(os.path.dirname(capnp.__file__), ".."))
module = env.get_template('module.pyx') module = env.get_template("module.pyx")
for f in code['requestedFiles']: for f in code["requestedFiles"]:
filename = f['filename'].replace('.', '_') + '_cython.pyx' filename = f["filename"].replace(".", "_") + "_cython.pyx"
file_code = dict(code) file_code = dict(code)
file_code['nodes'] = [node for node in file_code['nodes'] if node['displayName'].startswith(f['filename'])] file_code["nodes"] = [
with open(filename, 'w') as out: node
for node in file_code["nodes"]
if node["displayName"].startswith(f["filename"])
]
with open(filename, "w") as out:
out.write(module.render(code=file_code, file=f, include_dir=include_dir)) out.write(module.render(code=file_code, file=f, include_dir=include_dir))
setup = env.get_template('setup.py.tmpl') setup = env.get_template("setup.py.tmpl")
with open('setup_capnp.py', 'w') as out: with open("setup_capnp.py", "w") as out:
out.write(setup.render(code=code)) out.write(setup.render(code=code))
print('You now need to build the cython module by running `python setup_capnp.py build_ext --inplace`.') print(
"You now need to build the cython module by running `python setup_capnp.py build_ext --inplace`."
)
print() print()

View file

@ -1,6 +1,6 @@
''' """
Docs configuration Docs configuration
''' """
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# capnp documentation build configuration file, created by # capnp documentation build configuration file, created by
@ -15,6 +15,7 @@ Docs configuration
# serve to show the default. # serve to show the default.
import string import string
# import sys, os # import sys, os
import capnp import capnp
@ -31,27 +32,27 @@ import capnp
# Add any Sphinx extension module names here, as strings. They can be extensions # Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx_multiversion', "sphinx_multiversion",
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix of source filenames. # The suffix of source filenames.
source_suffix = '.rst' source_suffix = ".rst"
# The encoding of source files. # The encoding of source files.
# source_encoding = 'utf-8-sig' # source_encoding = 'utf-8-sig'
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = u'capnp' project = u"capnp"
copyright = u'2013-2019 (Jason Paryani), 2019-2020 (Jacob Alexander)' copyright = u"2013-2019 (Jason Paryani), 2019-2020 (Jacob Alexander)"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
@ -77,7 +78,7 @@ release = vs
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
exclude_patterns = ['_build'] exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all documents. # The reST default role (used for this markup: `text`) to use for all documents.
# default_role = None # default_role = None
@ -94,7 +95,7 @@ exclude_patterns = ['_build']
# show_authors = False # show_authors = False
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting. # A list of ignored prefixes for module index sorting.
# modindex_common_prefix = [] # modindex_common_prefix = []
@ -104,7 +105,7 @@ pygments_style = 'sphinx'
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
html_theme = 'nature' html_theme = "nature"
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
@ -144,7 +145,15 @@ html_theme = 'nature'
# html_use_smartypants = True # html_use_smartypants = True
# Custom sidebar templates, maps document names to template names. # Custom sidebar templates, maps document names to template names.
html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', 'searchbox.html', 'versioning.html']} html_sidebars = {
"**": [
"globaltoc.html",
"relations.html",
"sourcelink.html",
"searchbox.html",
"versioning.html",
]
}
# Additional templates that should be rendered to pages, maps page names to # Additional templates that should be rendered to pages, maps page names to
# template names. # template names.
@ -177,7 +186,7 @@ html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', '
# html_file_suffix = None # html_file_suffix = None
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'capnpdoc' htmlhelp_basename = "capnpdoc"
# -- Options for LaTeX output -------------------------------------------------- # -- Options for LaTeX output --------------------------------------------------
@ -190,14 +199,12 @@ htmlhelp_basename = 'capnpdoc'
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# 'preamble': '', # 'preamble': '',
latex_elements = { latex_elements = {}
}
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]). # (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [ latex_documents = [
('index', 'capnp.tex', u'capnp Documentation', ("index", "capnp.tex", u"capnp Documentation", u"Author", "manual"),
u'Author', 'manual'),
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
@ -225,10 +232,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [("index", "capnp", u"capnp Documentation", [u"Author"], 1)]
('index', 'capnp', u'capnp Documentation',
[u'Author'], 1)
]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
# man_show_urls = False # man_show_urls = False
@ -240,9 +244,15 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
('index', 'capnp', u'capnp Documentation', (
u'Author', 'capnp', 'One line description of project.', "index",
'Miscellaneous'), "capnp",
u"capnp Documentation",
u"Author",
"capnp",
"One line description of project.",
"Miscellaneous",
),
] ]
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.
@ -258,10 +268,10 @@ texinfo_documents = [
# -- Options for Epub output --------------------------------------------------- # -- Options for Epub output ---------------------------------------------------
# Bibliographic Dublin Core info. # Bibliographic Dublin Core info.
epub_title = u'capnp' epub_title = u"capnp"
epub_author = u'Author' epub_author = u"Author"
epub_publisher = u'Author' epub_publisher = u"Author"
epub_copyright = u'2013, Author' epub_copyright = u"2013, Author"
# The language of the text. It defaults to the language option # The language of the text. It defaults to the language option
# or en if the language is not set. # or en if the language is not set.
@ -297,6 +307,6 @@ epub_copyright = u'2013, Author'
# Allow duplicate toc entries. # Allow duplicate toc entries.
# epub_tocdup = True # epub_tocdup = True
intersphinx_mapping = {'http://docs.python.org/': None} intersphinx_mapping = {"http://docs.python.org/": None}
smv_branch_whitelist = r'^master$' smv_branch_whitelist = r"^master$"

View file

@ -7,26 +7,26 @@ import addressbook_capnp
def writeAddressBook(file): def writeAddressBook(file):
addresses = addressbook_capnp.AddressBook.new_message() addresses = addressbook_capnp.AddressBook.new_message()
people = addresses.init('people', 2) people = addresses.init("people", 2)
alice = people[0] alice = people[0]
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
alice.employment.school = "MIT" alice.employment.school = "MIT"
bob = people[1] bob = people[1]
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
bob.employment.unemployed = None bob.employment.unemployed = None
addresses.write(file) addresses.write(file)
@ -36,27 +36,27 @@ def printAddressBook(file):
addresses = addressbook_capnp.AddressBook.read(file) addresses = addressbook_capnp.AddressBook.read(file)
for person in addresses.people: for person in addresses.people:
print(person.name, ':', person.email) print(person.name, ":", person.email)
for phone in person.phones: for phone in person.phones:
print(phone.type, ':', phone.number) print(phone.type, ":", phone.number)
which = person.employment.which() which = person.employment.which()
print(which) print(which)
if which == 'unemployed': if which == "unemployed":
print('unemployed') print("unemployed")
elif which == 'employer': elif which == "employer":
print('employer:', person.employment.employer) print("employer:", person.employment.employer)
elif which == 'school': elif which == "school":
print('student at:', person.employment.school) print("student at:", person.employment.school)
elif which == 'selfEmployed': elif which == "selfEmployed":
print('self employed') print("self employed")
print() print()
if __name__ == '__main__': if __name__ == "__main__":
f = open('example', 'w') f = open("example", "w")
writeAddressBook(f) writeAddressBook(f)
f = open('example', 'r') f = open("example", "r")
printAddressBook(f) printAddressBook(f)

View file

@ -10,16 +10,16 @@ import calculator_capnp
class PowerFunction(calculator_capnp.Calculator.Function.Server): class PowerFunction(calculator_capnp.Calculator.Function.Server):
'''An implementation of the Function interface wrapping pow(). Note that """An implementation of the Function interface wrapping pow(). Note that
we're implementing this on the client side and will pass a reference to we're implementing this on the client side and will pass a reference to
the server. The server will then be able to make calls back to the client.''' the server. The server will then be able to make calls back to the client."""
def call(self, params, **kwargs): def call(self, params, **kwargs):
'''Note the **kwargs. This is very necessary to include, since """Note the **kwargs. This is very necessary to include, since
protocols can add parameters over time. Also, by default, a _context protocols can add parameters over time. Also, by default, a _context
variable is passed to all server methods, but you can also return variable is passed to all server methods, but you can also return
results directly as python objects, and they'll be added to the results directly as python objects, and they'll be added to the
results struct in the correct order''' results struct in the correct order"""
return pow(params[0], params[1]) return pow(params[0], params[1])
@ -38,29 +38,29 @@ async def mywriter(client, writer):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Calculator server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Calculator server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
async def main(host): async def main(host):
host = host.split(':') host = host.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, family=socket.AF_INET
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, family=socket.AF_INET6
family=socket.AF_INET6
) )
# Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode) # Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode)
@ -73,7 +73,7 @@ async def main(host):
# Bootstrap the Calculator interface # Bootstrap the Calculator interface
calculator = client.bootstrap().cast_as(calculator_capnp.Calculator) calculator = client.bootstrap().cast_as(calculator_capnp.Calculator)
'''Make a request that just evaluates the literal value 123. """Make a request that just evaluates the literal value 123.
What's interesting here is that evaluate() returns a "Value", which is What's interesting here is that evaluate() returns a "Value", which is
another interface and therefore points back to an object living on the another interface and therefore points back to an object living on the
@ -81,9 +81,9 @@ async def main(host):
However, even though we are making two RPC's, this block executes in However, even though we are making two RPC's, this block executes in
*one* network round trip because of promise pipelining: we do not wait *one* network round trip because of promise pipelining: we do not wait
for the first call to complete before we send the second call to the for the first call to complete before we send the second call to the
server.''' server."""
print('Evaluating a literal... ', end="") print("Evaluating a literal... ", end="")
# Make the request. Note we are using the shorter function form (instead # Make the request. Note we are using the shorter function form (instead
# of evaluate_request), and we are passing a dictionary that represents a # of evaluate_request), and we are passing a dictionary that represents a
@ -91,13 +91,13 @@ async def main(host):
eval_promise = calculator.evaluate({"literal": 123}) eval_promise = calculator.evaluate({"literal": 123})
# This is equivalent to: # This is equivalent to:
''' """
request = calculator.evaluate_request() request = calculator.evaluate_request()
request.expression.literal = 123 request.expression.literal = 123
# Send it, which returns a promise for the result (without blocking). # Send it, which returns a promise for the result (without blocking).
eval_promise = request.send() eval_promise = request.send()
''' """
# Using the promise, create a pipelined request to call read() on the # Using the promise, create a pipelined request to call read() on the
# returned object. Note that here we are using the shortened method call # returned object. Note that here we are using the shortened method call
@ -111,32 +111,32 @@ async def main(host):
print("PASS") print("PASS")
'''Make a request to evaluate 123 + 45 - 67. """Make a request to evaluate 123 + 45 - 67.
The Calculator interface requires that we first call getOperator() to The Calculator interface requires that we first call getOperator() to
get the addition and subtraction functions, then call evaluate() to use get the addition and subtraction functions, then call evaluate() to use
them. But, once again, we can get both functions, call evaluate(), and them. But, once again, we can get both functions, call evaluate(), and
then read() the result -- four RPCs -- in the time of *one* network then read() the result -- four RPCs -- in the time of *one* network
round trip, because of promise pipelining.''' round trip, because of promise pipelining."""
print("Using add and subtract... ", end='') print("Using add and subtract... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "subtract" function from the server. # Get the "subtract" function from the server.
subtract = calculator.getOperator(op='subtract').func subtract = calculator.getOperator(op="subtract").func
# Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate' # Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate'
# + '_request', where 'evaluate' is the name of the method we want to call # + '_request', where 'evaluate' is the name of the method we want to call
request = calculator.evaluate_request() request = calculator.evaluate_request()
subtract_call = request.expression.init('call') subtract_call = request.expression.init("call")
subtract_call.function = subtract subtract_call.function = subtract
subtract_params = subtract_call.init('params', 2) subtract_params = subtract_call.init("params", 2)
subtract_params[1].literal = 67.0 subtract_params[1].literal = 67.0
add_call = subtract_params[0].init('call') add_call = subtract_params[0].init("call")
add_call.function = add add_call.function = add
add_params = add_call.init('params', 2) add_params = add_call.init("params", 2)
add_params[0].literal = 123 add_params[0].literal = 123
add_params[1].literal = 45 add_params[1].literal = 45
@ -149,7 +149,7 @@ async def main(host):
print("PASS") print("PASS")
''' """
Note: a one liner version of building the previous request (I highly Note: a one liner version of building the previous request (I highly
recommend not doing it this way for such a complicated structure, but I recommend not doing it this way for such a complicated structure, but I
just wanted to demonstrate it is possible to set all of the fields with a just wanted to demonstrate it is possible to set all of the fields with a
@ -161,22 +161,22 @@ async def main(host):
'params': [{'literal': 123}, 'params': [{'literal': 123},
{'literal': 45}]}}, {'literal': 45}]}},
{'literal': 67.0}]}}) {'literal': 67.0}]}})
''' """
'''Make a request to evaluate 4 * 6, then use the result in two more """Make a request to evaluate 4 * 6, then use the result in two more
requests that add 3 and 5. requests that add 3 and 5.
Since evaluate() returns its result wrapped in a `Value`, we can pass Since evaluate() returns its result wrapped in a `Value`, we can pass
that `Value` back to the server in subsequent requests before the first that `Value` back to the server in subsequent requests before the first
`evaluate()` has actually returned. Thus, this example again does only `evaluate()` has actually returned. Thus, this example again does only
one network round trip.''' one network round trip."""
print("Pipelining eval() calls... ", end="") print("Pipelining eval() calls... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "multiply" function from the server. # Get the "multiply" function from the server.
multiply = calculator.getOperator(op='multiply').func multiply = calculator.getOperator(op="multiply").func
# Build the request to evaluate 4 * 6 # Build the request to evaluate 4 * 6
request = calculator.evaluate_request() request = calculator.evaluate_request()
@ -213,7 +213,7 @@ async def main(host):
print("PASS") print("PASS")
'''Our calculator interface supports defining functions. Here we use it """Our calculator interface supports defining functions. Here we use it
to define two functions and then make calls to them as follows: to define two functions and then make calls to them as follows:
f(x, y) = x * 100 + y f(x, y) = x * 100 + y
@ -221,14 +221,14 @@ async def main(host):
f(12, 34) f(12, 34)
g(21) g(21)
Once again, the whole thing takes only one network round trip.''' Once again, the whole thing takes only one network round trip."""
print("Defining functions... ", end="") print("Defining functions... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "multiply" function from the server. # Get the "multiply" function from the server.
multiply = calculator.getOperator(op='multiply').func multiply = calculator.getOperator(op="multiply").func
# Define f. # Define f.
request = calculator.defFunction_request() request = calculator.defFunction_request()
@ -286,7 +286,7 @@ async def main(host):
g_eval_request = calculator.evaluate_request() g_eval_request = calculator.evaluate_request()
g_call = g_eval_request.expression.init("call") g_call = g_eval_request.expression.init("call")
g_call.function = g g_call.function = g
g_call.init('params', 1)[0].literal = 21 g_call.init("params", 1)[0].literal = 21
g_eval_promise = g_eval_request.send().value.read() g_eval_promise = g_eval_request.send().value.read()
# Wait for the results. # Wait for the results.
@ -295,7 +295,7 @@ async def main(host):
print("PASS") print("PASS")
'''Make a request that will call back to a function defined locally. """Make a request that will call back to a function defined locally.
Specifically, we will compute 2^(4 + 5). However, exponent is not Specifically, we will compute 2^(4 + 5). However, exponent is not
defined by the Calculator server. So, we'll implement the Function defined by the Calculator server. So, we'll implement the Function
@ -307,12 +307,12 @@ async def main(host):
particular case, this could potentially be optimized by using a tail particular case, this could potentially be optimized by using a tail
call on the server side -- see CallContext::tailCall(). However, to call on the server side -- see CallContext::tailCall(). However, to
keep the example simpler, we haven't implemented this optimization in keep the example simpler, we haven't implemented this optimization in
the sample server.''' the sample server."""
print("Using a callback... ", end="") print("Using a callback... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Build the eval request for 2^(4+5). # Build the eval request for 2^(4+5).
request = calculator.evaluate_request() request = calculator.evaluate_request()
@ -334,5 +334,6 @@ async def main(host):
print("PASS") print("PASS")
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main(parse_args().host)) asyncio.run(main(parse_args().host))

View file

@ -18,10 +18,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.reader.read(4096), timeout=0.1)
self.reader.read(4096),
timeout=0.1
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("myreader timeout.") logger.debug("myreader timeout.")
continue continue
@ -36,10 +33,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.server.read(4096), timeout=0.1)
self.server.read(4096),
timeout=0.1
)
self.writer.write(data.tobytes()) self.writer.write(data.tobytes())
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("mywriter timeout.") logger.debug("mywriter timeout.")
@ -74,29 +68,29 @@ class Server:
def read_value(value): def read_value(value):
'''Helper function to asynchronously call read() on a Calculator::Value and """Helper function to asynchronously call read() on a Calculator::Value and
return a promise for the result. (In the future, the generated code might return a promise for the result. (In the future, the generated code might
include something like this automatically.)''' include something like this automatically.)"""
return value.read().then(lambda result: result.value) return value.read().then(lambda result: result.value)
def evaluate_impl(expression, params=None): def evaluate_impl(expression, params=None):
'''Implementation of CalculatorImpl::evaluate(), also shared by """Implementation of CalculatorImpl::evaluate(), also shared by
FunctionImpl::call(). In the latter case, `params` are the parameter FunctionImpl::call(). In the latter case, `params` are the parameter
values passed to the function; in the former case, `params` is just an values passed to the function; in the former case, `params` is just an
empty list.''' empty list."""
which = expression.which() which = expression.which()
if which == 'literal': if which == "literal":
return capnp.Promise(expression.literal) return capnp.Promise(expression.literal)
elif which == 'previousResult': elif which == "previousResult":
return read_value(expression.previousResult) return read_value(expression.previousResult)
elif which == 'parameter': elif which == "parameter":
assert expression.parameter < len(params) assert expression.parameter < len(params)
return capnp.Promise(params[expression.parameter]) return capnp.Promise(params[expression.parameter])
elif which == 'call': elif which == "call":
call = expression.call call = expression.call
func = call.function func = call.function
@ -105,9 +99,9 @@ def evaluate_impl(expression, params=None):
joinedParams = capnp.join_promises(paramPromises) joinedParams = capnp.join_promises(paramPromises)
# When the parameters are complete, call the function. # When the parameters are complete, call the function.
ret = (joinedParams ret = joinedParams.then(lambda vals: func.call(vals)).then(
.then(lambda vals: func.call(vals)) lambda result: result.value
.then(lambda result: result.value)) )
return ret return ret
else: else:
@ -127,28 +121,30 @@ class ValueImpl(calculator_capnp.Calculator.Value.Server):
class FunctionImpl(calculator_capnp.Calculator.Function.Server): class FunctionImpl(calculator_capnp.Calculator.Function.Server):
'''Implementation of the Calculator.Function Cap'n Proto interface, where the """Implementation of the Calculator.Function Cap'n Proto interface, where the
function is defined by a Calculator.Expression.''' function is defined by a Calculator.Expression."""
def __init__(self, paramCount, body): def __init__(self, paramCount, body):
self.paramCount = paramCount self.paramCount = paramCount
self.body = body.as_builder() self.body = body.as_builder()
def call(self, params, _context, **kwargs): def call(self, params, _context, **kwargs):
'''Note that we're returning a Promise object here, and bypassing the """Note that we're returning a Promise object here, and bypassing the
helper functionality that normally sets the results struct from the helper functionality that normally sets the results struct from the
returned object. Instead, we set _context.results directly inside of returned object. Instead, we set _context.results directly inside of
another promise''' another promise"""
assert len(params) == self.paramCount assert len(params) == self.paramCount
# using setattr because '=' is not allowed inside of lambdas # using setattr because '=' is not allowed inside of lambdas
return evaluate_impl(self.body, params).then(lambda value: setattr(_context.results, 'value', value)) return evaluate_impl(self.body, params).then(
lambda value: setattr(_context.results, "value", value)
)
class OperatorImpl(calculator_capnp.Calculator.Function.Server): class OperatorImpl(calculator_capnp.Calculator.Function.Server):
'''Implementation of the Calculator.Function Cap'n Proto interface, wrapping """Implementation of the Calculator.Function Cap'n Proto interface, wrapping
basic binary arithmetic operators.''' basic binary arithmetic operators."""
def __init__(self, op): def __init__(self, op):
self.op = op self.op = op
@ -158,16 +154,16 @@ class OperatorImpl(calculator_capnp.Calculator.Function.Server):
op = self.op op = self.op
if op == 'add': if op == "add":
return params[0] + params[1] return params[0] + params[1]
elif op == 'subtract': elif op == "subtract":
return params[0] - params[1] return params[0] - params[1]
elif op == 'multiply': elif op == "multiply":
return params[0] * params[1] return params[0] * params[1]
elif op == 'divide': elif op == "divide":
return params[0] / params[1] return params[0] / params[1]
else: else:
raise ValueError('Unknown operator') raise ValueError("Unknown operator")
class CalculatorImpl(calculator_capnp.Calculator.Server): class CalculatorImpl(calculator_capnp.Calculator.Server):
@ -175,7 +171,9 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
"Implementation of the Calculator Cap'n Proto interface." "Implementation of the Calculator Cap'n Proto interface."
def evaluate(self, expression, _context, **kwargs): def evaluate(self, expression, _context, **kwargs):
return evaluate_impl(expression).then(lambda value: setattr(_context.results, 'value', ValueImpl(value))) return evaluate_impl(expression).then(
lambda value: setattr(_context.results, "value", ValueImpl(value))
)
def defFunction(self, paramCount, body, _context, **kwargs): def defFunction(self, paramCount, body, _context, **kwargs):
return FunctionImpl(paramCount, body) return FunctionImpl(paramCount, body)
@ -185,8 +183,10 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='''Runs the server bound to the\ parser = argparse.ArgumentParser(
given address/port ADDRESS. ''') usage="""Runs the server bound to the\
given address/port ADDRESS. """
)
parser.add_argument("address", help="ADDRESS:PORT") parser.add_argument("address", help="ADDRESS:PORT")
@ -200,7 +200,7 @@ async def new_connection(reader, writer):
async def main(): async def main():
address = parse_args().address address = parse_args().address
host = address.split(':') host = address.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
@ -208,20 +208,17 @@ async def main():
try: try:
print("Try IPv4") print("Try IPv4")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection, addr, port, family=socket.AF_INET
addr, port,
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection, addr, port, family=socket.AF_INET6
addr, port,
family=socket.AF_INET6
) )
async with server: async with server:
await server.serve_forever() await server.serve_forever()
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View file

@ -13,18 +13,20 @@ capnp.create_event_loop(threaded=True)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Example thread server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Example thread server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server): class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server):
'''An implementation of the StatusSubscriber interface''' """An implementation of the StatusSubscriber interface"""
def status(self, value, **kwargs): def status(self, value, **kwargs):
print('status: {}'.format(time.time())) print("status: {}".format(time.time()))
async def myreader(client, reader): async def myreader(client, reader):
@ -46,21 +48,19 @@ async def background(cap):
async def main(host): async def main(host):
host = host.split(':') host = host.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, family=socket.AF_INET
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, family=socket.AF_INET6
family=socket.AF_INET6
) )
# Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode) # Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode)
@ -76,13 +76,14 @@ async def main(host):
asyncio.gather(*tasks, return_exceptions=True) asyncio.gather(*tasks, return_exceptions=True)
# Run blocking tasks # Run blocking tasks
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main(parse_args().host)) asyncio.run(main(parse_args().host))

View file

@ -17,18 +17,20 @@ capnp.create_event_loop(threaded=True)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Example thread server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Example thread server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server): class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server):
'''An implementation of the StatusSubscriber interface''' """An implementation of the StatusSubscriber interface"""
def status(self, value, **kwargs): def status(self, value, **kwargs):
print('status: {}'.format(time.time())) print("status: {}".format(time.time()))
async def myreader(client, reader): async def myreader(client, reader):
@ -71,28 +73,26 @@ async def background(cap):
async def main(host): async def main(host):
host = host.split(':') host = host.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Setup SSL context # Setup SSL context
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=os.path.join(this_dir, 'selfsigned.cert')) ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=os.path.join(this_dir, "selfsigned.cert")
)
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, ssl=ctx, family=socket.AF_INET
ssl=ctx,
family=socket.AF_INET
) )
except OSError: except OSError:
print("Try IPv6") print("Try IPv6")
try: try:
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, ssl=ctx, family=socket.AF_INET6
ssl=ctx,
family=socket.AF_INET6
) )
except OSError: except OSError:
return False return False
@ -115,20 +115,21 @@ async def main(host):
overalltasks.append(asyncio.gather(*tasks, return_exceptions=True)) overalltasks.append(asyncio.gather(*tasks, return_exceptions=True))
# Run blocking tasks # Run blocking tasks
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
for task in overalltasks: for task in overalltasks:
task.cancel() task.cancel()
return True return True
if __name__ == '__main__':
if __name__ == "__main__":
# Using asyncio.run hits an asyncio ssl bug # Using asyncio.run hits an asyncio ssl bug
# https://bugs.python.org/issue36709 # https://bugs.python.org/issue36709
# asyncio.run(main(parse_args().host), loop=loop, debug=True) # asyncio.run(main(parse_args().host), loop=loop, debug=True)

View file

@ -18,12 +18,15 @@ class ExampleImpl(thread_capnp.Example.Server):
"Implementation of the Example threading Cap'n Proto interface." "Implementation of the Example threading Cap'n Proto interface."
def subscribeStatus(self, subscriber, **kwargs): def subscribeStatus(self, subscriber, **kwargs):
return capnp.getTimer().after_delay(10**9) \ return (
.then(lambda: subscriber.status(True)) \ capnp.getTimer()
.after_delay(10 ** 9)
.then(lambda: subscriber.status(True))
.then(lambda _: self.subscribeStatus(subscriber)) .then(lambda _: self.subscribeStatus(subscriber))
)
def longRunning(self, **kwargs): def longRunning(self, **kwargs):
return capnp.getTimer().after_delay(1 * 10**9) return capnp.getTimer().after_delay(1 * 10 ** 9)
class Server: class Server:
@ -31,10 +34,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.reader.read(4096), timeout=0.1)
self.reader.read(4096),
timeout=0.1
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("myreader timeout.") logger.debug("myreader timeout.")
continue continue
@ -49,10 +49,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.server.read(4096), timeout=0.1)
self.server.read(4096),
timeout=0.1
)
self.writer.write(data.tobytes()) self.writer.write(data.tobytes())
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("mywriter timeout.") logger.debug("mywriter timeout.")
@ -87,8 +84,10 @@ class Server:
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='''Runs the server bound to the\ parser = argparse.ArgumentParser(
given address/port ADDRESS. ''') usage="""Runs the server bound to the\
given address/port ADDRESS. """
)
parser.add_argument("address", help="ADDRESS:PORT") parser.add_argument("address", help="ADDRESS:PORT")
@ -102,7 +101,7 @@ async def new_connection(reader, writer):
async def main(): async def main():
address = parse_args().address address = parse_args().address
host = address.split(':') host = address.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
@ -110,21 +109,17 @@ async def main():
try: try:
print("Try IPv4") print("Try IPv4")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection, addr, port, family=socket.AF_INET
addr, port,
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection, addr, port, family=socket.AF_INET6
addr, port,
family=socket.AF_INET6
) )
async with server: async with server:
await server.serve_forever() await server.serve_forever()
if __name__ == '__main__': if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View file

@ -15,16 +15,16 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
class PowerFunction(calculator_capnp.Calculator.Function.Server): class PowerFunction(calculator_capnp.Calculator.Function.Server):
'''An implementation of the Function interface wrapping pow(). Note that """An implementation of the Function interface wrapping pow(). Note that
we're implementing this on the client side and will pass a reference to we're implementing this on the client side and will pass a reference to
the server. The server will then be able to make calls back to the client.''' the server. The server will then be able to make calls back to the client."""
def call(self, params, **kwargs): def call(self, params, **kwargs):
'''Note the **kwargs. This is very necessary to include, since """Note the **kwargs. This is very necessary to include, since
protocols can add parameters over time. Also, by default, a _context protocols can add parameters over time. Also, by default, a _context
variable is passed to all server methods, but you can also return variable is passed to all server methods, but you can also return
results directly as python objects, and they'll be added to the results directly as python objects, and they'll be added to the
results struct in the correct order''' results struct in the correct order"""
return pow(params[0], params[1]) return pow(params[0], params[1])
@ -43,35 +43,35 @@ async def mywriter(client, writer):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Calculator server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Calculator server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
async def main(host): async def main(host):
host = host.split(':') host = host.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Setup SSL context # Setup SSL context
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=os.path.join(this_dir, 'selfsigned.cert')) ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=os.path.join(this_dir, "selfsigned.cert")
)
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, ssl=ctx, family=socket.AF_INET
ssl=ctx,
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, ssl=ctx, family=socket.AF_INET6
ssl=ctx,
family=socket.AF_INET6
) )
# Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode) # Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode)
@ -84,7 +84,7 @@ async def main(host):
# Bootstrap the Calculator interface # Bootstrap the Calculator interface
calculator = client.bootstrap().cast_as(calculator_capnp.Calculator) calculator = client.bootstrap().cast_as(calculator_capnp.Calculator)
'''Make a request that just evaluates the literal value 123. """Make a request that just evaluates the literal value 123.
What's interesting here is that evaluate() returns a "Value", which is What's interesting here is that evaluate() returns a "Value", which is
another interface and therefore points back to an object living on the another interface and therefore points back to an object living on the
@ -92,9 +92,9 @@ async def main(host):
However, even though we are making two RPC's, this block executes in However, even though we are making two RPC's, this block executes in
*one* network round trip because of promise pipelining: we do not wait *one* network round trip because of promise pipelining: we do not wait
for the first call to complete before we send the second call to the for the first call to complete before we send the second call to the
server.''' server."""
print('Evaluating a literal... ', end="") print("Evaluating a literal... ", end="")
# Make the request. Note we are using the shorter function form (instead # Make the request. Note we are using the shorter function form (instead
# of evaluate_request), and we are passing a dictionary that represents a # of evaluate_request), and we are passing a dictionary that represents a
@ -102,13 +102,13 @@ async def main(host):
eval_promise = calculator.evaluate({"literal": 123}) eval_promise = calculator.evaluate({"literal": 123})
# This is equivalent to: # This is equivalent to:
''' """
request = calculator.evaluate_request() request = calculator.evaluate_request()
request.expression.literal = 123 request.expression.literal = 123
# Send it, which returns a promise for the result (without blocking). # Send it, which returns a promise for the result (without blocking).
eval_promise = request.send() eval_promise = request.send()
''' """
# Using the promise, create a pipelined request to call read() on the # Using the promise, create a pipelined request to call read() on the
# returned object. Note that here we are using the shortened method call # returned object. Note that here we are using the shortened method call
@ -122,32 +122,32 @@ async def main(host):
print("PASS") print("PASS")
'''Make a request to evaluate 123 + 45 - 67. """Make a request to evaluate 123 + 45 - 67.
The Calculator interface requires that we first call getOperator() to The Calculator interface requires that we first call getOperator() to
get the addition and subtraction functions, then call evaluate() to use get the addition and subtraction functions, then call evaluate() to use
them. But, once again, we can get both functions, call evaluate(), and them. But, once again, we can get both functions, call evaluate(), and
then read() the result -- four RPCs -- in the time of *one* network then read() the result -- four RPCs -- in the time of *one* network
round trip, because of promise pipelining.''' round trip, because of promise pipelining."""
print("Using add and subtract... ", end='') print("Using add and subtract... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "subtract" function from the server. # Get the "subtract" function from the server.
subtract = calculator.getOperator(op='subtract').func subtract = calculator.getOperator(op="subtract").func
# Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate' # Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate'
# + '_request', where 'evaluate' is the name of the method we want to call # + '_request', where 'evaluate' is the name of the method we want to call
request = calculator.evaluate_request() request = calculator.evaluate_request()
subtract_call = request.expression.init('call') subtract_call = request.expression.init("call")
subtract_call.function = subtract subtract_call.function = subtract
subtract_params = subtract_call.init('params', 2) subtract_params = subtract_call.init("params", 2)
subtract_params[1].literal = 67.0 subtract_params[1].literal = 67.0
add_call = subtract_params[0].init('call') add_call = subtract_params[0].init("call")
add_call.function = add add_call.function = add
add_params = add_call.init('params', 2) add_params = add_call.init("params", 2)
add_params[0].literal = 123 add_params[0].literal = 123
add_params[1].literal = 45 add_params[1].literal = 45
@ -160,7 +160,7 @@ async def main(host):
print("PASS") print("PASS")
''' """
Note: a one liner version of building the previous request (I highly Note: a one liner version of building the previous request (I highly
recommend not doing it this way for such a complicated structure, but I recommend not doing it this way for such a complicated structure, but I
just wanted to demonstrate it is possible to set all of the fields with a just wanted to demonstrate it is possible to set all of the fields with a
@ -172,22 +172,22 @@ async def main(host):
'params': [{'literal': 123}, 'params': [{'literal': 123},
{'literal': 45}]}}, {'literal': 45}]}},
{'literal': 67.0}]}}) {'literal': 67.0}]}})
''' """
'''Make a request to evaluate 4 * 6, then use the result in two more """Make a request to evaluate 4 * 6, then use the result in two more
requests that add 3 and 5. requests that add 3 and 5.
Since evaluate() returns its result wrapped in a `Value`, we can pass Since evaluate() returns its result wrapped in a `Value`, we can pass
that `Value` back to the server in subsequent requests before the first that `Value` back to the server in subsequent requests before the first
`evaluate()` has actually returned. Thus, this example again does only `evaluate()` has actually returned. Thus, this example again does only
one network round trip.''' one network round trip."""
print("Pipelining eval() calls... ", end="") print("Pipelining eval() calls... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "multiply" function from the server. # Get the "multiply" function from the server.
multiply = calculator.getOperator(op='multiply').func multiply = calculator.getOperator(op="multiply").func
# Build the request to evaluate 4 * 6 # Build the request to evaluate 4 * 6
request = calculator.evaluate_request() request = calculator.evaluate_request()
@ -224,7 +224,7 @@ async def main(host):
print("PASS") print("PASS")
'''Our calculator interface supports defining functions. Here we use it """Our calculator interface supports defining functions. Here we use it
to define two functions and then make calls to them as follows: to define two functions and then make calls to them as follows:
f(x, y) = x * 100 + y f(x, y) = x * 100 + y
@ -232,14 +232,14 @@ async def main(host):
f(12, 34) f(12, 34)
g(21) g(21)
Once again, the whole thing takes only one network round trip.''' Once again, the whole thing takes only one network round trip."""
print("Defining functions... ", end="") print("Defining functions... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "multiply" function from the server. # Get the "multiply" function from the server.
multiply = calculator.getOperator(op='multiply').func multiply = calculator.getOperator(op="multiply").func
# Define f. # Define f.
request = calculator.defFunction_request() request = calculator.defFunction_request()
@ -297,7 +297,7 @@ async def main(host):
g_eval_request = calculator.evaluate_request() g_eval_request = calculator.evaluate_request()
g_call = g_eval_request.expression.init("call") g_call = g_eval_request.expression.init("call")
g_call.function = g g_call.function = g
g_call.init('params', 1)[0].literal = 21 g_call.init("params", 1)[0].literal = 21
g_eval_promise = g_eval_request.send().value.read() g_eval_promise = g_eval_request.send().value.read()
# Wait for the results. # Wait for the results.
@ -306,7 +306,7 @@ async def main(host):
print("PASS") print("PASS")
'''Make a request that will call back to a function defined locally. """Make a request that will call back to a function defined locally.
Specifically, we will compute 2^(4 + 5). However, exponent is not Specifically, we will compute 2^(4 + 5). However, exponent is not
defined by the Calculator server. So, we'll implement the Function defined by the Calculator server. So, we'll implement the Function
@ -318,12 +318,12 @@ async def main(host):
particular case, this could potentially be optimized by using a tail particular case, this could potentially be optimized by using a tail
call on the server side -- see CallContext::tailCall(). However, to call on the server side -- see CallContext::tailCall(). However, to
keep the example simpler, we haven't implemented this optimization in keep the example simpler, we haven't implemented this optimization in
the sample server.''' the sample server."""
print("Using a callback... ", end="") print("Using a callback... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Build the eval request for 2^(4+5). # Build the eval request for 2^(4+5).
request = calculator.evaluate_request() request = calculator.evaluate_request()
@ -345,7 +345,8 @@ async def main(host):
print("PASS") print("PASS")
if __name__ == '__main__':
if __name__ == "__main__":
# Using asyncio.run hits an asyncio ssl bug # Using asyncio.run hits an asyncio ssl bug
# https://bugs.python.org/issue36709 # https://bugs.python.org/issue36709
# asyncio.run(main(parse_args().host), loop=loop, debug=True) # asyncio.run(main(parse_args().host), loop=loop, debug=True)

View file

@ -22,10 +22,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.reader.read(4096), timeout=0.1)
self.reader.read(4096),
timeout=0.1
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("myreader timeout.") logger.debug("myreader timeout.")
continue continue
@ -40,10 +37,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.server.read(4096), timeout=0.1)
self.server.read(4096),
timeout=0.1
)
self.writer.write(data.tobytes()) self.writer.write(data.tobytes())
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("mywriter timeout.") logger.debug("mywriter timeout.")
@ -78,29 +72,29 @@ class Server:
def read_value(value): def read_value(value):
'''Helper function to asynchronously call read() on a Calculator::Value and """Helper function to asynchronously call read() on a Calculator::Value and
return a promise for the result. (In the future, the generated code might return a promise for the result. (In the future, the generated code might
include something like this automatically.)''' include something like this automatically.)"""
return value.read().then(lambda result: result.value) return value.read().then(lambda result: result.value)
def evaluate_impl(expression, params=None): def evaluate_impl(expression, params=None):
'''Implementation of CalculatorImpl::evaluate(), also shared by """Implementation of CalculatorImpl::evaluate(), also shared by
FunctionImpl::call(). In the latter case, `params` are the parameter FunctionImpl::call(). In the latter case, `params` are the parameter
values passed to the function; in the former case, `params` is just an values passed to the function; in the former case, `params` is just an
empty list.''' empty list."""
which = expression.which() which = expression.which()
if which == 'literal': if which == "literal":
return capnp.Promise(expression.literal) return capnp.Promise(expression.literal)
elif which == 'previousResult': elif which == "previousResult":
return read_value(expression.previousResult) return read_value(expression.previousResult)
elif which == 'parameter': elif which == "parameter":
assert expression.parameter < len(params) assert expression.parameter < len(params)
return capnp.Promise(params[expression.parameter]) return capnp.Promise(params[expression.parameter])
elif which == 'call': elif which == "call":
call = expression.call call = expression.call
func = call.function func = call.function
@ -109,9 +103,9 @@ def evaluate_impl(expression, params=None):
joinedParams = capnp.join_promises(paramPromises) joinedParams = capnp.join_promises(paramPromises)
# When the parameters are complete, call the function. # When the parameters are complete, call the function.
ret = (joinedParams ret = joinedParams.then(lambda vals: func.call(vals)).then(
.then(lambda vals: func.call(vals)) lambda result: result.value
.then(lambda result: result.value)) )
return ret return ret
else: else:
@ -131,28 +125,30 @@ class ValueImpl(calculator_capnp.Calculator.Value.Server):
class FunctionImpl(calculator_capnp.Calculator.Function.Server): class FunctionImpl(calculator_capnp.Calculator.Function.Server):
'''Implementation of the Calculator.Function Cap'n Proto interface, where the """Implementation of the Calculator.Function Cap'n Proto interface, where the
function is defined by a Calculator.Expression.''' function is defined by a Calculator.Expression."""
def __init__(self, paramCount, body): def __init__(self, paramCount, body):
self.paramCount = paramCount self.paramCount = paramCount
self.body = body.as_builder() self.body = body.as_builder()
def call(self, params, _context, **kwargs): def call(self, params, _context, **kwargs):
'''Note that we're returning a Promise object here, and bypassing the """Note that we're returning a Promise object here, and bypassing the
helper functionality that normally sets the results struct from the helper functionality that normally sets the results struct from the
returned object. Instead, we set _context.results directly inside of returned object. Instead, we set _context.results directly inside of
another promise''' another promise"""
assert len(params) == self.paramCount assert len(params) == self.paramCount
# using setattr because '=' is not allowed inside of lambdas # using setattr because '=' is not allowed inside of lambdas
return evaluate_impl(self.body, params).then(lambda value: setattr(_context.results, 'value', value)) return evaluate_impl(self.body, params).then(
lambda value: setattr(_context.results, "value", value)
)
class OperatorImpl(calculator_capnp.Calculator.Function.Server): class OperatorImpl(calculator_capnp.Calculator.Function.Server):
'''Implementation of the Calculator.Function Cap'n Proto interface, wrapping """Implementation of the Calculator.Function Cap'n Proto interface, wrapping
basic binary arithmetic operators.''' basic binary arithmetic operators."""
def __init__(self, op): def __init__(self, op):
self.op = op self.op = op
@ -162,16 +158,16 @@ class OperatorImpl(calculator_capnp.Calculator.Function.Server):
op = self.op op = self.op
if op == 'add': if op == "add":
return params[0] + params[1] return params[0] + params[1]
elif op == 'subtract': elif op == "subtract":
return params[0] - params[1] return params[0] - params[1]
elif op == 'multiply': elif op == "multiply":
return params[0] * params[1] return params[0] * params[1]
elif op == 'divide': elif op == "divide":
return params[0] / params[1] return params[0] / params[1]
else: else:
raise ValueError('Unknown operator') raise ValueError("Unknown operator")
class CalculatorImpl(calculator_capnp.Calculator.Server): class CalculatorImpl(calculator_capnp.Calculator.Server):
@ -179,7 +175,9 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
"Implementation of the Calculator Cap'n Proto interface." "Implementation of the Calculator Cap'n Proto interface."
def evaluate(self, expression, _context, **kwargs): def evaluate(self, expression, _context, **kwargs):
return evaluate_impl(expression).then(lambda value: setattr(_context.results, 'value', ValueImpl(value))) return evaluate_impl(expression).then(
lambda value: setattr(_context.results, "value", ValueImpl(value))
)
def defFunction(self, paramCount, body, _context, **kwargs): def defFunction(self, paramCount, body, _context, **kwargs):
return FunctionImpl(paramCount, body) return FunctionImpl(paramCount, body)
@ -189,8 +187,10 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='''Runs the server bound to the\ parser = argparse.ArgumentParser(
given address/port ADDRESS. ''') usage="""Runs the server bound to the\
given address/port ADDRESS. """
)
parser.add_argument("address", help="ADDRESS:PORT") parser.add_argument("address", help="ADDRESS:PORT")
@ -204,34 +204,32 @@ async def new_connection(reader, writer):
async def main(): async def main():
address = parse_args().address address = parse_args().address
host = address.split(':') host = address.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Setup SSL context # Setup SSL context
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ctx.load_cert_chain(os.path.join(this_dir, 'selfsigned.cert'), os.path.join(this_dir, 'selfsigned.key')) ctx.load_cert_chain(
os.path.join(this_dir, "selfsigned.cert"),
os.path.join(this_dir, "selfsigned.key"),
)
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection, addr, port, ssl=ctx, family=socket.AF_INET
addr, port,
ssl=ctx,
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection, addr, port, ssl=ctx, family=socket.AF_INET6
addr, port,
ssl=ctx,
family=socket.AF_INET6
) )
async with server: async with server:
await server.serve_forever() await server.serve_forever()
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View file

@ -14,18 +14,20 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Example thread server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Example thread server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server): class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server):
'''An implementation of the StatusSubscriber interface''' """An implementation of the StatusSubscriber interface"""
def status(self, value, **kwargs): def status(self, value, **kwargs):
print('status: {}'.format(time.time())) print("status: {}".format(time.time()))
async def myreader(client, reader): async def myreader(client, reader):
@ -48,27 +50,25 @@ async def background(cap):
async def main(host): async def main(host):
host = host.split(':') host = host.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Setup SSL context # Setup SSL context
ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=os.path.join(this_dir, 'selfsigned.cert')) ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=os.path.join(this_dir, "selfsigned.cert")
)
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, ssl=ctx, family=socket.AF_INET
ssl=ctx,
family=socket.AF_INET
) )
except Exception: except Exception:
print("Try IPv6") print("Try IPv6")
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
addr, port, addr, port, ssl=ctx, family=socket.AF_INET6
ssl=ctx,
family=socket.AF_INET6
) )
# Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode) # Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode)
@ -84,15 +84,16 @@ async def main(host):
asyncio.gather(*tasks, return_exceptions=True) asyncio.gather(*tasks, return_exceptions=True)
# Run blocking tasks # Run blocking tasks
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
await cap.longRunning().a_wait() await cap.longRunning().a_wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
if __name__ == '__main__':
if __name__ == "__main__":
# Using asyncio.run hits an asyncio ssl bug # Using asyncio.run hits an asyncio ssl bug
# https://bugs.python.org/issue36709 # https://bugs.python.org/issue36709
# asyncio.run(main(parse_args().host), loop=loop, debug=True) # asyncio.run(main(parse_args().host), loop=loop, debug=True)

View file

@ -22,12 +22,15 @@ class ExampleImpl(thread_capnp.Example.Server):
"Implementation of the Example threading Cap'n Proto interface." "Implementation of the Example threading Cap'n Proto interface."
def subscribeStatus(self, subscriber, **kwargs): def subscribeStatus(self, subscriber, **kwargs):
return capnp.getTimer().after_delay(10**9) \ return (
.then(lambda: subscriber.status(True)) \ capnp.getTimer()
.after_delay(10 ** 9)
.then(lambda: subscriber.status(True))
.then(lambda _: self.subscribeStatus(subscriber)) .then(lambda _: self.subscribeStatus(subscriber))
)
def longRunning(self, **kwargs): def longRunning(self, **kwargs):
return capnp.getTimer().after_delay(1 * 10**9) return capnp.getTimer().after_delay(1 * 10 ** 9)
def alive(self, **kwargs): def alive(self, **kwargs):
return True return True
@ -38,10 +41,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.reader.read(4096), timeout=0.1)
self.reader.read(4096),
timeout=0.1
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("myreader timeout.") logger.debug("myreader timeout.")
continue continue
@ -56,10 +56,7 @@ class Server:
while self.retry: while self.retry:
try: try:
# Must be a wait_for so we don't block on read() # Must be a wait_for so we don't block on read()
data = await asyncio.wait_for( data = await asyncio.wait_for(self.server.read(4096), timeout=0.1)
self.server.read(4096),
timeout=0.1
)
self.writer.write(data.tobytes()) self.writer.write(data.tobytes())
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("mywriter timeout.") logger.debug("mywriter timeout.")
@ -99,8 +96,10 @@ async def new_connection(reader, writer):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='''Runs the server bound to the\ parser = argparse.ArgumentParser(
given address/port ADDRESS. ''') usage="""Runs the server bound to the\
given address/port ADDRESS. """
)
parser.add_argument("address", help="ADDRESS:PORT") parser.add_argument("address", help="ADDRESS:PORT")
@ -109,20 +108,24 @@ given address/port ADDRESS. ''')
async def main(): async def main():
address = parse_args().address address = parse_args().address
host = address.split(':') host = address.split(":")
addr = host[0] addr = host[0]
port = host[1] port = host[1]
# Setup SSL context # Setup SSL context
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ctx.load_cert_chain(os.path.join(this_dir, 'selfsigned.cert'), os.path.join(this_dir, 'selfsigned.key')) ctx.load_cert_chain(
os.path.join(this_dir, "selfsigned.cert"),
os.path.join(this_dir, "selfsigned.key"),
)
# Handle both IPv4 and IPv6 cases # Handle both IPv4 and IPv6 cases
try: try:
print("Try IPv4") print("Try IPv4")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection,
addr, port, addr,
port,
ssl=ctx, ssl=ctx,
family=socket.AF_INET, family=socket.AF_INET,
) )
@ -130,7 +133,8 @@ async def main():
print("Try IPv6") print("Try IPv6")
server = await asyncio.start_server( server = await asyncio.start_server(
new_connection, new_connection,
addr, port, addr,
port,
ssl=ctx, ssl=ctx,
family=socket.AF_INET6, family=socket.AF_INET6,
) )
@ -138,5 +142,6 @@ async def main():
async with server: async with server:
await server.serve_forever() await server.serve_forever()
if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View file

@ -8,23 +8,25 @@ import calculator_capnp
class PowerFunction(calculator_capnp.Calculator.Function.Server): class PowerFunction(calculator_capnp.Calculator.Function.Server):
'''An implementation of the Function interface wrapping pow(). Note that """An implementation of the Function interface wrapping pow(). Note that
we're implementing this on the client side and will pass a reference to we're implementing this on the client side and will pass a reference to
the server. The server will then be able to make calls back to the client.''' the server. The server will then be able to make calls back to the client."""
def call(self, params, **kwargs): def call(self, params, **kwargs):
'''Note the **kwargs. This is very necessary to include, since """Note the **kwargs. This is very necessary to include, since
protocols can add parameters over time. Also, by default, a _context protocols can add parameters over time. Also, by default, a _context
variable is passed to all server methods, but you can also return variable is passed to all server methods, but you can also return
results directly as python objects, and they'll be added to the results directly as python objects, and they'll be added to the
results struct in the correct order''' results struct in the correct order"""
return pow(params[0], params[1]) return pow(params[0], params[1])
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Calculator server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Calculator server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
@ -36,7 +38,7 @@ def main(host):
# Bootstrap the server capability and cast it to the Calculator interface # Bootstrap the server capability and cast it to the Calculator interface
calculator = client.bootstrap().cast_as(calculator_capnp.Calculator) calculator = client.bootstrap().cast_as(calculator_capnp.Calculator)
'''Make a request that just evaluates the literal value 123. """Make a request that just evaluates the literal value 123.
What's interesting here is that evaluate() returns a "Value", which is What's interesting here is that evaluate() returns a "Value", which is
another interface and therefore points back to an object living on the another interface and therefore points back to an object living on the
@ -44,9 +46,9 @@ def main(host):
However, even though we are making two RPC's, this block executes in However, even though we are making two RPC's, this block executes in
*one* network round trip because of promise pipelining: we do not wait *one* network round trip because of promise pipelining: we do not wait
for the first call to complete before we send the second call to the for the first call to complete before we send the second call to the
server.''' server."""
print('Evaluating a literal... ', end="") print("Evaluating a literal... ", end="")
# Make the request. Note we are using the shorter function form (instead # Make the request. Note we are using the shorter function form (instead
# of evaluate_request), and we are passing a dictionary that represents a # of evaluate_request), and we are passing a dictionary that represents a
@ -54,13 +56,13 @@ def main(host):
eval_promise = calculator.evaluate({"literal": 123}) eval_promise = calculator.evaluate({"literal": 123})
# This is equivalent to: # This is equivalent to:
''' """
request = calculator.evaluate_request() request = calculator.evaluate_request()
request.expression.literal = 123 request.expression.literal = 123
# Send it, which returns a promise for the result (without blocking). # Send it, which returns a promise for the result (without blocking).
eval_promise = request.send() eval_promise = request.send()
''' """
# Using the promise, create a pipelined request to call read() on the # Using the promise, create a pipelined request to call read() on the
# returned object. Note that here we are using the shortened method call # returned object. Note that here we are using the shortened method call
@ -74,32 +76,32 @@ def main(host):
print("PASS") print("PASS")
'''Make a request to evaluate 123 + 45 - 67. """Make a request to evaluate 123 + 45 - 67.
The Calculator interface requires that we first call getOperator() to The Calculator interface requires that we first call getOperator() to
get the addition and subtraction functions, then call evaluate() to use get the addition and subtraction functions, then call evaluate() to use
them. But, once again, we can get both functions, call evaluate(), and them. But, once again, we can get both functions, call evaluate(), and
then read() the result -- four RPCs -- in the time of *one* network then read() the result -- four RPCs -- in the time of *one* network
round trip, because of promise pipelining.''' round trip, because of promise pipelining."""
print("Using add and subtract... ", end='') print("Using add and subtract... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "subtract" function from the server. # Get the "subtract" function from the server.
subtract = calculator.getOperator(op='subtract').func subtract = calculator.getOperator(op="subtract").func
# Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate' # Build the request to evaluate 123 + 45 - 67. Note the form is 'evaluate'
# + '_request', where 'evaluate' is the name of the method we want to call # + '_request', where 'evaluate' is the name of the method we want to call
request = calculator.evaluate_request() request = calculator.evaluate_request()
subtract_call = request.expression.init('call') subtract_call = request.expression.init("call")
subtract_call.function = subtract subtract_call.function = subtract
subtract_params = subtract_call.init('params', 2) subtract_params = subtract_call.init("params", 2)
subtract_params[1].literal = 67.0 subtract_params[1].literal = 67.0
add_call = subtract_params[0].init('call') add_call = subtract_params[0].init("call")
add_call.function = add add_call.function = add
add_params = add_call.init('params', 2) add_params = add_call.init("params", 2)
add_params[0].literal = 123 add_params[0].literal = 123
add_params[1].literal = 45 add_params[1].literal = 45
@ -112,7 +114,7 @@ def main(host):
print("PASS") print("PASS")
''' """
Note: a one liner version of building the previous request (I highly Note: a one liner version of building the previous request (I highly
recommend not doing it this way for such a complicated structure, but I recommend not doing it this way for such a complicated structure, but I
just wanted to demonstrate it is possible to set all of the fields with a just wanted to demonstrate it is possible to set all of the fields with a
@ -124,22 +126,22 @@ def main(host):
'params': [{'literal': 123}, 'params': [{'literal': 123},
{'literal': 45}]}}, {'literal': 45}]}},
{'literal': 67.0}]}}) {'literal': 67.0}]}})
''' """
'''Make a request to evaluate 4 * 6, then use the result in two more """Make a request to evaluate 4 * 6, then use the result in two more
requests that add 3 and 5. requests that add 3 and 5.
Since evaluate() returns its result wrapped in a `Value`, we can pass Since evaluate() returns its result wrapped in a `Value`, we can pass
that `Value` back to the server in subsequent requests before the first that `Value` back to the server in subsequent requests before the first
`evaluate()` has actually returned. Thus, this example again does only `evaluate()` has actually returned. Thus, this example again does only
one network round trip.''' one network round trip."""
print("Pipelining eval() calls... ", end="") print("Pipelining eval() calls... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "multiply" function from the server. # Get the "multiply" function from the server.
multiply = calculator.getOperator(op='multiply').func multiply = calculator.getOperator(op="multiply").func
# Build the request to evaluate 4 * 6 # Build the request to evaluate 4 * 6
request = calculator.evaluate_request() request = calculator.evaluate_request()
@ -176,7 +178,7 @@ def main(host):
print("PASS") print("PASS")
'''Our calculator interface supports defining functions. Here we use it """Our calculator interface supports defining functions. Here we use it
to define two functions and then make calls to them as follows: to define two functions and then make calls to them as follows:
f(x, y) = x * 100 + y f(x, y) = x * 100 + y
@ -184,14 +186,14 @@ def main(host):
f(12, 34) f(12, 34)
g(21) g(21)
Once again, the whole thing takes only one network round trip.''' Once again, the whole thing takes only one network round trip."""
print("Defining functions... ", end="") print("Defining functions... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Get the "multiply" function from the server. # Get the "multiply" function from the server.
multiply = calculator.getOperator(op='multiply').func multiply = calculator.getOperator(op="multiply").func
# Define f. # Define f.
request = calculator.defFunction_request() request = calculator.defFunction_request()
@ -249,7 +251,7 @@ def main(host):
g_eval_request = calculator.evaluate_request() g_eval_request = calculator.evaluate_request()
g_call = g_eval_request.expression.init("call") g_call = g_eval_request.expression.init("call")
g_call.function = g g_call.function = g
g_call.init('params', 1)[0].literal = 21 g_call.init("params", 1)[0].literal = 21
g_eval_promise = g_eval_request.send().value.read() g_eval_promise = g_eval_request.send().value.read()
# Wait for the results. # Wait for the results.
@ -258,7 +260,7 @@ def main(host):
print("PASS") print("PASS")
'''Make a request that will call back to a function defined locally. """Make a request that will call back to a function defined locally.
Specifically, we will compute 2^(4 + 5). However, exponent is not Specifically, we will compute 2^(4 + 5). However, exponent is not
defined by the Calculator server. So, we'll implement the Function defined by the Calculator server. So, we'll implement the Function
@ -270,12 +272,12 @@ def main(host):
particular case, this could potentially be optimized by using a tail particular case, this could potentially be optimized by using a tail
call on the server side -- see CallContext::tailCall(). However, to call on the server side -- see CallContext::tailCall(). However, to
keep the example simpler, we haven't implemented this optimization in keep the example simpler, we haven't implemented this optimization in
the sample server.''' the sample server."""
print("Using a callback... ", end="") print("Using a callback... ", end="")
# Get the "add" function from the server. # Get the "add" function from the server.
add = calculator.getOperator(op='add').func add = calculator.getOperator(op="add").func
# Build the eval request for 2^(4+5). # Build the eval request for 2^(4+5).
request = calculator.evaluate_request() request = calculator.evaluate_request()
@ -298,5 +300,5 @@ def main(host):
print("PASS") print("PASS")
if __name__ == '__main__': if __name__ == "__main__":
main(parse_args().host) main(parse_args().host)

View file

@ -8,29 +8,29 @@ import calculator_capnp
def read_value(value): def read_value(value):
'''Helper function to asynchronously call read() on a Calculator::Value and """Helper function to asynchronously call read() on a Calculator::Value and
return a promise for the result. (In the future, the generated code might return a promise for the result. (In the future, the generated code might
include something like this automatically.)''' include something like this automatically.)"""
return value.read().then(lambda result: result.value) return value.read().then(lambda result: result.value)
def evaluate_impl(expression, params=None): def evaluate_impl(expression, params=None):
'''Implementation of CalculatorImpl::evaluate(), also shared by """Implementation of CalculatorImpl::evaluate(), also shared by
FunctionImpl::call(). In the latter case, `params` are the parameter FunctionImpl::call(). In the latter case, `params` are the parameter
values passed to the function; in the former case, `params` is just an values passed to the function; in the former case, `params` is just an
empty list.''' empty list."""
which = expression.which() which = expression.which()
if which == 'literal': if which == "literal":
return capnp.Promise(expression.literal) return capnp.Promise(expression.literal)
elif which == 'previousResult': elif which == "previousResult":
return read_value(expression.previousResult) return read_value(expression.previousResult)
elif which == 'parameter': elif which == "parameter":
assert expression.parameter < len(params) assert expression.parameter < len(params)
return capnp.Promise(params[expression.parameter]) return capnp.Promise(params[expression.parameter])
elif which == 'call': elif which == "call":
call = expression.call call = expression.call
func = call.function func = call.function
@ -39,9 +39,9 @@ def evaluate_impl(expression, params=None):
joinedParams = capnp.join_promises(paramPromises) joinedParams = capnp.join_promises(paramPromises)
# When the parameters are complete, call the function. # When the parameters are complete, call the function.
ret = (joinedParams ret = joinedParams.then(lambda vals: func.call(vals)).then(
.then(lambda vals: func.call(vals)) lambda result: result.value
.then(lambda result: result.value)) )
return ret return ret
else: else:
@ -61,28 +61,30 @@ class ValueImpl(calculator_capnp.Calculator.Value.Server):
class FunctionImpl(calculator_capnp.Calculator.Function.Server): class FunctionImpl(calculator_capnp.Calculator.Function.Server):
'''Implementation of the Calculator.Function Cap'n Proto interface, where the """Implementation of the Calculator.Function Cap'n Proto interface, where the
function is defined by a Calculator.Expression.''' function is defined by a Calculator.Expression."""
def __init__(self, paramCount, body): def __init__(self, paramCount, body):
self.paramCount = paramCount self.paramCount = paramCount
self.body = body.as_builder() self.body = body.as_builder()
def call(self, params, _context, **kwargs): def call(self, params, _context, **kwargs):
'''Note that we're returning a Promise object here, and bypassing the """Note that we're returning a Promise object here, and bypassing the
helper functionality that normally sets the results struct from the helper functionality that normally sets the results struct from the
returned object. Instead, we set _context.results directly inside of returned object. Instead, we set _context.results directly inside of
another promise''' another promise"""
assert len(params) == self.paramCount assert len(params) == self.paramCount
# using setattr because '=' is not allowed inside of lambdas # using setattr because '=' is not allowed inside of lambdas
return evaluate_impl(self.body, params).then(lambda value: setattr(_context.results, 'value', value)) return evaluate_impl(self.body, params).then(
lambda value: setattr(_context.results, "value", value)
)
class OperatorImpl(calculator_capnp.Calculator.Function.Server): class OperatorImpl(calculator_capnp.Calculator.Function.Server):
'''Implementation of the Calculator.Function Cap'n Proto interface, wrapping """Implementation of the Calculator.Function Cap'n Proto interface, wrapping
basic binary arithmetic operators.''' basic binary arithmetic operators."""
def __init__(self, op): def __init__(self, op):
self.op = op self.op = op
@ -92,16 +94,16 @@ class OperatorImpl(calculator_capnp.Calculator.Function.Server):
op = self.op op = self.op
if op == 'add': if op == "add":
return params[0] + params[1] return params[0] + params[1]
elif op == 'subtract': elif op == "subtract":
return params[0] - params[1] return params[0] - params[1]
elif op == 'multiply': elif op == "multiply":
return params[0] * params[1] return params[0] * params[1]
elif op == 'divide': elif op == "divide":
return params[0] / params[1] return params[0] / params[1]
else: else:
raise ValueError('Unknown operator') raise ValueError("Unknown operator")
class CalculatorImpl(calculator_capnp.Calculator.Server): class CalculatorImpl(calculator_capnp.Calculator.Server):
@ -109,7 +111,9 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
"Implementation of the Calculator Cap'n Proto interface." "Implementation of the Calculator Cap'n Proto interface."
def evaluate(self, expression, _context, **kwargs): def evaluate(self, expression, _context, **kwargs):
return evaluate_impl(expression).then(lambda value: setattr(_context.results, 'value', ValueImpl(value))) return evaluate_impl(expression).then(
lambda value: setattr(_context.results, "value", ValueImpl(value))
)
def defFunction(self, paramCount, body, _context, **kwargs): def defFunction(self, paramCount, body, _context, **kwargs):
return FunctionImpl(paramCount, body) return FunctionImpl(paramCount, body)
@ -119,9 +123,11 @@ class CalculatorImpl(calculator_capnp.Calculator.Server):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='''Runs the server bound to the\ parser = argparse.ArgumentParser(
usage="""Runs the server bound to the\
given address/port ADDRESS may be '*' to bind to all local addresses.\ given address/port ADDRESS may be '*' to bind to all local addresses.\
:PORT may be omitted to choose a port automatically. ''') :PORT may be omitted to choose a port automatically. """
)
parser.add_argument("address", help="ADDRESS[:PORT]") parser.add_argument("address", help="ADDRESS[:PORT]")
@ -137,5 +143,5 @@ def main():
time.sleep(0.001) time.sleep(0.001)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -12,8 +12,10 @@ capnp.create_event_loop(threaded=True)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='Connects to the Example thread server \ parser = argparse.ArgumentParser(
at the given address and does some RPCs') usage="Connects to the Example thread server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT") parser.add_argument("host", help="HOST:PORT")
return parser.parse_args() return parser.parse_args()
@ -21,10 +23,10 @@ at the given address and does some RPCs')
class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server): class StatusSubscriber(thread_capnp.Example.StatusSubscriber.Server):
'''An implementation of the StatusSubscriber interface''' """An implementation of the StatusSubscriber interface"""
def status(self, value, **kwargs): def status(self, value, **kwargs):
print('status: {}'.format(time.time())) print("status: {}".format(time.time()))
def start_status_thread(host): def start_status_thread(host):
@ -44,14 +46,14 @@ def main(host):
status_thread.daemon = True status_thread.daemon = True
status_thread.start() status_thread.start()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
cap.longRunning().wait() cap.longRunning().wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
cap.longRunning().wait() cap.longRunning().wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
cap.longRunning().wait() cap.longRunning().wait()
print('main: {}'.format(time.time())) print("main: {}".format(time.time()))
if __name__ == '__main__': if __name__ == "__main__":
main(parse_args().host) main(parse_args().host)

View file

@ -12,18 +12,23 @@ class ExampleImpl(thread_capnp.Example.Server):
"Implementation of the Example threading Cap'n Proto interface." "Implementation of the Example threading Cap'n Proto interface."
def subscribeStatus(self, subscriber, **kwargs): def subscribeStatus(self, subscriber, **kwargs):
return capnp.getTimer().after_delay(10**9) \ return (
.then(lambda: subscriber.status(True)) \ capnp.getTimer()
.after_delay(10 ** 9)
.then(lambda: subscriber.status(True))
.then(lambda _: self.subscribeStatus(subscriber)) .then(lambda _: self.subscribeStatus(subscriber))
)
def longRunning(self, **kwargs): def longRunning(self, **kwargs):
return capnp.getTimer().after_delay(1 * 10**9) return capnp.getTimer().after_delay(1 * 10 ** 9)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(usage='''Runs the server bound to the\ parser = argparse.ArgumentParser(
usage="""Runs the server bound to the\
given address/port ADDRESS may be '*' to bind to all local addresses.\ given address/port ADDRESS may be '*' to bind to all local addresses.\
:PORT may be omitted to choose a port automatically. ''') :PORT may be omitted to choose a port automatically. """
)
parser.add_argument("address", help="ADDRESS[:PORT]") parser.add_argument("address", help="ADDRESS[:PORT]")
@ -39,5 +44,5 @@ def main():
time.sleep(0.001) time.sleep(0.001)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -11,7 +11,12 @@ def parse_args():
parser.add_argument("command") parser.add_argument("command")
parser.add_argument("schema_file") parser.add_argument("schema_file")
parser.add_argument("struct_name") parser.add_argument("struct_name")
parser.add_argument("-d", "--defaults", help="include default values in json output", action="store_true") parser.add_argument(
"-d",
"--defaults",
help="include default values in json output",
action="store_true",
)
return parser.parse_args() return parser.parse_args()
@ -41,9 +46,11 @@ def main():
command = args.command command = args.command
kwargs = vars(args) kwargs = vars(args)
del kwargs['command'] del kwargs["command"]
globals()[command](**kwargs) # hacky way to get defined functions, and call function with name=command globals()[command](
**kwargs
) # hacky way to get defined functions, and call function with name=command
main() main()

View file

@ -3,7 +3,10 @@ import os
import sys import sys
import capnp import capnp
capnp.add_import_hook([os.getcwd(), "/usr/local/include/"]) # change this to be auto-detected?
capnp.add_import_hook(
[os.getcwd(), "/usr/local/include/"]
) # change this to be auto-detected?
import test_capnp # noqa: E402 import test_capnp # noqa: E402
@ -20,7 +23,7 @@ def encode(name):
print(message.to_bytes()) print(message.to_bytes())
if sys.argv[1] == 'decode': if sys.argv[1] == "decode":
decode(sys.argv[2]) decode(sys.argv[2])
else: else:
encode(sys.argv[2]) encode(sys.argv[2])

160
setup.py
View file

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
pycapnp distutils setup.py pycapnp distutils setup.py
''' """
import glob import glob
import os import os
@ -23,15 +23,15 @@ _this_dir = os.path.dirname(__file__)
MAJOR = 1 MAJOR = 1
MINOR = 1 MINOR = 1
MICRO = 0 MICRO = 0
TAG = '' TAG = ""
VERSION = '%d.%d.%d%s' % (MAJOR, MINOR, MICRO, TAG) VERSION = "%d.%d.%d%s" % (MAJOR, MINOR, MICRO, TAG)
# Write version info # Write version info
def write_version_py(filename=None): def write_version_py(filename=None):
''' """
Generate pycapnp version Generate pycapnp version
''' """
cnt = """\ cnt = """\
from .lib.capnp import _CAPNP_VERSION_MAJOR as LIBCAPNP_VERSION_MAJOR # noqa: F401 from .lib.capnp import _CAPNP_VERSION_MAJOR as LIBCAPNP_VERSION_MAJOR # noqa: F401
from .lib.capnp import _CAPNP_VERSION_MINOR as LIBCAPNP_VERSION_MINOR # noqa: F401 from .lib.capnp import _CAPNP_VERSION_MINOR as LIBCAPNP_VERSION_MINOR # noqa: F401
@ -42,10 +42,9 @@ version = '%s'
short_version = '%s' short_version = '%s'
""" """
if not filename: if not filename:
filename = os.path.join( filename = os.path.join(os.path.dirname(__file__), "capnp", "version.py")
os.path.dirname(__file__), 'capnp', 'version.py')
a = open(filename, 'w') a = open(filename, "w")
try: try:
a.write(cnt % (VERSION, VERSION)) a.write(cnt % (VERSION, VERSION))
finally: finally:
@ -55,30 +54,31 @@ short_version = '%s'
write_version_py() write_version_py()
# Try to use README.md and CHANGELOG.md as description and changelog # Try to use README.md and CHANGELOG.md as description and changelog
with open('README.md', encoding='utf-8') as f: with open("README.md", encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
with open('CHANGELOG.md', encoding='utf-8') as f: with open("CHANGELOG.md", encoding="utf-8") as f:
changelog = f.read() changelog = f.read()
changelog = '\nChangelog\n=============\n' + changelog changelog = "\nChangelog\n=============\n" + changelog
long_description += changelog long_description += changelog
class clean(_clean): class clean(_clean):
''' """
Clean command, invoked with `python setup.py clean` Clean command, invoked with `python setup.py clean`
''' """
def run(self): def run(self):
_clean.run(self) _clean.run(self)
for x in [ for x in [
os.path.join('capnp', 'lib', 'capnp.cpp'), os.path.join("capnp", "lib", "capnp.cpp"),
os.path.join('capnp', 'lib', 'capnp.h'), os.path.join("capnp", "lib", "capnp.h"),
os.path.join('capnp', 'version.py'), os.path.join("capnp", "version.py"),
'build', "build",
'build32', "build32",
'build64', "build64",
'bundled' "bundled",
] + glob.glob(os.path.join('capnp', '*.capnp')): ] + glob.glob(os.path.join("capnp", "*.capnp")):
print('removing %s' % x) print("removing %s" % x)
try: try:
os.remove(x) os.remove(x)
except OSError: except OSError:
@ -109,9 +109,10 @@ from Cython.Distutils import build_ext as build_ext_c # noqa: E402
class build_libcapnp_ext(build_ext_c): class build_libcapnp_ext(build_ext_c):
''' """
Build capnproto library Build capnproto library
''' """
def build_extension(self, ext): def build_extension(self, ext):
build_ext_c.build_extension(self, ext) build_ext_c.build_extension(self, ext)
@ -126,12 +127,16 @@ class build_libcapnp_ext(build_ext_c):
if capnp_executable: if capnp_executable:
capnp_dir = os.path.dirname(capnp_executable) capnp_dir = os.path.dirname(capnp_executable)
self.include_dirs += [os.path.join(capnp_dir, "..", "include")] self.include_dirs += [os.path.join(capnp_dir, "..", "include")]
self.library_dirs += [os.path.join(capnp_dir, "..", "lib{}".format(8 * struct.calcsize("P")))] self.library_dirs += [
os.path.join(
capnp_dir, "..", "lib{}".format(8 * struct.calcsize("P"))
)
]
self.library_dirs += [os.path.join(capnp_dir, "..", "lib")] self.library_dirs += [os.path.join(capnp_dir, "..", "lib")]
# Look for capnproto using pkg-config (and minimum version) # Look for capnproto using pkg-config (and minimum version)
try: try:
if pkgconfig.installed('capnp', '>= 0.8.0'): if pkgconfig.installed("capnp", ">= 0.8.0"):
need_build = False need_build = False
else: else:
need_build = True need_build = True
@ -149,14 +154,16 @@ class build_libcapnp_ext(build_ext_c):
bundle_dir = os.path.join(_this_dir, "bundled") bundle_dir = os.path.join(_this_dir, "bundled")
if not os.path.exists(bundle_dir): if not os.path.exists(bundle_dir):
os.mkdir(bundle_dir) os.mkdir(bundle_dir)
build_dir = os.path.join(_this_dir, "build{}".format(8 * struct.calcsize("P"))) build_dir = os.path.join(
_this_dir, "build{}".format(8 * struct.calcsize("P"))
)
if not os.path.exists(build_dir): if not os.path.exists(build_dir):
os.mkdir(build_dir) os.mkdir(build_dir)
# Check if we've already built capnproto # Check if we've already built capnproto
capnp_bin = os.path.join(build_dir, 'bin', 'capnp') capnp_bin = os.path.join(build_dir, "bin", "capnp")
if os.name == 'nt': if os.name == "nt":
capnp_bin = os.path.join(build_dir, 'bin', 'capnp.exe') capnp_bin = os.path.join(build_dir, "bin", "capnp.exe")
if not os.path.exists(capnp_bin): if not os.path.exists(capnp_bin):
# Not built, fetch and build # Not built, fetch and build
@ -165,12 +172,14 @@ class build_libcapnp_ext(build_ext_c):
else: else:
print("capnproto already built at {}".format(build_dir)) print("capnproto already built at {}".format(build_dir))
self.include_dirs += [os.path.join(build_dir, 'include')] self.include_dirs += [os.path.join(build_dir, "include")]
self.library_dirs += [os.path.join(build_dir, 'lib{}'.format(8 * struct.calcsize("P")))] self.library_dirs += [
self.library_dirs += [os.path.join(build_dir, 'lib')] os.path.join(build_dir, "lib{}".format(8 * struct.calcsize("P")))
]
self.library_dirs += [os.path.join(build_dir, "lib")]
# Copy .capnp files from source # Copy .capnp files from source
src_glob = glob.glob(os.path.join(build_dir, 'include', 'capnp', '*.capnp')) src_glob = glob.glob(os.path.join(build_dir, "include", "capnp", "*.capnp"))
dst_dir = os.path.join(self.build_lib, "capnp") dst_dir = os.path.join(self.build_lib, "capnp")
for file in src_glob: for file in src_glob:
print("copying {} -> {}".format(file, dst_dir)) print("copying {} -> {}".format(file, dst_dir))
@ -179,64 +188,71 @@ class build_libcapnp_ext(build_ext_c):
return build_ext_c.run(self) return build_ext_c.run(self)
extra_compile_args = ['--std=c++14'] extra_compile_args = ["--std=c++14"]
extra_link_args = [] extra_link_args = []
if os.name == 'nt': if os.name == "nt":
extra_compile_args = ['/std:c++14', '/MD'] extra_compile_args = ["/std:c++14", "/MD"]
extra_link_args = ['/MANIFEST'] extra_link_args = ["/MANIFEST"]
import Cython.Build # noqa: E402 import Cython.Build # noqa: E402
import Cython # noqa: E402 import Cython # noqa: E402
extensions = [Extension( extensions = [
'*', ['capnp/helpers/capabilityHelper.cpp', 'capnp/lib/*.pyx'], Extension(
extra_compile_args=extra_compile_args, "*",
extra_link_args=extra_link_args, ["capnp/helpers/capabilityHelper.cpp", "capnp/lib/*.pyx"],
language='c++', extra_compile_args=extra_compile_args,
)] extra_link_args=extra_link_args,
language="c++",
)
]
setup( setup(
name="pycapnp", name="pycapnp",
packages=["capnp"], packages=["capnp"],
version=VERSION, version=VERSION,
package_data={ package_data={
'capnp': [ "capnp": [
'*.pxd', '*.h', '*.capnp', 'helpers/*.pxd', 'helpers/*.h', "*.pxd",
'includes/*.pxd', 'lib/*.pxd', 'lib/*.py', 'lib/*.pyx', 'templates/*' "*.h",
"*.capnp",
"helpers/*.pxd",
"helpers/*.h",
"includes/*.pxd",
"lib/*.pxd",
"lib/*.py",
"lib/*.pyx",
"templates/*",
] ]
}, },
ext_modules=Cython.Build.cythonize(extensions), ext_modules=Cython.Build.cythonize(extensions),
cmdclass={ cmdclass={"clean": clean, "build_ext": build_libcapnp_ext},
'clean': clean,
'build_ext': build_libcapnp_ext
},
install_requires=[], install_requires=[],
entry_points={ entry_points={"console_scripts": ["capnpc-cython = capnp._gen:main"]},
"console_scripts": ["capnpc-cython = capnp._gen:main"]
},
# PyPi info # PyPi info
description="A cython wrapping of the C++ Cap'n Proto library", description="A cython wrapping of the C++ Cap'n Proto library",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license='BSD', license="BSD",
# (setup.py only supports 1 author...) # (setup.py only supports 1 author...)
author="Jacob Alexander", # <- Current maintainer; Original author -> Jason Paryani author="Jacob Alexander", # <- Current maintainer; Original author -> Jason Paryani
author_email="haata@kiibohd.com", author_email="haata@kiibohd.com",
url='https://github.com/capnproto/pycapnp', url="https://github.com/capnproto/pycapnp",
download_url='https://github.com/haata/pycapnp/archive/v%s.zip' % VERSION, download_url="https://github.com/haata/pycapnp/archive/v%s.zip" % VERSION,
keywords=['capnp', 'capnproto', "Cap'n Proto", 'pycapnp'], keywords=["capnp", "capnproto", "Cap'n Proto", "pycapnp"],
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', "Development Status :: 5 - Production/Stable",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: BSD License', "License :: OSI Approved :: BSD License",
'Operating System :: MacOS :: MacOS X', "Operating System :: MacOS :: MacOS X",
'Operating System :: Microsoft :: Windows :: Windows 10', "Operating System :: Microsoft :: Windows :: Windows 10",
'Operating System :: POSIX', "Operating System :: POSIX",
'Programming Language :: C++', "Programming Language :: C++",
'Programming Language :: Cython', "Programming Language :: Cython",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Programming Language :: Python :: 3.8', "Programming Language :: Python :: 3.8",
'Programming Language :: Python :: 3.9', "Programming Language :: Python :: 3.9",
'Programming Language :: Python :: Implementation :: PyPy', "Programming Language :: Python :: Implementation :: PyPy",
'Topic :: Communications'], "Topic :: Communications",
],
) )

View file

@ -16,17 +16,17 @@ class Server(capability.TestInterface.Server):
return str(i * 5 + extra + self.val) return str(i * 5 + extra + self.val)
def buz(self, i, **kwargs): def buz(self, i, **kwargs):
return i.host + '_test' return i.host + "_test"
def bam(self, i, **kwargs): def bam(self, i, **kwargs):
return str(i) + '_test', i return str(i) + "_test", i
class PipelineServer(capability.TestPipeline.Server): class PipelineServer(capability.TestPipeline.Server):
def getCap(self, n, inCap, _context, **kwargs): def getCap(self, n, inCap, _context, **kwargs):
def _then(response): def _then(response):
_results = _context.results _results = _context.results
_results.s = response.x + '_foo' _results.s = response.x + "_foo"
_results.outBox.cap = Server(100) _results.outBox.cap = Server(100)
return inCap.foo(i=n).then(_then) return inCap.foo(i=n).then(_then)
@ -35,13 +35,13 @@ class PipelineServer(capability.TestPipeline.Server):
def test_client(): def test_client():
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
req = client._request('foo') req = client._request("foo")
req.i = 5 req.i = 5
remote = req.send() remote = req.send()
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
req = client.foo_request() req = client.foo_request()
req.i = 5 req.i = 5
@ -49,7 +49,7 @@ def test_client():
remote = req.send() remote = req.send()
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
client.foo2_request() client.foo2_request()
@ -57,7 +57,7 @@ def test_client():
req = client.foo_request() req = client.foo_request()
with pytest.raises(Exception): with pytest.raises(Exception):
req.i = 'foo' req.i = "foo"
req = client.foo_request() req = client.foo_request()
@ -68,45 +68,45 @@ def test_client():
def test_simple_client(): def test_simple_client():
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
remote = client._send('foo', i=5) remote = client._send("foo", i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(i=5) remote = client.foo(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(i=5, j=True) remote = client.foo(i=5, j=True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.foo(5) remote = client.foo(5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(5, True) remote = client.foo(5, True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.foo(5, j=True) remote = client.foo(5, j=True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.buz(capability.TestSturdyRefHostId.new_message(host='localhost')) remote = client.buz(capability.TestSturdyRefHostId.new_message(host="localhost"))
response = remote.wait() response = remote.wait()
assert response.x == 'localhost_test' assert response.x == "localhost_test"
remote = client.bam(i=5) remote = client.bam(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '5_test' assert response.x == "5_test"
assert response.i == 5 assert response.i == 5
with pytest.raises(Exception): with pytest.raises(Exception):
@ -116,7 +116,7 @@ def test_simple_client():
remote = client.foo(5, True, 100) remote = client.foo(5, True, 100)
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(i='foo') remote = client.foo(i="foo")
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
remote = client.foo2(i=5) remote = client.foo2(i=5)
@ -135,10 +135,10 @@ def test_pipeline():
pipelinePromise = outCap.foo(i=10) pipelinePromise = outCap.foo(i=10)
response = pipelinePromise.wait() response = pipelinePromise.wait()
assert response.x == '150' assert response.x == "150"
response = remote.wait() response = remote.wait()
assert response.s == '26_foo' assert response.s == "26_foo"
class BadServer(capability.TestInterface.Server): class BadServer(capability.TestInterface.Server):
@ -155,7 +155,7 @@ class BadServer(capability.TestInterface.Server):
def test_exception_client(): def test_exception_client():
client = capability.TestInterface._new_client(BadServer()) client = capability.TestInterface._new_client(BadServer())
remote = client._send('foo', i=5) remote = client._send("foo", i=5)
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
remote.wait() remote.wait()
@ -164,11 +164,11 @@ class BadPipelineServer(capability.TestPipeline.Server):
def getCap(self, n, inCap, _context, **kwargs): def getCap(self, n, inCap, _context, **kwargs):
def _then(response): def _then(response):
_results = _context.results _results = _context.results
_results.s = response.x + '_foo' _results.s = response.x + "_foo"
_results.outBox.cap = Server(100) _results.outBox.cap = Server(100)
def _error(error): def _error(error):
raise Exception('test was a success') raise Exception("test was a success")
return inCap.foo(i=n).then(_then, _error) return inCap.foo(i=n).then(_then, _error)
@ -182,7 +182,7 @@ def test_exception_chain():
try: try:
remote.wait() remote.wait()
except Exception as e: except Exception as e:
assert 'test was a success' in str(e) assert "test was a success" in str(e)
def test_pipeline_exception(): def test_pipeline_exception():
@ -226,7 +226,7 @@ class TailCaller(capability.TestTailCaller.Server):
def foo(self, i, callee, _context, **kwargs): def foo(self, i, callee, _context, **kwargs):
self.count += 1 self.count += 1
tail = callee.foo_request(i=i, t='from TailCaller') tail = callee.foo_request(i=i, t="from TailCaller")
return _context.tail_call(tail) return _context.tail_call(tail)
@ -275,7 +275,7 @@ def test_tail_call():
def test_cancel(): def test_cancel():
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
req = client._request('foo') req = client._request("foo")
req.i = 5 req.i = 5
remote = req.send() remote = req.send()
@ -292,17 +292,17 @@ def test_timer():
def set_timer_var(): def set_timer_var():
global test_timer_var global test_timer_var
test_timer_var = True test_timer_var = True
capnp.getTimer().after_delay(1).then(set_timer_var).wait() capnp.getTimer().after_delay(1).then(set_timer_var).wait()
assert test_timer_var is True assert test_timer_var is True
test_timer_var = False test_timer_var = False
promise = capnp.Promise(0).then( promise = (
lambda x: time.sleep(.1) capnp.Promise(0)
).then( .then(lambda x: time.sleep(0.1))
lambda x: time.sleep(.1) .then(lambda x: time.sleep(0.1))
).then( .then(lambda x: set_timer_var())
lambda x: set_timer_var()
) )
canceller = capnp.getTimer().after_delay(1).then(lambda: promise.cancel()) canceller = capnp.getTimer().after_delay(1).then(lambda: promise.cancel())
@ -317,7 +317,7 @@ def test_timer():
def test_double_send(): def test_double_send():
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
req = client._request('foo') req = client._request("foo")
req.i = 5 req.i = 5
req.send() req.send()
@ -362,19 +362,20 @@ def test_inheritance():
remote = client.foo(i=5) remote = client.foo(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
class PassedCapTest(capability.TestPassedCap.Server): class PassedCapTest(capability.TestPassedCap.Server):
def foo(self, cap, _context, **kwargs): def foo(self, cap, _context, **kwargs):
def set_result(res): def set_result(res):
_context.results.x = res.x _context.results.x = res.x
return cap.foo(5).then(set_result) return cap.foo(5).then(set_result)
def test_null_cap(): def test_null_cap():
client = capability.TestPassedCap._new_client(PassedCapTest()) client = capability.TestPassedCap._new_client(PassedCapTest())
assert client.foo(Server()).wait().x == '26' assert client.foo(Server()).wait().x == "26"
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
client.foo().wait() client.foo().wait()
@ -387,14 +388,14 @@ class StructArgTest(capability.TestStructArg.Server):
def test_struct_args(): def test_struct_args():
client = capability.TestStructArg._new_client(StructArgTest()) client = capability.TestStructArg._new_client(StructArgTest())
assert client.bar(a='test', b=1).wait().c == 'test1' assert client.bar(a="test", b=1).wait().c == "test1"
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
assert client.bar('test', 1).wait().c == 'test1' assert client.bar("test", 1).wait().c == "test1"
class GenericTest(capability.TestGeneric.Server): class GenericTest(capability.TestGeneric.Server):
def foo(self, a, **kwargs): def foo(self, a, **kwargs):
return a.as_text() + 'test' return a.as_text() + "test"
def test_generic(): def test_generic():
@ -402,4 +403,4 @@ def test_generic():
obj = capnp._MallocMessageBuilder().get_root_as_any() obj = capnp._MallocMessageBuilder().get_root_as_any()
obj.set_as_text("anypointer_") obj.set_as_text("anypointer_")
assert client.foo(obj).wait().b == 'anypointer_test' assert client.foo(obj).wait().b == "anypointer_test"

View file

@ -7,10 +7,12 @@ this_dir = os.path.dirname(__file__)
# flake8: noqa: E501 # flake8: noqa: E501
@pytest.fixture @pytest.fixture
def capability(): def capability():
capnp.cleanup_global_schema_parser() capnp.cleanup_global_schema_parser()
return capnp.load(os.path.join(this_dir, 'test_capability.capnp')) return capnp.load(os.path.join(this_dir, "test_capability.capnp"))
class Server: class Server:
def __init__(self, val=1): def __init__(self, val=1):
@ -23,26 +25,30 @@ class Server:
context.results.x = str(context.params.i * 5 + extra + self.val) context.results.x = str(context.params.i * 5 + extra + self.val)
def buz_context(self, context): def buz_context(self, context):
context.results.x = context.params.i.host + '_test' context.results.x = context.params.i.host + "_test"
class PipelineServer: class PipelineServer:
def getCap_context(self, context): def getCap_context(self, context):
def _then(response): def _then(response):
context.results.s = response.x + '_foo' context.results.s = response.x + "_foo"
context.results.outBox.cap = capability().TestInterface._new_server(Server(100)) context.results.outBox.cap = capability().TestInterface._new_server(
Server(100)
)
return context.params.inCap.foo(i=context.params.n).then(_then) return context.params.inCap.foo(i=context.params.n).then(_then)
def test_client_context(capability): def test_client_context(capability):
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
req = client._request('foo') req = client._request("foo")
req.i = 5 req.i = 5
remote = req.send() remote = req.send()
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
req = client.foo_request() req = client.foo_request()
req.i = 5 req.i = 5
@ -50,7 +56,7 @@ def test_client_context(capability):
remote = req.send() remote = req.send()
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
client.foo2_request() client.foo2_request()
@ -58,51 +64,51 @@ def test_client_context(capability):
req = client.foo_request() req = client.foo_request()
with pytest.raises(Exception): with pytest.raises(Exception):
req.i = 'foo' req.i = "foo"
req = client.foo_request() req = client.foo_request()
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
req.baz = 1 req.baz = 1
def test_simple_client_context(capability): def test_simple_client_context(capability):
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
remote = client._send('foo', i=5) remote = client._send("foo", i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(i=5) remote = client.foo(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(i=5, j=True) remote = client.foo(i=5, j=True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.foo(5) remote = client.foo(5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(5, True) remote = client.foo(5, True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.foo(5, j=True) remote = client.foo(5, j=True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.buz(capability.TestSturdyRefHostId.new_message(host='localhost')) remote = client.buz(capability.TestSturdyRefHostId.new_message(host="localhost"))
response = remote.wait() response = remote.wait()
assert response.x == 'localhost_test' assert response.x == "localhost_test"
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(5, 10) remote = client.foo(5, 10)
@ -111,7 +117,7 @@ def test_simple_client_context(capability):
remote = client.foo(5, True, 100) remote = client.foo(5, True, 100)
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(i='foo') remote = client.foo(i="foo")
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
remote = client.foo2(i=5) remote = client.foo2(i=5)
@ -119,15 +125,16 @@ def test_simple_client_context(capability):
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(baz=5) remote = client.foo(baz=5)
@pytest.mark.xfail @pytest.mark.xfail
def test_pipeline_context(capability): def test_pipeline_context(capability):
''' """
E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:61: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly, E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:61: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly,
E but are created automatically when test functions request them as parameters. E but are created automatically when test functions request them as parameters.
E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and
E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code. E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code.
E stack: 7f87c1ac6e40 7f87c17c3250 7f87c17be260 7f87c17c49f0 7f87c17c0f50 7f87c17c5540 7f87c17d7bf0 7f87c1acb768 7f87c1aaf185 7f87c1aaf2dc 7f87c1a6da1d 7f87c3895459 7f87c3895713 7f87c38c72eb 7f87c3901409 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c38fdb77 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c3901409 7f87c38b6632 7f87c38c71cf 7f87c3901409 E stack: 7f87c1ac6e40 7f87c17c3250 7f87c17be260 7f87c17c49f0 7f87c17c0f50 7f87c17c5540 7f87c17d7bf0 7f87c1acb768 7f87c1aaf185 7f87c1aaf2dc 7f87c1a6da1d 7f87c3895459 7f87c3895713 7f87c38c72eb 7f87c3901409 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c38fdb77 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c3901409 7f87c38b6632 7f87c38c71cf 7f87c3901409
''' """
client = capability.TestPipeline._new_client(PipelineServer()) client = capability.TestPipeline._new_client(PipelineServer())
foo_client = capability.TestInterface._new_client(Server()) foo_client = capability.TestInterface._new_client(Server())
@ -137,10 +144,11 @@ def test_pipeline_context(capability):
pipelinePromise = outCap.foo(i=10) pipelinePromise = outCap.foo(i=10)
response = pipelinePromise.wait() response = pipelinePromise.wait()
assert response.x == '150' assert response.x == "150"
response = remote.wait() response = remote.wait()
assert response.s == '26_foo' assert response.s == "26_foo"
class BadServer: class BadServer:
def __init__(self, val=1): def __init__(self, val=1):
@ -148,26 +156,31 @@ class BadServer:
def foo_context(self, context): def foo_context(self, context):
context.results.x = str(context.params.i * 5 + self.val) context.results.x = str(context.params.i * 5 + self.val)
context.results.x2 = 5 # raises exception context.results.x2 = 5 # raises exception
def test_exception_client_context(capability): def test_exception_client_context(capability):
client = capability.TestInterface._new_client(BadServer()) client = capability.TestInterface._new_client(BadServer())
remote = client._send('foo', i=5) remote = client._send("foo", i=5)
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
remote.wait() remote.wait()
class BadPipelineServer: class BadPipelineServer:
def getCap_context(self, context): def getCap_context(self, context):
def _then(response): def _then(response):
context.results.s = response.x + '_foo' context.results.s = response.x + "_foo"
context.results.outBox.cap = capability().TestInterface._new_server(Server(100)) context.results.outBox.cap = capability().TestInterface._new_server(
Server(100)
)
def _error(error): def _error(error):
raise Exception('test was a success') raise Exception("test was a success")
return context.params.inCap.foo(i=context.params.n).then(_then, _error) return context.params.inCap.foo(i=context.params.n).then(_then, _error)
def test_exception_chain_context(capability): def test_exception_chain_context(capability):
client = capability.TestPipeline._new_client(BadPipelineServer()) client = capability.TestPipeline._new_client(BadPipelineServer())
foo_client = capability.TestInterface._new_client(BadServer()) foo_client = capability.TestInterface._new_client(BadServer())
@ -177,7 +190,8 @@ def test_exception_chain_context(capability):
try: try:
remote.wait() remote.wait()
except Exception as e: except Exception as e:
assert 'test was a success' in str(e) assert "test was a success" in str(e)
def test_pipeline_exception_context(capability): def test_pipeline_exception_context(capability):
client = capability.TestPipeline._new_client(BadPipelineServer()) client = capability.TestPipeline._new_client(BadPipelineServer())
@ -194,6 +208,7 @@ def test_pipeline_exception_context(capability):
with pytest.raises(Exception): with pytest.raises(Exception):
remote.wait() remote.wait()
def test_casting_context(capability): def test_casting_context(capability):
client = capability.TestExtends._new_client(Server()) client = capability.TestExtends._new_client(Server())
client2 = client.upcast(capability.TestInterface) client2 = client.upcast(capability.TestInterface)
@ -202,6 +217,7 @@ def test_casting_context(capability):
with pytest.raises(Exception): with pytest.raises(Exception):
client.upcast(capability.TestPipeline) client.upcast(capability.TestPipeline)
class TailCallOrder: class TailCallOrder:
def __init__(self): def __init__(self):
self.count = -1 self.count = -1
@ -210,6 +226,7 @@ class TailCallOrder:
self.count += 1 self.count += 1
context.results.n = self.count context.results.n = self.count
class TailCaller: class TailCaller:
def __init__(self): def __init__(self):
self.count = 0 self.count = 0
@ -217,9 +234,12 @@ class TailCaller:
def foo_context(self, context): def foo_context(self, context):
self.count += 1 self.count += 1
tail = context.params.callee.foo_request(i=context.params.i, t='from TailCaller') tail = context.params.callee.foo_request(
i=context.params.i, t="from TailCaller"
)
return context.tail_call(tail) return context.tail_call(tail)
class TailCallee: class TailCallee:
def __init__(self): def __init__(self):
self.count = 0 self.count = 0
@ -232,15 +252,16 @@ class TailCallee:
results.t = context.params.t results.t = context.params.t
results.c = capability().TestCallOrder._new_server(TailCallOrder()) results.c = capability().TestCallOrder._new_server(TailCallOrder())
@pytest.mark.xfail @pytest.mark.xfail
def test_tail_call(capability): def test_tail_call(capability):
''' """
E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:75: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly, E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:75: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly,
E but are created automatically when test functions request them as parameters. E but are created automatically when test functions request them as parameters.
E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and
E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code. E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code.
E stack: 7f87c17c5540 7f87c17c51b0 7f87c17c5540 7f87c17d7bf0 7f87c1acb768 7f87c1aaf185 7f87c1aaf2dc 7f87c1a6da1d 7f87c3895459 7f87c3895713 7f87c38c72eb 7f87c3901409 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c38fdb77 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c3901409 7f87c38b6632 7f87c38c71cf 7f87c3901409 7f87c38b5767 7f87c38b6e7e 7f87c388ace7 E stack: 7f87c17c5540 7f87c17c51b0 7f87c17c5540 7f87c17d7bf0 7f87c1acb768 7f87c1aaf185 7f87c1aaf2dc 7f87c1a6da1d 7f87c3895459 7f87c3895713 7f87c38c72eb 7f87c3901409 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b6e7e 7f87c38fe48d 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c38fdb77 7f87c38b5767 7f87c38b67d2 7f87c38c71cf 7f87c3901409 7f87c38b6632 7f87c38c71cf 7f87c3901409 7f87c38b5767 7f87c38b6e7e 7f87c388ace7
''' """
callee_server = TailCallee() callee_server = TailCallee()
caller_server = TailCaller() caller_server = TailCaller()

View file

@ -7,9 +7,11 @@ this_dir = os.path.dirname(__file__)
# flake8: noqa: E501 # flake8: noqa: E501
@pytest.fixture @pytest.fixture
def capability(): def capability():
return capnp.load(os.path.join(this_dir, 'test_capability.capnp')) return capnp.load(os.path.join(this_dir, "test_capability.capnp"))
class Server: class Server:
def __init__(self, val=1): def __init__(self, val=1):
@ -22,27 +24,29 @@ class Server:
return str(i * 5 + extra + self.val) return str(i * 5 + extra + self.val)
def buz(self, i, **kwargs): def buz(self, i, **kwargs):
return i.host + '_test' return i.host + "_test"
class PipelineServer: class PipelineServer:
def getCap(self, n, inCap, _context, **kwargs): def getCap(self, n, inCap, _context, **kwargs):
def _then(response): def _then(response):
_results = _context.results _results = _context.results
_results.s = response.x + '_foo' _results.s = response.x + "_foo"
_results.outBox.cap = capability().TestInterface._new_server(Server(100)) _results.outBox.cap = capability().TestInterface._new_server(Server(100))
return inCap.foo(i=n).then(_then) return inCap.foo(i=n).then(_then)
def test_client(capability): def test_client(capability):
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
req = client._request('foo') req = client._request("foo")
req.i = 5 req.i = 5
remote = req.send() remote = req.send()
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
req = client.foo_request() req = client.foo_request()
req.i = 5 req.i = 5
@ -50,7 +54,7 @@ def test_client(capability):
remote = req.send() remote = req.send()
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
client.foo2_request() client.foo2_request()
@ -58,51 +62,51 @@ def test_client(capability):
req = client.foo_request() req = client.foo_request()
with pytest.raises(Exception): with pytest.raises(Exception):
req.i = 'foo' req.i = "foo"
req = client.foo_request() req = client.foo_request()
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
req.baz = 1 req.baz = 1
def test_simple_client(capability): def test_simple_client(capability):
client = capability.TestInterface._new_client(Server()) client = capability.TestInterface._new_client(Server())
remote = client._send('foo', i=5) remote = client._send("foo", i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(i=5) remote = client.foo(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(i=5, j=True) remote = client.foo(i=5, j=True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.foo(5) remote = client.foo(5)
response = remote.wait() response = remote.wait()
assert response.x == '26' assert response.x == "26"
remote = client.foo(5, True) remote = client.foo(5, True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.foo(5, j=True) remote = client.foo(5, j=True)
response = remote.wait() response = remote.wait()
assert response.x == '27' assert response.x == "27"
remote = client.buz(capability.TestSturdyRefHostId.new_message(host='localhost')) remote = client.buz(capability.TestSturdyRefHostId.new_message(host="localhost"))
response = remote.wait() response = remote.wait()
assert response.x == 'localhost_test' assert response.x == "localhost_test"
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(5, 10) remote = client.foo(5, 10)
@ -111,7 +115,7 @@ def test_simple_client(capability):
remote = client.foo(5, True, 100) remote = client.foo(5, True, 100)
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(i='foo') remote = client.foo(i="foo")
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
remote = client.foo2(i=5) remote = client.foo2(i=5)
@ -119,15 +123,16 @@ def test_simple_client(capability):
with pytest.raises(Exception): with pytest.raises(Exception):
remote = client.foo(baz=5) remote = client.foo(baz=5)
@pytest.mark.xfail @pytest.mark.xfail
def test_pipeline(capability): def test_pipeline(capability):
''' """
E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:61: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly, E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:61: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly,
E but are created automatically when test functions request them as parameters. E but are created automatically when test functions request them as parameters.
E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and
E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code. E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code.
E stack: 7f680f7fce40 7f680f4f9250 7f680f4f4260 7f680f4fa9f0 7f680f4f6f50 7f680f4fb540 7f680f50dbf0 7f680f801768 7f680f7e5185 7f680f7e52dc 7f680f7a3a1d 7f68115cb459 7f68115cb713 7f68115fd2eb 7f6811637409 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811633b77 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811637409 7f68115ec632 7f68115fd1cf 7f6811637409 E stack: 7f680f7fce40 7f680f4f9250 7f680f4f4260 7f680f4fa9f0 7f680f4f6f50 7f680f4fb540 7f680f50dbf0 7f680f801768 7f680f7e5185 7f680f7e52dc 7f680f7a3a1d 7f68115cb459 7f68115cb713 7f68115fd2eb 7f6811637409 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811633b77 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811637409 7f68115ec632 7f68115fd1cf 7f6811637409
''' """
client = capability.TestPipeline._new_client(PipelineServer()) client = capability.TestPipeline._new_client(PipelineServer())
foo_client = capability.TestInterface._new_client(Server()) foo_client = capability.TestInterface._new_client(Server())
@ -137,10 +142,11 @@ def test_pipeline(capability):
pipelinePromise = outCap.foo(i=10) pipelinePromise = outCap.foo(i=10)
response = pipelinePromise.wait() response = pipelinePromise.wait()
assert response.x == '150' assert response.x == "150"
response = remote.wait() response = remote.wait()
assert response.s == '26_foo' assert response.s == "26_foo"
class BadServer: class BadServer:
def __init__(self, val=1): def __init__(self, val=1):
@ -150,27 +156,30 @@ class BadServer:
extra = 0 extra = 0
if j: if j:
extra = 1 extra = 1
return str(i * 5 + extra + self.val), 10 # returning too many args return str(i * 5 + extra + self.val), 10 # returning too many args
def test_exception_client(capability): def test_exception_client(capability):
client = capability.TestInterface._new_client(BadServer()) client = capability.TestInterface._new_client(BadServer())
remote = client._send('foo', i=5) remote = client._send("foo", i=5)
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
remote.wait() remote.wait()
class BadPipelineServer: class BadPipelineServer:
def getCap(self, n, inCap, _context, **kwargs): def getCap(self, n, inCap, _context, **kwargs):
def _then(response): def _then(response):
_results = _context.results _results = _context.results
_results.s = response.x + '_foo' _results.s = response.x + "_foo"
_results.outBox.cap = capability().TestInterface._new_server(Server(100)) _results.outBox.cap = capability().TestInterface._new_server(Server(100))
def _error(error): def _error(error):
raise Exception('test was a success') raise Exception("test was a success")
return inCap.foo(i=n).then(_then, _error) return inCap.foo(i=n).then(_then, _error)
def test_exception_chain(capability): def test_exception_chain(capability):
client = capability.TestPipeline._new_client(BadPipelineServer()) client = capability.TestPipeline._new_client(BadPipelineServer())
foo_client = capability.TestInterface._new_client(BadServer()) foo_client = capability.TestInterface._new_client(BadServer())
@ -180,7 +189,8 @@ def test_exception_chain(capability):
try: try:
remote.wait() remote.wait()
except Exception as e: except Exception as e:
assert 'test was a success' in str(e) assert "test was a success" in str(e)
def test_pipeline_exception(capability): def test_pipeline_exception(capability):
client = capability.TestPipeline._new_client(BadPipelineServer()) client = capability.TestPipeline._new_client(BadPipelineServer())
@ -197,6 +207,7 @@ def test_pipeline_exception(capability):
with pytest.raises(Exception): with pytest.raises(Exception):
remote.wait() remote.wait()
def test_casting(capability): def test_casting(capability):
client = capability.TestExtends._new_client(Server()) client = capability.TestExtends._new_client(Server())
client2 = client.upcast(capability.TestInterface) client2 = client.upcast(capability.TestInterface)
@ -205,6 +216,7 @@ def test_casting(capability):
with pytest.raises(Exception): with pytest.raises(Exception):
client.upcast(capability.TestPipeline) client.upcast(capability.TestPipeline)
class TailCallOrder: class TailCallOrder:
def __init__(self): def __init__(self):
self.count = -1 self.count = -1
@ -213,6 +225,7 @@ class TailCallOrder:
self.count += 1 self.count += 1
return self.count return self.count
class TailCaller: class TailCaller:
def __init__(self): def __init__(self):
self.count = 0 self.count = 0
@ -220,9 +233,10 @@ class TailCaller:
def foo(self, i, callee, _context, **kwargs): def foo(self, i, callee, _context, **kwargs):
self.count += 1 self.count += 1
tail = callee.foo_request(i=i, t='from TailCaller') tail = callee.foo_request(i=i, t="from TailCaller")
return _context.tail_call(tail) return _context.tail_call(tail)
class TailCallee: class TailCallee:
def __init__(self): def __init__(self):
self.count = 0 self.count = 0
@ -235,15 +249,16 @@ class TailCallee:
results.t = t results.t = t
results.c = capability().TestCallOrder._new_server(TailCallOrder()) results.c = capability().TestCallOrder._new_server(TailCallOrder())
@pytest.mark.xfail @pytest.mark.xfail
def test_tail_call(capability): def test_tail_call(capability):
''' """
E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:104: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly, E capnp.lib.capnp.KjException: capnp/lib/capnp.pyx:104: failed: <class 'Failed'>:Fixture "capability" called directly. Fixtures are not meant to be called directly,
E but are created automatically when test functions request them as parameters. E but are created automatically when test functions request them as parameters.
E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and E See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures, and
E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code. E https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly about how to update your code.
E stack: 7f680f4fb540 7f680f4fb1b0 7f680f4fb540 7f680f50dbf0 7f680f801768 7f680f7e5185 7f680f7e52dc 7f680f7a3a1d 7f68115cb459 7f68115cb713 7f68115fd2eb 7f6811637409 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811633b77 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811637409 7f68115ec632 7f68115fd1cf 7f6811637409 7f68115eb767 7f68115ece7e 7f68115c0ce7 E stack: 7f680f4fb540 7f680f4fb1b0 7f680f4fb540 7f680f50dbf0 7f680f801768 7f680f7e5185 7f680f7e52dc 7f680f7a3a1d 7f68115cb459 7f68115cb713 7f68115fd2eb 7f6811637409 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ece7e 7f681163448d 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811633b77 7f68115eb767 7f68115ec7d2 7f68115fd1cf 7f6811637409 7f68115ec632 7f68115fd1cf 7f6811637409 7f68115eb767 7f68115ece7e 7f68115c0ce7
''' """
callee_server = TailCallee() callee_server = TailCallee()
caller_server = TailCaller() caller_server = TailCaller()

View file

@ -5,8 +5,8 @@ import subprocess
import sys import sys
import time import time
examples_dir = os.path.join(os.path.dirname(__file__), '..', 'examples') examples_dir = os.path.join(os.path.dirname(__file__), "..", "examples")
hostname = 'localhost' hostname = "localhost"
processes = [] processes = []
@ -19,24 +19,28 @@ def cleanup():
p.kill() p.kill()
def run_subprocesses(address, server, client, wildcard_server=False, ipv4_force=True): # noqa def run_subprocesses(
address, server, client, wildcard_server=False, ipv4_force=True
): # noqa
server_attempt = 0 server_attempt = 0
server_attempts = 2 server_attempts = 2
done = False done = False
addr, port = address.split(':') addr, port = address.split(":")
c_address = address c_address = address
s_address = address s_address = address
while not done: while not done:
assert server_attempt < server_attempts, "Failed {} server attempts".format(server_attempts) assert server_attempt < server_attempts, "Failed {} server attempts".format(
server_attempts
)
server_attempt += 1 server_attempt += 1
# Force ipv4 for tests (known issues on GitHub Actions with IPv6 for some targets) # Force ipv4 for tests (known issues on GitHub Actions with IPv6 for some targets)
if 'unix' not in addr and ipv4_force: if "unix" not in addr and ipv4_force:
addr = socket.gethostbyname(addr) addr = socket.gethostbyname(addr)
c_address = '{}:{}'.format(addr, port) c_address = "{}:{}".format(addr, port)
s_address = c_address s_address = c_address
if wildcard_server: if wildcard_server:
s_address = '*:{}'.format(port) # Use wildcard address for server s_address = "*:{}".format(port) # Use wildcard address for server
print("Forcing ipv4 -> {} => {} {}".format(address, c_address, s_address)) print("Forcing ipv4 -> {} => {} {}".format(address, c_address, s_address))
# Start server # Start server
@ -48,7 +52,7 @@ def run_subprocesses(address, server, client, wildcard_server=False, ipv4_force=
# Loop until we have a socket connection to the server (with timeout) # Loop until we have a socket connection to the server (with timeout)
while True: while True:
try: try:
if 'unix' in address: if "unix" in address:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
result = sock.connect_ex(port) result = sock.connect_ex(port)
if result == 0: if result == 0:
@ -114,22 +118,26 @@ def run_subprocesses(address, server, client, wildcard_server=False, ipv4_force=
def test_async_calculator_example(cleanup): def test_async_calculator_example(cleanup):
address = '{}:36432'.format(hostname) address = "{}:36432".format(hostname)
server = 'async_calculator_server.py' server = "async_calculator_server.py"
client = 'async_calculator_client.py' client = "async_calculator_client.py"
run_subprocesses(address, server, client) run_subprocesses(address, server, client)
@pytest.mark.xfail(reason="Some versions of python don't like to share ports, don't worry if this fails") @pytest.mark.xfail(
reason="Some versions of python don't like to share ports, don't worry if this fails"
)
def test_thread_example(cleanup): def test_thread_example(cleanup):
address = '{}:36433'.format(hostname) address = "{}:36433".format(hostname)
server = 'thread_server.py' server = "thread_server.py"
client = 'thread_client.py' client = "thread_client.py"
run_subprocesses(address, server, client, wildcard_server=True) run_subprocesses(address, server, client, wildcard_server=True)
def test_addressbook_example(cleanup): def test_addressbook_example(cleanup):
proc = subprocess.Popen([sys.executable, os.path.join(examples_dir, 'addressbook.py')]) proc = subprocess.Popen(
[sys.executable, os.path.join(examples_dir, "addressbook.py")]
)
ret = proc.wait() ret = proc.wait()
assert ret == 0 assert ret == 0
@ -139,12 +147,12 @@ def test_addressbook_example(cleanup):
reason=""" reason="""
Asyncio bug with libcapnp timer, likely due to asyncio starving some event loop. Asyncio bug with libcapnp timer, likely due to asyncio starving some event loop.
See https://github.com/capnproto/pycapnp/issues/196 See https://github.com/capnproto/pycapnp/issues/196
""" """,
) )
def test_async_example(cleanup): def test_async_example(cleanup):
address = '{}:36434'.format(hostname) address = "{}:36434".format(hostname)
server = 'async_server.py' server = "async_server.py"
client = 'async_client.py' client = "async_client.py"
run_subprocesses(address, server, client) run_subprocesses(address, server, client)
@ -153,12 +161,12 @@ def test_async_example(cleanup):
reason=""" reason="""
Asyncio bug with libcapnp timer, likely due to asyncio starving some event loop. Asyncio bug with libcapnp timer, likely due to asyncio starving some event loop.
See https://github.com/capnproto/pycapnp/issues/196 See https://github.com/capnproto/pycapnp/issues/196
""" """,
) )
def test_ssl_async_example(cleanup): def test_ssl_async_example(cleanup):
address = '{}:36435'.format(hostname) address = "{}:36435".format(hostname)
server = 'async_ssl_server.py' server = "async_ssl_server.py"
client = 'async_ssl_client.py' client = "async_ssl_client.py"
run_subprocesses(address, server, client, ipv4_force=False) run_subprocesses(address, server, client, ipv4_force=False)
@ -167,17 +175,17 @@ def test_ssl_async_example(cleanup):
reason=""" reason="""
Asyncio bug with libcapnp timer, likely due to asyncio starving some event loop. Asyncio bug with libcapnp timer, likely due to asyncio starving some event loop.
See https://github.com/capnproto/pycapnp/issues/196 See https://github.com/capnproto/pycapnp/issues/196
""" """,
) )
def test_ssl_reconnecting_async_example(cleanup): def test_ssl_reconnecting_async_example(cleanup):
address = '{}:36436'.format(hostname) address = "{}:36436".format(hostname)
server = 'async_ssl_server.py' server = "async_ssl_server.py"
client = 'async_reconnecting_ssl_client.py' client = "async_reconnecting_ssl_client.py"
run_subprocesses(address, server, client, ipv4_force=False) run_subprocesses(address, server, client, ipv4_force=False)
def test_async_ssl_calculator_example(cleanup): def test_async_ssl_calculator_example(cleanup):
address = '{}:36437'.format(hostname) address = "{}:36437".format(hostname)
server = 'async_ssl_calculator_server.py' server = "async_ssl_calculator_server.py"
client = 'async_ssl_calculator_client.py' client = "async_ssl_calculator_client.py"
run_subprocesses(address, server, client, ipv4_force=False) run_subprocesses(address, server, client, ipv4_force=False)

View file

@ -11,7 +11,7 @@ this_dir = os.path.dirname(__file__)
@pytest.fixture @pytest.fixture
def test_capnp(): def test_capnp():
return capnp.load(os.path.join(this_dir, 'test_large_read.capnp')) return capnp.load(os.path.join(this_dir, "test_large_read.capnp"))
def test_large_read(test_capnp): def test_large_read(test_capnp):
@ -19,8 +19,8 @@ def test_large_read(test_capnp):
array = test_capnp.MultiArray.new_message() array = test_capnp.MultiArray.new_message()
row = array.init('rows', 1)[0] row = array.init("rows", 1)[0]
values = row.init('values', 10000) values = row.init("values", 10000)
for i in range(len(values)): for i in range(len(values)):
values[i] = i values[i] = i
@ -66,12 +66,15 @@ def test_large_read_multiple_bytes(test_capnp):
pass pass
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
data = get_two_adjacent_messages(test_capnp) + b' ' data = get_two_adjacent_messages(test_capnp) + b" "
for m in test_capnp.Msg.read_multiple_bytes(data): for m in test_capnp.Msg.read_multiple_bytes(data):
pass pass
@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="PyPy memoryview support is limited") @pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="PyPy memoryview support is limited",
)
def test_large_read_mutltiple_bytes_memoryview(test_capnp): def test_large_read_mutltiple_bytes_memoryview(test_capnp):
data = get_two_adjacent_messages(test_capnp) data = get_two_adjacent_messages(test_capnp)
for m in test_capnp.Msg.read_multiple_bytes(memoryview(data)): for m in test_capnp.Msg.read_multiple_bytes(memoryview(data)):
@ -83,6 +86,6 @@ def test_large_read_mutltiple_bytes_memoryview(test_capnp):
pass pass
with pytest.raises(capnp.KjException): with pytest.raises(capnp.KjException):
data = get_two_adjacent_messages(test_capnp) + b' ' data = get_two_adjacent_messages(test_capnp) + b" "
for m in test_capnp.Msg.read_multiple_bytes(memoryview(data)): for m in test_capnp.Msg.read_multiple_bytes(memoryview(data)):
pass pass

View file

@ -8,21 +8,21 @@ this_dir = os.path.dirname(__file__)
@pytest.fixture @pytest.fixture
def addressbook(): def addressbook():
return capnp.load(os.path.join(this_dir, 'addressbook.capnp')) return capnp.load(os.path.join(this_dir, "addressbook.capnp"))
@pytest.fixture @pytest.fixture
def foo(): def foo():
return capnp.load(os.path.join(this_dir, 'foo.capnp')) return capnp.load(os.path.join(this_dir, "foo.capnp"))
@pytest.fixture @pytest.fixture
def bar(): def bar():
return capnp.load(os.path.join(this_dir, 'bar.capnp')) return capnp.load(os.path.join(this_dir, "bar.capnp"))
def test_basic_load(): def test_basic_load():
capnp.load(os.path.join(this_dir, 'addressbook.capnp')) capnp.load(os.path.join(this_dir, "addressbook.capnp"))
def test_constants(addressbook): def test_constants(addressbook):
@ -40,25 +40,25 @@ def test_import(foo, bar):
m2 = capnp._MallocMessageBuilder() m2 = capnp._MallocMessageBuilder()
bar = m2.init_root(bar.Bar) bar = m2.init_root(bar.Bar)
foo.name = 'foo' foo.name = "foo"
bar.foo = foo bar.foo = foo
assert bar.foo.name == 'foo' assert bar.foo.name == "foo"
def test_failed_import(): def test_failed_import():
s = capnp.SchemaParser() s = capnp.SchemaParser()
s2 = capnp.SchemaParser() s2 = capnp.SchemaParser()
foo = s.load(os.path.join(this_dir, 'foo.capnp')) foo = s.load(os.path.join(this_dir, "foo.capnp"))
bar = s2.load(os.path.join(this_dir, 'bar.capnp')) bar = s2.load(os.path.join(this_dir, "bar.capnp"))
m = capnp._MallocMessageBuilder() m = capnp._MallocMessageBuilder()
foo = m.init_root(foo.Foo) foo = m.init_root(foo.Foo)
m2 = capnp._MallocMessageBuilder() m2 = capnp._MallocMessageBuilder()
bar = m2.init_root(bar.Bar) bar = m2.init_root(bar.Bar)
foo.name = 'foo' foo.name = "foo"
with pytest.raises(Exception): with pytest.raises(Exception):
bar.foo = foo bar.foo = foo
@ -86,6 +86,7 @@ def test_add_import_hook():
capnp.cleanup_global_schema_parser() capnp.cleanup_global_schema_parser()
import addressbook_capnp import addressbook_capnp
addressbook_capnp.AddressBook.new_message() addressbook_capnp.AddressBook.new_message()
@ -98,6 +99,7 @@ def test_multiple_add_import_hook():
capnp.cleanup_global_schema_parser() capnp.cleanup_global_schema_parser()
import addressbook_capnp import addressbook_capnp
addressbook_capnp.AddressBook.new_message() addressbook_capnp.AddressBook.new_message()
@ -105,9 +107,9 @@ def test_remove_import_hook():
capnp.add_import_hook([this_dir]) capnp.add_import_hook([this_dir])
capnp.remove_import_hook() capnp.remove_import_hook()
if 'addressbook_capnp' in sys.modules: if "addressbook_capnp" in sys.modules:
# hack to deal with it being imported already # hack to deal with it being imported already
del sys.modules['addressbook_capnp'] del sys.modules["addressbook_capnp"]
with pytest.raises(ImportError): with pytest.raises(ImportError):
import addressbook_capnp # noqa: F401 import addressbook_capnp # noqa: F401

View file

@ -7,22 +7,22 @@ this_dir = os.path.dirname(__file__)
@pytest.fixture @pytest.fixture
def addressbook(): def addressbook():
return capnp.load(os.path.join(this_dir, 'addressbook.capnp')) return capnp.load(os.path.join(this_dir, "addressbook.capnp"))
def test_object_basic(addressbook): def test_object_basic(addressbook):
obj = capnp._MallocMessageBuilder().get_root_as_any() obj = capnp._MallocMessageBuilder().get_root_as_any()
person = obj.as_struct(addressbook.Person) person = obj.as_struct(addressbook.Person)
person.name = 'test' person.name = "test"
person.id = 1000 person.id = 1000
same_person = obj.as_struct(addressbook.Person) same_person = obj.as_struct(addressbook.Person)
assert same_person.name == 'test' assert same_person.name == "test"
assert same_person.id == 1000 assert same_person.id == 1000
obj_r = obj.as_reader() obj_r = obj.as_reader()
same_person = obj_r.as_struct(addressbook.Person) same_person = obj_r.as_struct(addressbook.Person)
assert same_person.name == 'test' assert same_person.name == "test"
assert same_person.id == 1000 assert same_person.id == 1000
@ -31,21 +31,21 @@ def test_object_list(addressbook):
listSchema = capnp._ListSchema(addressbook.Person) listSchema = capnp._ListSchema(addressbook.Person)
people = obj.init_as_list(listSchema, 2) people = obj.init_as_list(listSchema, 2)
person = people[0] person = people[0]
person.name = 'test' person.name = "test"
person.id = 1000 person.id = 1000
person = people[1] person = people[1]
person.name = 'test2' person.name = "test2"
person.id = 1001 person.id = 1001
same_person = obj.as_list(listSchema) same_person = obj.as_list(listSchema)
assert same_person[0].name == 'test' assert same_person[0].name == "test"
assert same_person[0].id == 1000 assert same_person[0].id == 1000
assert same_person[1].name == 'test2' assert same_person[1].name == "test2"
assert same_person[1].id == 1001 assert same_person[1].id == 1001
obj_r = obj.as_reader() obj_r = obj.as_reader()
same_person = obj_r.as_list(listSchema) same_person = obj_r.as_list(listSchema)
assert same_person[0].name == 'test' assert same_person[0].name == "test"
assert same_person[0].id == 1000 assert same_person[0].id == 1000
assert same_person[1].name == 'test2' assert same_person[1].name == "test2"
assert same_person[1].id == 1001 assert same_person[1].id == 1001

View file

@ -16,33 +16,33 @@ else:
@pytest.fixture @pytest.fixture
def addressbook(): def addressbook():
return capnp.load(os.path.join(this_dir, 'addressbook.capnp')) return capnp.load(os.path.join(this_dir, "addressbook.capnp"))
def test_addressbook_message_classes(addressbook): def test_addressbook_message_classes(addressbook):
def writeAddressBook(fd): def writeAddressBook(fd):
message = capnp._MallocMessageBuilder() message = capnp._MallocMessageBuilder()
addressBook = message.init_root(addressbook.AddressBook) addressBook = message.init_root(addressbook.AddressBook)
people = addressBook.init('people', 2) people = addressBook.init("people", 2)
alice = people[0] alice = people[0]
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
alice.employment.school = "MIT" alice.employment.school = "MIT"
bob = people[1] bob = people[1]
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
bob.employment.unemployed = None bob.employment.unemployed = None
capnp._write_packed_message_to_fd(fd, message) capnp._write_packed_message_to_fd(fd, message)
@ -55,54 +55,54 @@ def test_addressbook_message_classes(addressbook):
alice = people[0] alice = people[0]
assert alice.id == 123 assert alice.id == 123
assert alice.name == 'Alice' assert alice.name == "Alice"
assert alice.email == 'alice@example.com' assert alice.email == "alice@example.com"
alicePhones = alice.phones alicePhones = alice.phones
assert alicePhones[0].number == "555-1212" assert alicePhones[0].number == "555-1212"
assert alicePhones[0].type == 'mobile' assert alicePhones[0].type == "mobile"
assert alice.employment.school == "MIT" assert alice.employment.school == "MIT"
bob = people[1] bob = people[1]
assert bob.id == 456 assert bob.id == 456
assert bob.name == 'Bob' assert bob.name == "Bob"
assert bob.email == 'bob@example.com' assert bob.email == "bob@example.com"
bobPhones = bob.phones bobPhones = bob.phones
assert bobPhones[0].number == "555-4567" assert bobPhones[0].number == "555-4567"
assert bobPhones[0].type == 'home' assert bobPhones[0].type == "home"
assert bobPhones[1].number == "555-7654" assert bobPhones[1].number == "555-7654"
assert bobPhones[1].type == 'work' assert bobPhones[1].type == "work"
assert bob.employment.unemployed is None assert bob.employment.unemployed is None
f = open('example', 'w') f = open("example", "w")
writeAddressBook(f.fileno()) writeAddressBook(f.fileno())
f = open('example', 'r') f = open("example", "r")
printAddressBook(f.fileno()) printAddressBook(f.fileno())
def test_addressbook(addressbook): def test_addressbook(addressbook):
def writeAddressBook(file): def writeAddressBook(file):
addresses = addressbook.AddressBook.new_message() addresses = addressbook.AddressBook.new_message()
people = addresses.init('people', 2) people = addresses.init("people", 2)
alice = people[0] alice = people[0]
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
alice.employment.school = "MIT" alice.employment.school = "MIT"
bob = people[1] bob = people[1]
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
bob.employment.unemployed = None bob.employment.unemployed = None
addresses.write(file) addresses.write(file)
@ -114,54 +114,54 @@ def test_addressbook(addressbook):
alice = people[0] alice = people[0]
assert alice.id == 123 assert alice.id == 123
assert alice.name == 'Alice' assert alice.name == "Alice"
assert alice.email == 'alice@example.com' assert alice.email == "alice@example.com"
alicePhones = alice.phones alicePhones = alice.phones
assert alicePhones[0].number == "555-1212" assert alicePhones[0].number == "555-1212"
assert alicePhones[0].type == 'mobile' assert alicePhones[0].type == "mobile"
assert alice.employment.school == "MIT" assert alice.employment.school == "MIT"
bob = people[1] bob = people[1]
assert bob.id == 456 assert bob.id == 456
assert bob.name == 'Bob' assert bob.name == "Bob"
assert bob.email == 'bob@example.com' assert bob.email == "bob@example.com"
bobPhones = bob.phones bobPhones = bob.phones
assert bobPhones[0].number == "555-4567" assert bobPhones[0].number == "555-4567"
assert bobPhones[0].type == 'home' assert bobPhones[0].type == "home"
assert bobPhones[1].number == "555-7654" assert bobPhones[1].number == "555-7654"
assert bobPhones[1].type == 'work' assert bobPhones[1].type == "work"
assert bob.employment.unemployed is None assert bob.employment.unemployed is None
f = open('example', 'w') f = open("example", "w")
writeAddressBook(f) writeAddressBook(f)
f = open('example', 'r') f = open("example", "r")
printAddressBook(f) printAddressBook(f)
def test_addressbook_resizable(addressbook): def test_addressbook_resizable(addressbook):
def writeAddressBook(file): def writeAddressBook(file):
addresses = addressbook.AddressBook.new_message() addresses = addressbook.AddressBook.new_message()
people = addresses.init_resizable_list('people') people = addresses.init_resizable_list("people")
alice = people.add() alice = people.add()
alice.id = 123 alice.id = 123
alice.name = 'Alice' alice.name = "Alice"
alice.email = 'alice@example.com' alice.email = "alice@example.com"
alicePhones = alice.init('phones', 1) alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212" alicePhones[0].number = "555-1212"
alicePhones[0].type = 'mobile' alicePhones[0].type = "mobile"
alice.employment.school = "MIT" alice.employment.school = "MIT"
bob = people.add() bob = people.add()
bob.id = 456 bob.id = 456
bob.name = 'Bob' bob.name = "Bob"
bob.email = 'bob@example.com' bob.email = "bob@example.com"
bobPhones = bob.init('phones', 2) bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567" bobPhones[0].number = "555-4567"
bobPhones[0].type = 'home' bobPhones[0].type = "home"
bobPhones[1].number = "555-7654" bobPhones[1].number = "555-7654"
bobPhones[1].type = 'work' bobPhones[1].type = "work"
bob.employment.unemployed = None bob.employment.unemployed = None
people.finish() people.finish()
@ -175,28 +175,28 @@ def test_addressbook_resizable(addressbook):
alice = people[0] alice = people[0]
assert alice.id == 123 assert alice.id == 123
assert alice.name == 'Alice' assert alice.name == "Alice"
assert alice.email == 'alice@example.com' assert alice.email == "alice@example.com"
alicePhones = alice.phones alicePhones = alice.phones
assert alicePhones[0].number == "555-1212" assert alicePhones[0].number == "555-1212"
assert alicePhones[0].type == 'mobile' assert alicePhones[0].type == "mobile"
assert alice.employment.school == "MIT" assert alice.employment.school == "MIT"
bob = people[1] bob = people[1]
assert bob.id == 456 assert bob.id == 456
assert bob.name == 'Bob' assert bob.name == "Bob"
assert bob.email == 'bob@example.com' assert bob.email == "bob@example.com"
bobPhones = bob.phones bobPhones = bob.phones
assert bobPhones[0].number == "555-4567" assert bobPhones[0].number == "555-4567"
assert bobPhones[0].type == 'home' assert bobPhones[0].type == "home"
assert bobPhones[1].number == "555-7654" assert bobPhones[1].number == "555-7654"
assert bobPhones[1].type == 'work' assert bobPhones[1].type == "work"
assert bob.employment.unemployed is None assert bob.employment.unemployed is None
f = open('example', 'w') f = open("example", "w")
writeAddressBook(f) writeAddressBook(f)
f = open('example', 'r') f = open("example", "r")
printAddressBook(f) printAddressBook(f)
@ -206,29 +206,33 @@ def test_addressbook_explicit_fields(addressbook):
address_fields = addressbook.AddressBook.schema.fields address_fields = addressbook.AddressBook.schema.fields
person_fields = addressbook.Person.schema.fields person_fields = addressbook.Person.schema.fields
phone_fields = addressbook.Person.PhoneNumber.schema.fields phone_fields = addressbook.Person.PhoneNumber.schema.fields
people = addresses._init_by_field(address_fields['people'], 2) people = addresses._init_by_field(address_fields["people"], 2)
alice = people[0] alice = people[0]
alice._set_by_field(person_fields['id'], 123) alice._set_by_field(person_fields["id"], 123)
alice._set_by_field(person_fields['name'], 'Alice') alice._set_by_field(person_fields["name"], "Alice")
alice._set_by_field(person_fields['email'], 'alice@example.com') alice._set_by_field(person_fields["email"], "alice@example.com")
alicePhones = alice._init_by_field(person_fields['phones'], 1) alicePhones = alice._init_by_field(person_fields["phones"], 1)
alicePhones[0]._set_by_field(phone_fields['number'], "555-1212") alicePhones[0]._set_by_field(phone_fields["number"], "555-1212")
alicePhones[0]._set_by_field(phone_fields['type'], 'mobile') alicePhones[0]._set_by_field(phone_fields["type"], "mobile")
employment = alice._get_by_field(person_fields['employment']) employment = alice._get_by_field(person_fields["employment"])
employment._set_by_field(addressbook.Person.Employment.schema.fields['school'], "MIT") employment._set_by_field(
addressbook.Person.Employment.schema.fields["school"], "MIT"
)
bob = people[1] bob = people[1]
bob._set_by_field(person_fields['id'], 456) bob._set_by_field(person_fields["id"], 456)
bob._set_by_field(person_fields['name'], 'Bob') bob._set_by_field(person_fields["name"], "Bob")
bob._set_by_field(person_fields['email'], 'bob@example.com') bob._set_by_field(person_fields["email"], "bob@example.com")
bobPhones = bob._init_by_field(person_fields['phones'], 2) bobPhones = bob._init_by_field(person_fields["phones"], 2)
bobPhones[0]._set_by_field(phone_fields['number'], "555-4567") bobPhones[0]._set_by_field(phone_fields["number"], "555-4567")
bobPhones[0]._set_by_field(phone_fields['type'], 'home') bobPhones[0]._set_by_field(phone_fields["type"], "home")
bobPhones[1]._set_by_field(phone_fields['number'], "555-7654") bobPhones[1]._set_by_field(phone_fields["number"], "555-7654")
bobPhones[1]._set_by_field(phone_fields['type'], 'work') bobPhones[1]._set_by_field(phone_fields["type"], "work")
employment = bob._get_by_field(person_fields['employment']) employment = bob._get_by_field(person_fields["employment"])
employment._set_by_field(addressbook.Person.Employment.schema.fields['unemployed'], None) employment._set_by_field(
addressbook.Person.Employment.schema.fields["unemployed"], None
)
addresses.write(file) addresses.write(file)
@ -238,40 +242,45 @@ def test_addressbook_explicit_fields(addressbook):
person_fields = addressbook.Person.schema.fields person_fields = addressbook.Person.schema.fields
phone_fields = addressbook.Person.PhoneNumber.schema.fields phone_fields = addressbook.Person.PhoneNumber.schema.fields
people = addresses._get_by_field(address_fields['people']) people = addresses._get_by_field(address_fields["people"])
alice = people[0] alice = people[0]
assert alice._get_by_field(person_fields['id']) == 123 assert alice._get_by_field(person_fields["id"]) == 123
assert alice._get_by_field(person_fields['name']) == 'Alice' assert alice._get_by_field(person_fields["name"]) == "Alice"
assert alice._get_by_field(person_fields['email']) == 'alice@example.com' assert alice._get_by_field(person_fields["email"]) == "alice@example.com"
alicePhones = alice._get_by_field(person_fields['phones']) alicePhones = alice._get_by_field(person_fields["phones"])
assert alicePhones[0]._get_by_field(phone_fields['number']) == "555-1212" assert alicePhones[0]._get_by_field(phone_fields["number"]) == "555-1212"
assert alicePhones[0]._get_by_field(phone_fields['type']) == 'mobile' assert alicePhones[0]._get_by_field(phone_fields["type"]) == "mobile"
employment = alice._get_by_field(person_fields['employment']) employment = alice._get_by_field(person_fields["employment"])
employment._get_by_field(addressbook.Person.Employment.schema.fields['school']) == "MIT" employment._get_by_field(
addressbook.Person.Employment.schema.fields["school"]
) == "MIT"
bob = people[1] bob = people[1]
assert bob._get_by_field(person_fields['id']) == 456 assert bob._get_by_field(person_fields["id"]) == 456
assert bob._get_by_field(person_fields['name']) == 'Bob' assert bob._get_by_field(person_fields["name"]) == "Bob"
assert bob._get_by_field(person_fields['email']) == 'bob@example.com' assert bob._get_by_field(person_fields["email"]) == "bob@example.com"
bobPhones = bob._get_by_field(person_fields['phones']) bobPhones = bob._get_by_field(person_fields["phones"])
assert bobPhones[0]._get_by_field(phone_fields['number']) == "555-4567" assert bobPhones[0]._get_by_field(phone_fields["number"]) == "555-4567"
assert bobPhones[0]._get_by_field(phone_fields['type']) == 'home' assert bobPhones[0]._get_by_field(phone_fields["type"]) == "home"
assert bobPhones[1]._get_by_field(phone_fields['number']) == "555-7654" assert bobPhones[1]._get_by_field(phone_fields["number"]) == "555-7654"
assert bobPhones[1]._get_by_field(phone_fields['type']) == 'work' assert bobPhones[1]._get_by_field(phone_fields["type"]) == "work"
employment = bob._get_by_field(person_fields['employment']) employment = bob._get_by_field(person_fields["employment"])
employment._get_by_field(addressbook.Person.Employment.schema.fields['unemployed']) is None employment._get_by_field(
addressbook.Person.Employment.schema.fields["unemployed"]
) is None
f = open('example', 'w') f = open("example", "w")
writeAddressBook(f) writeAddressBook(f)
f = open('example', 'r') f = open("example", "r")
printAddressBook(f) printAddressBook(f)
@pytest.fixture @pytest.fixture
def all_types(): def all_types():
return capnp.load(os.path.join(this_dir, 'all_types.capnp')) return capnp.load(os.path.join(this_dir, "all_types.capnp"))
# TODO: These tests should be extended to: # TODO: These tests should be extended to:
# - Read each field in Python and assert that it is equal to the expected value. # - Read each field in Python and assert that it is equal to the expected value.
@ -317,19 +326,24 @@ def init_all_types(builder):
subBuilder.voidList = [None, None, None] subBuilder.voidList = [None, None, None]
subBuilder.boolList = [False, True, False, True, True] subBuilder.boolList = [False, True, False, True, True]
subBuilder.int8List = [12, -34, -0x80, 0x7f] subBuilder.int8List = [12, -34, -0x80, 0x7F]
subBuilder.int16List = [1234, -5678, -0x8000, 0x7fff] subBuilder.int16List = [1234, -5678, -0x8000, 0x7FFF]
subBuilder.int32List = [12345678, -90123456, -0x80000000, 0x7fffffff] subBuilder.int32List = [12345678, -90123456, -0x80000000, 0x7FFFFFFF]
subBuilder.int64List = [123456789012345, -678901234567890, -0x8000000000000000, 0x7fffffffffffffff] subBuilder.int64List = [
subBuilder.uInt8List = [12, 34, 0, 0xff] 123456789012345,
subBuilder.uInt16List = [1234, 5678, 0, 0xffff] -678901234567890,
subBuilder.uInt32List = [12345678, 90123456, 0, 0xffffffff] -0x8000000000000000,
subBuilder.uInt64List = [123456789012345, 678901234567890, 0, 0xffffffffffffffff] 0x7FFFFFFFFFFFFFFF,
]
subBuilder.uInt8List = [12, 34, 0, 0xFF]
subBuilder.uInt16List = [1234, 5678, 0, 0xFFFF]
subBuilder.uInt32List = [12345678, 90123456, 0, 0xFFFFFFFF]
subBuilder.uInt64List = [123456789012345, 678901234567890, 0, 0xFFFFFFFFFFFFFFFF]
subBuilder.float32List = [0, 1234567, 1e37, -1e37, 1e-37, -1e-37] subBuilder.float32List = [0, 1234567, 1e37, -1e37, 1e-37, -1e-37]
subBuilder.float64List = [0, 123456789012345, 1e306, -1e306, 1e-306, -1e-306] subBuilder.float64List = [0, 123456789012345, 1e306, -1e306, 1e-306, -1e-306]
subBuilder.textList = ["quux", "corge", "grault"] subBuilder.textList = ["quux", "corge", "grault"]
subBuilder.dataList = [b"garply", b"waldo", b"fred"] subBuilder.dataList = [b"garply", b"waldo", b"fred"]
listBuilder = subBuilder.init('structList', 3) listBuilder = subBuilder.init("structList", 3)
listBuilder[0].textField = "x structlist 1" listBuilder[0].textField = "x structlist 1"
listBuilder[1].textField = "x structlist 2" listBuilder[1].textField = "x structlist 2"
listBuilder[2].textField = "x structlist 3" listBuilder[2].textField = "x structlist 3"
@ -351,7 +365,7 @@ def init_all_types(builder):
builder.float64List = [7777.75, float("inf"), float("-inf"), float("nan")] builder.float64List = [7777.75, float("inf"), float("-inf"), float("nan")]
builder.textList = ["plugh", "xyzzy", "thud"] builder.textList = ["plugh", "xyzzy", "thud"]
builder.dataList = [b"oops", b"exhausted", b"rfc3092"] builder.dataList = [b"oops", b"exhausted", b"rfc3092"]
listBuilder = builder.init('structList', 3) listBuilder = builder.init("structList", 3)
listBuilder[0].textField = "structlist 1" listBuilder[0].textField = "structlist 1"
listBuilder[1].textField = "structlist 2" listBuilder[1].textField = "structlist 2"
listBuilder[2].textField = "structlist 3" listBuilder[2].textField = "structlist 3"
@ -419,23 +433,30 @@ def check_all_types(reader):
assert subReader.enumField == "baz" assert subReader.enumField == "baz"
# Check that enums are hashable and can be used as keys in dicts # Check that enums are hashable and can be used as keys in dicts
# interchangably with their string version. # interchangably with their string version.
assert hash(subReader.enumField) == hash('baz') assert hash(subReader.enumField) == hash("baz")
assert {subReader.enumField: 17}.get(subReader.enumField) == 17 assert {subReader.enumField: 17}.get(subReader.enumField) == 17
assert {subReader.enumField: 17}.get('baz') == 17 assert {subReader.enumField: 17}.get("baz") == 17
assert {'baz': 17}.get(subReader.enumField) == 17 assert {"baz": 17}.get(subReader.enumField) == 17
check_list(subReader.voidList, [None, None, None]) check_list(subReader.voidList, [None, None, None])
check_list(subReader.boolList, [False, True, False, True, True]) check_list(subReader.boolList, [False, True, False, True, True])
check_list(subReader.int8List, [12, -34, -0x80, 0x7f]) check_list(subReader.int8List, [12, -34, -0x80, 0x7F])
check_list(subReader.int16List, [1234, -5678, -0x8000, 0x7fff]) check_list(subReader.int16List, [1234, -5678, -0x8000, 0x7FFF])
check_list(subReader.int32List, [12345678, -90123456, -0x80000000, 0x7fffffff]) check_list(subReader.int32List, [12345678, -90123456, -0x80000000, 0x7FFFFFFF])
check_list(subReader.int64List, [123456789012345, -678901234567890, -0x8000000000000000, 0x7fffffffffffffff]) check_list(
check_list(subReader.uInt8List, [12, 34, 0, 0xff]) subReader.int64List,
check_list(subReader.uInt16List, [1234, 5678, 0, 0xffff]) [123456789012345, -678901234567890, -0x8000000000000000, 0x7FFFFFFFFFFFFFFF],
check_list(subReader.uInt32List, [12345678, 90123456, 0, 0xffffffff]) )
check_list(subReader.uInt64List, [123456789012345, 678901234567890, 0, 0xffffffffffffffff]) check_list(subReader.uInt8List, [12, 34, 0, 0xFF])
check_list(subReader.uInt16List, [1234, 5678, 0, 0xFFFF])
check_list(subReader.uInt32List, [12345678, 90123456, 0, 0xFFFFFFFF])
check_list(
subReader.uInt64List, [123456789012345, 678901234567890, 0, 0xFFFFFFFFFFFFFFFF]
)
check_list(subReader.float32List, [0.0, 1234567.0, 1e37, -1e37, 1e-37, -1e-37]) check_list(subReader.float32List, [0.0, 1234567.0, 1e37, -1e37, 1e-37, -1e-37])
check_list(subReader.float64List, [0.0, 123456789012345.0, 1e306, -1e306, 1e-306, -1e-306]) check_list(
subReader.float64List, [0.0, 123456789012345.0, 1e306, -1e306, 1e-306, -1e-306]
)
check_list(subReader.textList, ["quux", "corge", "grault"]) check_list(subReader.textList, ["quux", "corge", "grault"])
check_list(subReader.dataList, [b"garply", b"waldo", b"fred"]) check_list(subReader.dataList, [b"garply", b"waldo", b"fred"])
@ -489,29 +510,37 @@ def check_all_types(reader):
def test_build(all_types): def test_build(all_types):
root = all_types.TestAllTypes.new_message() root = all_types.TestAllTypes.new_message()
init_all_types(root) init_all_types(root)
expectedText = open(os.path.join(this_dir, 'all-types.txt'), 'r', encoding='utf8').read() expectedText = open(
assert str(root) + '\n' == expectedText os.path.join(this_dir, "all-types.txt"), "r", encoding="utf8"
).read()
assert str(root) + "\n" == expectedText
def test_build_first_segment_size(all_types): def test_build_first_segment_size(all_types):
root = all_types.TestAllTypes.new_message(1) root = all_types.TestAllTypes.new_message(1)
init_all_types(root) init_all_types(root)
expectedText = open(os.path.join(this_dir, 'all-types.txt'), 'r', encoding='utf8').read() expectedText = open(
assert str(root) + '\n' == expectedText os.path.join(this_dir, "all-types.txt"), "r", encoding="utf8"
).read()
assert str(root) + "\n" == expectedText
root = all_types.TestAllTypes.new_message(1024 * 1024) root = all_types.TestAllTypes.new_message(1024 * 1024)
init_all_types(root) init_all_types(root)
expectedText = open(os.path.join(this_dir, 'all-types.txt'), 'r', encoding='utf8').read() expectedText = open(
assert str(root) + '\n' == expectedText os.path.join(this_dir, "all-types.txt"), "r", encoding="utf8"
).read()
assert str(root) + "\n" == expectedText
def test_binary_read(all_types): def test_binary_read(all_types):
f = open(os.path.join(this_dir, 'all-types.binary'), 'r', encoding='utf8') f = open(os.path.join(this_dir, "all-types.binary"), "r", encoding="utf8")
root = all_types.TestAllTypes.read(f) root = all_types.TestAllTypes.read(f)
check_all_types(root) check_all_types(root)
expectedText = open(os.path.join(this_dir, 'all-types.txt'), 'r', encoding='utf8').read() expectedText = open(
assert str(root) + '\n' == expectedText os.path.join(this_dir, "all-types.txt"), "r", encoding="utf8"
).read()
assert str(root) + "\n" == expectedText
# Test set_root(). # Test set_root().
builder = capnp._MallocMessageBuilder() builder = capnp._MallocMessageBuilder()
@ -524,25 +553,27 @@ def test_binary_read(all_types):
def test_packed_read(all_types): def test_packed_read(all_types):
f = open(os.path.join(this_dir, 'all-types.packed'), 'r', encoding='utf8') f = open(os.path.join(this_dir, "all-types.packed"), "r", encoding="utf8")
root = all_types.TestAllTypes.read_packed(f) root = all_types.TestAllTypes.read_packed(f)
check_all_types(root) check_all_types(root)
expectedText = open(os.path.join(this_dir, 'all-types.txt'), 'r', encoding='utf8').read() expectedText = open(
assert str(root) + '\n' == expectedText os.path.join(this_dir, "all-types.txt"), "r", encoding="utf8"
).read()
assert str(root) + "\n" == expectedText
def test_binary_write(all_types): def test_binary_write(all_types):
root = all_types.TestAllTypes.new_message() root = all_types.TestAllTypes.new_message()
init_all_types(root) init_all_types(root)
root.write(open('example', 'w')) root.write(open("example", "w"))
check_all_types(all_types.TestAllTypes.read(open('example', 'r'))) check_all_types(all_types.TestAllTypes.read(open("example", "r")))
def test_packed_write(all_types): def test_packed_write(all_types):
root = all_types.TestAllTypes.new_message() root = all_types.TestAllTypes.new_message()
init_all_types(root) init_all_types(root)
root.write_packed(open('example', 'w')) root.write_packed(open("example", "w"))
check_all_types(all_types.TestAllTypes.read_packed(open('example', 'r'))) check_all_types(all_types.TestAllTypes.read_packed(open("example", "r")))

View file

@ -1,6 +1,6 @@
''' """
rpc test rpc test
''' """
import pytest import pytest
import capnp import capnp
@ -10,7 +10,6 @@ import test_capability_capnp
class Server(test_capability_capnp.TestInterface.Server): class Server(test_capability_capnp.TestInterface.Server):
def __init__(self, val=100): def __init__(self, val=100):
self.val = val self.val = val
@ -45,4 +44,4 @@ def test_simple_rpc_bootstrap():
remote = cap.foo(i=5) remote = cap.foo(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '125' assert response.x == "125"

View file

@ -6,7 +6,7 @@ import sys # add examples dir to sys.path
import capnp import capnp
examples_dir = os.path.join(os.path.dirname(__file__), '..', 'examples') examples_dir = os.path.join(os.path.dirname(__file__), "..", "examples")
sys.path.append(examples_dir) sys.path.append(examples_dir)
import calculator_client # noqa: E402 import calculator_client # noqa: E402
@ -32,23 +32,31 @@ def test_calculator():
calculator_client.main(read) calculator_client.main(read)
@pytest.mark.xfail(reason="Some versions of python don't like to share ports, don't worry if this fails") @pytest.mark.xfail(
reason="Some versions of python don't like to share ports, don't worry if this fails"
)
def test_calculator_tcp(cleanup): def test_calculator_tcp(cleanup):
address = 'localhost:36431' address = "localhost:36431"
test_examples.run_subprocesses(address, 'calculator_server.py', 'calculator_client.py', wildcard_server=True) test_examples.run_subprocesses(
address, "calculator_server.py", "calculator_client.py", wildcard_server=True
)
@pytest.mark.xfail(reason="Some versions of python don't like to share ports, don't worry if this fails") @pytest.mark.xfail(
@pytest.mark.skipif(os.name == 'nt', reason="socket.AF_UNIX not supported on Windows") reason="Some versions of python don't like to share ports, don't worry if this fails"
)
@pytest.mark.skipif(os.name == "nt", reason="socket.AF_UNIX not supported on Windows")
def test_calculator_unix(cleanup): def test_calculator_unix(cleanup):
path = '/tmp/pycapnp-test' path = "/tmp/pycapnp-test"
try: try:
os.unlink(path) os.unlink(path)
except OSError: except OSError:
pass pass
address = 'unix:' + path address = "unix:" + path
test_examples.run_subprocesses(address, 'calculator_server.py', 'calculator_client.py') test_examples.run_subprocesses(
address, "calculator_server.py", "calculator_client.py"
)
def test_calculator_gc(): def test_calculator_gc():
@ -56,6 +64,7 @@ def test_calculator_gc():
def call(*args, **kwargs): def call(*args, **kwargs):
gc.collect() gc.collect()
return old_evaluate_impl(*args, **kwargs) return old_evaluate_impl(*args, **kwargs)
return call return call
read, write = socket.socketpair() read, write = socket.socketpair()

View file

@ -7,20 +7,20 @@ this_dir = os.path.dirname(__file__)
@pytest.fixture @pytest.fixture
def addressbook(): def addressbook():
return capnp.load(os.path.join(this_dir, 'addressbook.capnp')) return capnp.load(os.path.join(this_dir, "addressbook.capnp"))
@pytest.fixture @pytest.fixture
def annotations(): def annotations():
return capnp.load(os.path.join(this_dir, 'annotations.capnp')) return capnp.load(os.path.join(this_dir, "annotations.capnp"))
def test_basic_schema(addressbook): def test_basic_schema(addressbook):
assert addressbook.Person.schema.fieldnames[0] == 'id' assert addressbook.Person.schema.fieldnames[0] == "id"
def test_list_schema(addressbook): def test_list_schema(addressbook):
peopleField = addressbook.AddressBook.schema.fields['people'] peopleField = addressbook.AddressBook.schema.fields["people"]
personType = peopleField.schema.elementType personType = peopleField.schema.elementType
assert personType.node.id == addressbook.Person.schema.node.id assert personType.node.id == addressbook.Person.schema.node.id
@ -31,20 +31,24 @@ def test_list_schema(addressbook):
def test_annotations(annotations): def test_annotations(annotations):
assert annotations.schema.node.annotations[0].value.text == 'TestFile' assert annotations.schema.node.annotations[0].value.text == "TestFile"
annotation = annotations.TestAnnotationOne.schema.node.annotations[0] annotation = annotations.TestAnnotationOne.schema.node.annotations[0]
assert annotation.value.text == 'Test' assert annotation.value.text == "Test"
annotation = annotations.TestAnnotationTwo.schema.node.annotations[0] annotation = annotations.TestAnnotationTwo.schema.node.annotations[0]
assert annotation.value.struct.as_struct(annotations.AnnotationStruct).test == 100 assert annotation.value.struct.as_struct(annotations.AnnotationStruct).test == 100
annotation = annotations.TestAnnotationThree.schema.node.annotations[0] annotation = annotations.TestAnnotationThree.schema.node.annotations[0]
annotation_list = annotation.value.list.as_list(capnp._ListSchema(annotations.AnnotationStruct)) annotation_list = annotation.value.list.as_list(
capnp._ListSchema(annotations.AnnotationStruct)
)
assert annotation_list[0].test == 100 assert annotation_list[0].test == 100
assert annotation_list[1].test == 101 assert annotation_list[1].test == 101
annotation = annotations.TestAnnotationFour.schema.node.annotations[0] annotation = annotations.TestAnnotationFour.schema.node.annotations[0]
annotation_list = annotation.value.list.as_list(capnp._ListSchema(capnp.types.UInt16)) annotation_list = annotation.value.list.as_list(
capnp._ListSchema(capnp.types.UInt16)
)
assert annotation_list[0] == 200 assert annotation_list[0] == 200
assert annotation_list[1] == 201 assert annotation_list[1] == 201

View file

@ -16,7 +16,7 @@ this_dir = os.path.dirname(__file__)
@pytest.fixture @pytest.fixture
def all_types(): def all_types():
return capnp.load(os.path.join(this_dir, 'all_types.capnp')) return capnp.load(os.path.join(this_dir, "all_types.capnp"))
def test_roundtrip_file(all_types): def test_roundtrip_file(all_types):
@ -51,8 +51,8 @@ def test_roundtrip_bytes(all_types):
@pytest.mark.skipif( @pytest.mark.skipif(
platform.python_implementation() == 'PyPy', platform.python_implementation() == "PyPy",
reason="TODO: Investigate why this works on CPython but fails on PyPy." reason="TODO: Investigate why this works on CPython but fails on PyPy.",
) )
def test_roundtrip_segments(all_types): def test_roundtrip_segments(all_types):
msg = all_types.TestAllTypes.new_message() msg = all_types.TestAllTypes.new_message()
@ -62,7 +62,10 @@ def test_roundtrip_segments(all_types):
test_regression.check_all_types(msg) test_regression.check_all_types(msg)
@pytest.mark.skipif(sys.version_info[0] < 3, reason="mmap doesn't implement the buffer interface under python 2.") @pytest.mark.skipif(
sys.version_info[0] < 3,
reason="mmap doesn't implement the buffer interface under python 2.",
)
def test_roundtrip_bytes_mmap(all_types): def test_roundtrip_bytes_mmap(all_types):
msg = all_types.TestAllTypes.new_message() msg = all_types.TestAllTypes.new_message()
test_regression.init_all_types(msg) test_regression.init_all_types(msg)
@ -78,7 +81,9 @@ def test_roundtrip_bytes_mmap(all_types):
test_regression.check_all_types(msg) test_regression.check_all_types(msg)
@pytest.mark.skipif(sys.version_info[0] < 3, reason="memoryview is a builtin on Python 3") @pytest.mark.skipif(
sys.version_info[0] < 3, reason="memoryview is a builtin on Python 3"
)
def test_roundtrip_bytes_buffer(all_types): def test_roundtrip_bytes_buffer(all_types):
msg = all_types.TestAllTypes.new_message() msg = all_types.TestAllTypes.new_message()
test_regression.init_all_types(msg) test_regression.init_all_types(msg)
@ -95,8 +100,8 @@ def test_roundtrip_bytes_fail(all_types):
@pytest.mark.skipif( @pytest.mark.skipif(
platform.python_implementation() == 'PyPy', platform.python_implementation() == "PyPy",
reason="This works in PyPy 4.0.1 but travisci's version of PyPy has some bug that fails this test." reason="This works in PyPy 4.0.1 but travisci's version of PyPy has some bug that fails this test.",
) )
def test_roundtrip_bytes_packed(all_types): def test_roundtrip_bytes_packed(all_types):
msg = all_types.TestAllTypes.new_message() msg = all_types.TestAllTypes.new_message()
@ -108,7 +113,9 @@ def test_roundtrip_bytes_packed(all_types):
@contextmanager @contextmanager
def _warnings(expected_count=2, expected_text='This message has already been written once.'): def _warnings(
expected_count=2, expected_text="This message has already been written once."
):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
yield yield
@ -227,10 +234,7 @@ def test_from_bytes_traversal_limit(all_types):
for i in range(0, size): for i in range(0, size):
msg.structList[i].uInt8Field == 0 msg.structList[i].uInt8Field == 0
msg = all_types.TestAllTypes.from_bytes( msg = all_types.TestAllTypes.from_bytes(data, traversal_limit_in_words=2 ** 62)
data,
traversal_limit_in_words=2**62
)
for i in range(0, size): for i in range(0, size):
assert msg.structList[i].uInt8Field == 0 assert msg.structList[i].uInt8Field == 0
@ -247,8 +251,7 @@ def test_from_bytes_packed_traversal_limit(all_types):
msg.structList[i].uInt8Field == 0 msg.structList[i].uInt8Field == 0
msg = all_types.TestAllTypes.from_bytes_packed( msg = all_types.TestAllTypes.from_bytes_packed(
data, data, traversal_limit_in_words=2 ** 62
traversal_limit_in_words=2**62
) )
for i in range(0, size): for i in range(0, size):
assert msg.structList[i].uInt8Field == 0 assert msg.structList[i].uInt8Field == 0

View file

@ -11,17 +11,17 @@ this_dir = os.path.dirname(__file__)
@pytest.fixture @pytest.fixture
def addressbook(): def addressbook():
return capnp.load(os.path.join(this_dir, 'addressbook.capnp')) return capnp.load(os.path.join(this_dir, "addressbook.capnp"))
@pytest.fixture @pytest.fixture
def all_types(): def all_types():
return capnp.load(os.path.join(this_dir, 'all_types.capnp')) return capnp.load(os.path.join(this_dir, "all_types.capnp"))
def test_which_builder(addressbook): def test_which_builder(addressbook):
addresses = addressbook.AddressBook.new_message() addresses = addressbook.AddressBook.new_message()
people = addresses.init('people', 2) people = addresses.init("people", 2)
alice = people[0] alice = people[0]
alice.employment.school = "MIT" alice.employment.school = "MIT"
@ -53,7 +53,7 @@ def test_which_reader(addressbook):
def writeAddressBook(fd): def writeAddressBook(fd):
message = capnp._MallocMessageBuilder() message = capnp._MallocMessageBuilder()
addressBook = message.init_root(addressbook.AddressBook) addressBook = message.init_root(addressbook.AddressBook)
people = addressBook.init('people', 2) people = addressBook.init("people", 2)
alice = people[0] alice = people[0]
alice.employment.school = "MIT" alice.employment.school = "MIT"
@ -89,14 +89,14 @@ def test_which_reader(addressbook):
@pytest.mark.skipif( @pytest.mark.skipif(
capnp.version.LIBCAPNP_VERSION < 5000, capnp.version.LIBCAPNP_VERSION < 5000,
reason="Using ints as enums requires v0.5.0+ of the C++ capnp library" reason="Using ints as enums requires v0.5.0+ of the C++ capnp library",
) )
def test_enum(addressbook): def test_enum(addressbook):
addresses = addressbook.AddressBook.new_message() addresses = addressbook.AddressBook.new_message()
people = addresses.init('people', 2) people = addresses.init("people", 2)
alice = people[0] alice = people[0]
phones = alice.init('phones', 2) phones = alice.init("phones", 2)
assert phones[0].type == phones[1].type assert phones[0].type == phones[1].type
@ -104,7 +104,7 @@ def test_enum(addressbook):
assert phones[0].type != phones[1].type assert phones[0].type != phones[1].type
phones[1].type = 'home' phones[1].type = "home"
assert phones[0].type == phones[1].type assert phones[0].type == phones[1].type
@ -112,12 +112,12 @@ def test_enum(addressbook):
def test_builder_set(addressbook): def test_builder_set(addressbook):
person = addressbook.Person.new_message() person = addressbook.Person.new_message()
person.name = 'test' person.name = "test"
assert person.name == 'test' assert person.name == "test"
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
person.foo = 'test' person.foo = "test"
def test_builder_set_from_list(all_types): def test_builder_set_from_list(all_types):
@ -150,9 +150,9 @@ def test_unicode_str(all_types):
msg = all_types.TestAllTypes.new_message() msg = all_types.TestAllTypes.new_message()
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
msg.textField = u"f\u00e6oo".encode('utf-8') msg.textField = u"f\u00e6oo".encode("utf-8")
assert msg.textField.decode('utf-8') == u"f\u00e6oo" assert msg.textField.decode("utf-8") == u"f\u00e6oo"
else: else:
msg.textField = "f\u00e6oo" msg.textField = "f\u00e6oo"
@ -164,11 +164,13 @@ def test_new_message(all_types):
assert msg.int32Field == 100 assert msg.int32Field == 100
msg = all_types.TestAllTypes.new_message(structField={'int32Field': 100}) msg = all_types.TestAllTypes.new_message(structField={"int32Field": 100})
assert msg.structField.int32Field == 100 assert msg.structField.int32Field == 100
msg = all_types.TestAllTypes.new_message(structList=[{'int32Field': 100}, {'int32Field': 101}]) msg = all_types.TestAllTypes.new_message(
structList=[{"int32Field": 100}, {"int32Field": 101}]
)
assert msg.structList[0].int32Field == 100 assert msg.structList[0].int32Field == 100
assert msg.structList[1].int32Field == 101 assert msg.structList[1].int32Field == 101
@ -177,7 +179,7 @@ def test_new_message(all_types):
assert msg.int32Field == 100 assert msg.int32Field == 100
msg = all_types.TestAllTypes.new_message(**{'int32Field': 100, 'int64Field': 101}) msg = all_types.TestAllTypes.new_message(**{"int32Field": 100, "int64Field": 101})
assert msg.int32Field == 100 assert msg.int32Field == 100
assert msg.int64Field == 101 assert msg.int64Field == 101
@ -186,45 +188,55 @@ def test_new_message(all_types):
def test_set_dict(all_types): def test_set_dict(all_types):
msg = all_types.TestAllTypes.new_message() msg = all_types.TestAllTypes.new_message()
msg.structField = {'int32Field': 100} msg.structField = {"int32Field": 100}
assert msg.structField.int32Field == 100 assert msg.structField.int32Field == 100
msg.init('structList', 2) msg.init("structList", 2)
msg.structList[0] = {'int32Field': 102} msg.structList[0] = {"int32Field": 102}
assert msg.structList[0].int32Field == 102 assert msg.structList[0].int32Field == 102
def test_set_dict_union(addressbook): def test_set_dict_union(addressbook):
person = addressbook.Person.new_message(**{'employment': {'employer': {'name': 'foo'}}}) person = addressbook.Person.new_message(
**{"employment": {"employer": {"name": "foo"}}}
)
assert person.employment.which == addressbook.Person.Employment.employer assert person.employment.which == addressbook.Person.Employment.employer
assert person.employment.employer.name == 'foo' assert person.employment.employer.name == "foo"
def test_union_enum(all_types): def test_union_enum(all_types):
assert all_types.UnionAllTypes.Union.UnionStructField1 == 0 assert all_types.UnionAllTypes.Union.UnionStructField1 == 0
assert all_types.UnionAllTypes.Union.UnionStructField2 == 1 assert all_types.UnionAllTypes.Union.UnionStructField2 == 1
msg = all_types.UnionAllTypes.new_message(**{'unionStructField1': {'textField': "foo"}}) msg = all_types.UnionAllTypes.new_message(
**{"unionStructField1": {"textField": "foo"}}
)
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField1 assert msg.which == all_types.UnionAllTypes.Union.UnionStructField1
assert msg.which == 'unionStructField1' assert msg.which == "unionStructField1"
assert msg.which == 0 assert msg.which == 0
msg = all_types.UnionAllTypes.new_message(**{'unionStructField2': {'textField': "foo"}}) msg = all_types.UnionAllTypes.new_message(
**{"unionStructField2": {"textField": "foo"}}
)
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField2 assert msg.which == all_types.UnionAllTypes.Union.UnionStructField2
assert msg.which == 'unionStructField2' assert msg.which == "unionStructField2"
assert msg.which == 1 assert msg.which == 1
assert all_types.GroupedUnionAllTypes.Union.G1 == 0 assert all_types.GroupedUnionAllTypes.Union.G1 == 0
assert all_types.GroupedUnionAllTypes.Union.G2 == 1 assert all_types.GroupedUnionAllTypes.Union.G2 == 1
msg = all_types.GroupedUnionAllTypes.new_message(**{'g1': {'unionStructField1': {'textField': "foo"}}}) msg = all_types.GroupedUnionAllTypes.new_message(
**{"g1": {"unionStructField1": {"textField": "foo"}}}
)
assert msg.which == all_types.GroupedUnionAllTypes.Union.G1 assert msg.which == all_types.GroupedUnionAllTypes.Union.G1
msg = all_types.GroupedUnionAllTypes.new_message(**{'g2': {'unionStructField2': {'textField': "foo"}}}) msg = all_types.GroupedUnionAllTypes.new_message(
**{"g2": {"unionStructField2": {"textField": "foo"}}}
)
assert msg.which == all_types.GroupedUnionAllTypes.Union.G2 assert msg.which == all_types.GroupedUnionAllTypes.Union.G2
msg = all_types.UnionAllTypes.new_message() msg = all_types.UnionAllTypes.new_message()
@ -236,44 +248,55 @@ def isstr(s):
def test_to_dict_enum(addressbook): def test_to_dict_enum(addressbook):
person = addressbook.Person.new_message(**{'phones': [{'number': '999-9999', 'type': 'mobile'}]}) person = addressbook.Person.new_message(
**{"phones": [{"number": "999-9999", "type": "mobile"}]}
)
field = person.to_dict()['phones'][0]['type'] field = person.to_dict()["phones"][0]["type"]
assert isstr(field) assert isstr(field)
assert field == 'mobile' assert field == "mobile"
def test_explicit_field(addressbook): def test_explicit_field(addressbook):
person = addressbook.Person.new_message(**{'name': 'Test'}) person = addressbook.Person.new_message(**{"name": "Test"})
name_field = addressbook.Person.schema.fields['name'] name_field = addressbook.Person.schema.fields["name"]
assert person.name == person._get_by_field(name_field) assert person.name == person._get_by_field(name_field)
assert person.name == person.as_reader()._get_by_field(name_field) assert person.name == person.as_reader()._get_by_field(name_field)
def test_to_dict_verbose(addressbook): def test_to_dict_verbose(addressbook):
person = addressbook.Person.new_message(**{'name': 'Test'}) person = addressbook.Person.new_message(**{"name": "Test"})
assert person.to_dict(verbose=True)['phones'] == [] assert person.to_dict(verbose=True)["phones"] == []
if sys.version_info >= (2, 7): if sys.version_info >= (2, 7):
assert person.to_dict(verbose=True, ordered=True)['phones'] == [] assert person.to_dict(verbose=True, ordered=True)["phones"] == []
with pytest.raises(KeyError): with pytest.raises(KeyError):
assert person.to_dict()['phones'] == [] assert person.to_dict()["phones"] == []
def test_to_dict_ordered(addressbook): def test_to_dict_ordered(addressbook):
person = addressbook.Person.new_message(**{ person = addressbook.Person.new_message(
'name': 'Alice', **{
'phones': [{'type': 'mobile', 'number': '555-1212'}], "name": "Alice",
'id': 123, "phones": [{"type": "mobile", "number": "555-1212"}],
'employment': {'school': 'MIT'}, 'email': 'alice@example.com' "id": 123,
}) "employment": {"school": "MIT"},
"email": "alice@example.com",
}
)
if sys.version_info >= (2, 7): if sys.version_info >= (2, 7):
assert list(person.to_dict(ordered=True).keys()) == ['id', 'name', 'email', 'phones', 'employment'] assert list(person.to_dict(ordered=True).keys()) == [
"id",
"name",
"email",
"phones",
"employment",
]
else: else:
with pytest.raises(Exception): with pytest.raises(Exception):
person.to_dict(ordered=True) person.to_dict(ordered=True)
@ -281,7 +304,7 @@ def test_to_dict_ordered(addressbook):
def test_nested_list(addressbook): def test_nested_list(addressbook):
struct = addressbook.NestedList.new_message() struct = addressbook.NestedList.new_message()
struct.init('list', 2) struct.init("list", 2)
struct.list.init(0, 1) struct.list.init(0, 1)
struct.list.init(1, 2) struct.list.init(1, 2)

View file

@ -1,6 +1,6 @@
''' """
thread test thread test
''' """
import platform import platform
import socket import socket
@ -16,13 +16,13 @@ import test_capability_capnp
@pytest.mark.skipif( @pytest.mark.skipif(
platform.python_implementation() == 'PyPy', platform.python_implementation() == "PyPy",
reason="pycapnp's GIL handling isn't working properly at the moment for PyPy" reason="pycapnp's GIL handling isn't working properly at the moment for PyPy",
) )
def test_making_event_loop(): def test_making_event_loop():
''' """
Event loop test Event loop test
''' """
capnp.remove_event_loop(True) capnp.remove_event_loop(True)
capnp.create_event_loop() capnp.create_event_loop()
@ -31,13 +31,13 @@ def test_making_event_loop():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.python_implementation() == 'PyPy', platform.python_implementation() == "PyPy",
reason="pycapnp's GIL handling isn't working properly at the moment for PyPy" reason="pycapnp's GIL handling isn't working properly at the moment for PyPy",
) )
def test_making_threaded_event_loop(): def test_making_threaded_event_loop():
''' """
Threaded event loop test Threaded event loop test
''' """
# The following raises a KjException, and if not caught causes an SIGABRT: # The following raises a KjException, and if not caught causes an SIGABRT:
# kj/async.c++:973: failed: expected head == nullptr; EventLoop destroyed with events still in the queue. # kj/async.c++:973: failed: expected head == nullptr; EventLoop destroyed with events still in the queue.
# Memory leak?; head->trace() = kj::_::ForkHub<kj::_::Void> # Memory leak?; head->trace() = kj::_::ForkHub<kj::_::Void>
@ -54,27 +54,28 @@ def test_making_threaded_event_loop():
class Server(test_capability_capnp.TestInterface.Server): class Server(test_capability_capnp.TestInterface.Server):
''' """
Server Server
''' """
def __init__(self, val=100): def __init__(self, val=100):
self.val = val self.val = val
def foo(self, i, j, **kwargs): def foo(self, i, j, **kwargs):
''' """
foo foo
''' """
return str(i * 5 + self.val) return str(i * 5 + self.val)
@pytest.mark.skipif( @pytest.mark.skipif(
platform.python_implementation() == 'PyPy', platform.python_implementation() == "PyPy",
reason="pycapnp's GIL handling isn't working properly at the moment for PyPy" reason="pycapnp's GIL handling isn't working properly at the moment for PyPy",
) )
def test_using_threads(): def test_using_threads():
''' """
Thread test Thread test
''' """
capnp.remove_event_loop(True) capnp.remove_event_loop(True)
capnp.create_event_loop(True) capnp.create_event_loop(True)
@ -94,4 +95,4 @@ def test_using_threads():
remote = cap.foo(i=5) remote = cap.foo(i=5)
response = remote.wait() response = remote.wait()
assert response.x == '125' assert response.x == "125"