mirror of
https://github.com/capnproto/pycapnp.git
synced 2025-03-04 08:24:43 +01:00
_StructModuleWhich: Use enum
This commit is contained in:
parent
0c904443bd
commit
1795cac230
2 changed files with 19 additions and 7 deletions
|
@ -22,6 +22,7 @@ from libc.string cimport memcpy
|
|||
import array
|
||||
import asyncio
|
||||
import collections as _collections
|
||||
import enum as _enum
|
||||
import inspect as _inspect
|
||||
import os as _os
|
||||
import random as _random
|
||||
|
@ -3152,8 +3153,12 @@ cdef _new_message(self, kwargs, num_first_segment_words):
|
|||
return msg
|
||||
|
||||
|
||||
class _StructModuleWhich(object):
|
||||
pass
|
||||
class _StructModuleWhich(_enum.Enum):
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
else:
|
||||
return self.name == other
|
||||
|
||||
|
||||
class _StructModule(object):
|
||||
|
@ -3170,17 +3175,19 @@ class _StructModule(object):
|
|||
if field_schema.discriminantCount == 0:
|
||||
sub_module = _StructModule(raw_schema, name)
|
||||
else:
|
||||
sub_module = _StructModuleWhich()
|
||||
setattr(sub_module, 'schema', raw_schema)
|
||||
mapping = []
|
||||
for union_field in field_schema.fields:
|
||||
setattr(sub_module, union_field.name, union_field.discriminantValue)
|
||||
mapping.append((union_field.name, union_field.discriminantValue))
|
||||
sub_module = _StructModuleWhich("StructModuleWhich", mapping)
|
||||
setattr(sub_module, 'schema', raw_schema)
|
||||
setattr(self, name, sub_module)
|
||||
if schema.union_fields and not schema.non_union_fields:
|
||||
sub_module = _StructModuleWhich()
|
||||
mapping = []
|
||||
for union_field in schema.node.struct.fields:
|
||||
name = union_field.name
|
||||
name = name[0].upper() + name[1:]
|
||||
setattr(sub_module, name, union_field.discriminantValue)
|
||||
mapping.append((name, union_field.discriminantValue))
|
||||
sub_module = _StructModuleWhich("StructModuleWhich", mapping)
|
||||
setattr(self, 'Union', sub_module)
|
||||
|
||||
def read(self, file, traversal_limit_in_words=None, nesting_limit=None):
|
||||
|
|
|
@ -200,8 +200,13 @@ def test_union_enum(all_types):
|
|||
|
||||
msg = all_types.UnionAllTypes.new_message(**{'unionStructField1': {'textField': "foo"}})
|
||||
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField1
|
||||
assert msg.which == 'unionStructField1'
|
||||
assert msg.which == 0
|
||||
|
||||
msg = all_types.UnionAllTypes.new_message(**{'unionStructField2': {'textField': "foo"}})
|
||||
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField2
|
||||
assert msg.which == 'unionStructField2'
|
||||
assert msg.which == 1
|
||||
|
||||
assert all_types.GroupedUnionAllTypes.Union.G1 == 0
|
||||
assert all_types.GroupedUnionAllTypes.Union.G2 == 1
|
||||
|
|
Loading…
Add table
Reference in a new issue