_StructModuleWhich: Use enum

This commit is contained in:
John Vandenberg 2021-05-31 16:32:46 +08:00
parent 0c904443bd
commit 1795cac230
2 changed files with 19 additions and 7 deletions

View file

@ -22,6 +22,7 @@ from libc.string cimport memcpy
import array
import asyncio
import collections as _collections
import enum as _enum
import inspect as _inspect
import os as _os
import random as _random
@ -3152,8 +3153,12 @@ cdef _new_message(self, kwargs, num_first_segment_words):
return msg
class _StructModuleWhich(object):
pass
class _StructModuleWhich(_enum.Enum):
def __eq__(self, other):
if isinstance(other, int):
return self.value == other
else:
return self.name == other
class _StructModule(object):
@ -3170,17 +3175,19 @@ class _StructModule(object):
if field_schema.discriminantCount == 0:
sub_module = _StructModule(raw_schema, name)
else:
sub_module = _StructModuleWhich()
setattr(sub_module, 'schema', raw_schema)
mapping = []
for union_field in field_schema.fields:
setattr(sub_module, union_field.name, union_field.discriminantValue)
mapping.append((union_field.name, union_field.discriminantValue))
sub_module = _StructModuleWhich("StructModuleWhich", mapping)
setattr(sub_module, 'schema', raw_schema)
setattr(self, name, sub_module)
if schema.union_fields and not schema.non_union_fields:
sub_module = _StructModuleWhich()
mapping = []
for union_field in schema.node.struct.fields:
name = union_field.name
name = name[0].upper() + name[1:]
setattr(sub_module, name, union_field.discriminantValue)
mapping.append((name, union_field.discriminantValue))
sub_module = _StructModuleWhich("StructModuleWhich", mapping)
setattr(self, 'Union', sub_module)
def read(self, file, traversal_limit_in_words=None, nesting_limit=None):

View file

@ -200,8 +200,13 @@ def test_union_enum(all_types):
msg = all_types.UnionAllTypes.new_message(**{'unionStructField1': {'textField': "foo"}})
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField1
assert msg.which == 'unionStructField1'
assert msg.which == 0
msg = all_types.UnionAllTypes.new_message(**{'unionStructField2': {'textField': "foo"}})
assert msg.which == all_types.UnionAllTypes.Union.UnionStructField2
assert msg.which == 'unionStructField2'
assert msg.which == 1
assert all_types.GroupedUnionAllTypes.Union.G1 == 0
assert all_types.GroupedUnionAllTypes.Union.G2 == 1