Commit 0e2089c7 authored by cyyber's avatar cyyber

Calling Keychecker before checking key in MessageMap

parent 31c54d12
...@@ -549,10 +549,10 @@ class MessageMap(MutableMapping): ...@@ -549,10 +549,10 @@ class MessageMap(MutableMapping):
self._values = {} self._values = {}
def __getitem__(self, key): def __getitem__(self, key):
key = self._key_checker.CheckValue(key)
try: try:
return self._values[key] return self._values[key]
except KeyError: except KeyError:
key = self._key_checker.CheckValue(key)
new_element = self._message_descriptor._concrete_class() new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener) new_element._SetListener(self._message_listener)
self._values[key] = new_element self._values[key] = new_element
...@@ -584,12 +584,14 @@ class MessageMap(MutableMapping): ...@@ -584,12 +584,14 @@ class MessageMap(MutableMapping):
return default return default
def __contains__(self, item): def __contains__(self, item):
item = self._key_checker.CheckValue(item)
return item in self._values return item in self._values
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise ValueError('May not set values directly, call my_map[key].foo = 5') raise ValueError('May not set values directly, call my_map[key].foo = 5')
def __delitem__(self, key): def __delitem__(self, key):
key = self._key_checker.CheckValue(key)
del self._values[key] del self._values[key]
self._message_listener.Modified() self._message_listener.Modified()
......
...@@ -1480,12 +1480,8 @@ class Proto3Test(BaseTestCase): ...@@ -1480,12 +1480,8 @@ class Proto3Test(BaseTestCase):
submsg = msg.map_int32_foreign_message[5] submsg = msg.map_int32_foreign_message[5]
self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
# TODO(jieluo): Fix python and cpp extension diff. with self.assertRaises(TypeError):
if api_implementation.Type() == 'cpp': msg.map_int32_foreign_message.get('')
with self.assertRaises(TypeError):
msg.map_int32_foreign_message.get('')
else:
self.assertEqual(None, msg.map_int32_foreign_message.get(''))
def testScalarMap(self): def testScalarMap(self):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
...@@ -1695,12 +1691,8 @@ class Proto3Test(BaseTestCase): ...@@ -1695,12 +1691,8 @@ class Proto3Test(BaseTestCase):
del msg2.map_int32_foreign_message[222] del msg2.map_int32_foreign_message[222]
self.assertFalse(222 in msg2.map_int32_foreign_message) self.assertFalse(222 in msg2.map_int32_foreign_message)
if api_implementation.Type() == 'cpp': with self.assertRaises(TypeError):
with self.assertRaises(TypeError): del msg2.map_int32_foreign_message['']
del msg2.map_int32_foreign_message['']
else:
with self.assertRaises(KeyError):
del msg2.map_int32_foreign_message['']
def testMergeFromBadType(self): def testMergeFromBadType(self):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment