Commit ada65567 authored by Jisi Liu's avatar Jisi Liu

Down integrate from Google internal.

Change-Id: I34d301133eea9c6f3a822c47d1f91e136fd33145
parent 581be246
...@@ -48,8 +48,9 @@ Installation ...@@ -48,8 +48,9 @@ Installation
$ python setup.py build $ python setup.py build
$ python setup.py google_test $ python setup.py google_test
If you want to test c++ implementation, run: If you want to build/test c++ implementation, run:
$ python setup.py test --cpp_implementation $ python setup.py build --cpp_implementation
$ python setup.py google_test --cpp_implementation
If some tests fail, this library may not work correctly on your If some tests fail, this library may not work correctly on your
system. Continue at your own risk. system. Continue at your own risk.
......
This diff is collapsed.
...@@ -64,6 +64,9 @@ from google.protobuf import descriptor_database ...@@ -64,6 +64,9 @@ from google.protobuf import descriptor_database
from google.protobuf import text_encoding from google.protobuf import text_encoding
_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS
def _NormalizeFullyQualifiedName(name): def _NormalizeFullyQualifiedName(name):
"""Remove leading period from fully-qualified type name. """Remove leading period from fully-qualified type name.
...@@ -271,58 +274,81 @@ class DescriptorPool(object): ...@@ -271,58 +274,81 @@ class DescriptorPool(object):
file_descriptor = descriptor.FileDescriptor( file_descriptor = descriptor.FileDescriptor(
name=file_proto.name, name=file_proto.name,
package=file_proto.package, package=file_proto.package,
syntax=file_proto.syntax,
options=file_proto.options, options=file_proto.options,
serialized_pb=file_proto.SerializeToString(), serialized_pb=file_proto.SerializeToString(),
dependencies=direct_deps) dependencies=direct_deps)
scope = {} if _USE_C_DESCRIPTORS:
# When using C++ descriptors, all objects defined in the file were added
# This loop extracts all the message and enum types from all the # to the C++ database when the FileDescriptor was built above.
# dependencoes of the file_proto. This is necessary to create the # Just add them to this descriptor pool.
# scope of available message types when defining the passed in def _AddMessageDescriptor(message_desc):
# file proto. self._descriptors[message_desc.full_name] = message_desc
for dependency in built_deps: for nested in message_desc.nested_types:
scope.update(self._ExtractSymbols( _AddMessageDescriptor(nested)
dependency.message_types_by_name.values())) for enum_type in message_desc.enum_types:
scope.update((_PrefixWithDot(enum.full_name), enum) _AddEnumDescriptor(enum_type)
for enum in dependency.enum_types_by_name.values()) def _AddEnumDescriptor(enum_desc):
self._enum_descriptors[enum_desc.full_name] = enum_desc
for message_type in file_proto.message_type: for message_type in file_descriptor.message_types_by_name.values():
message_desc = self._ConvertMessageDescriptor( _AddMessageDescriptor(message_type)
message_type, file_proto.package, file_descriptor, scope) for enum_type in file_descriptor.enum_types_by_name.values():
file_descriptor.message_types_by_name[message_desc.name] = message_desc _AddEnumDescriptor(enum_type)
for enum_type in file_proto.enum_type:
file_descriptor.enum_types_by_name[enum_type.name] = (
self._ConvertEnumDescriptor(enum_type, file_proto.package,
file_descriptor, None, scope))
for index, extension_proto in enumerate(file_proto.extension):
extension_desc = self.MakeFieldDescriptor(
extension_proto, file_proto.package, index, is_extension=True)
extension_desc.containing_type = self._GetTypeFromScope(
file_descriptor.package, extension_proto.extendee, scope)
self.SetFieldType(extension_proto, extension_desc,
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = extension_desc
for desc_proto in file_proto.message_type:
self.SetAllFieldTypes(file_proto.package, desc_proto, scope)
if file_proto.package:
desc_proto_prefix = _PrefixWithDot(file_proto.package)
else: else:
desc_proto_prefix = '' scope = {}
# This loop extracts all the message and enum types from all the
# dependencies of the file_proto. This is necessary to create the
# scope of available message types when defining the passed in
# file proto.
for dependency in built_deps:
scope.update(self._ExtractSymbols(
dependency.message_types_by_name.values()))
scope.update((_PrefixWithDot(enum.full_name), enum)
for enum in dependency.enum_types_by_name.values())
for message_type in file_proto.message_type:
message_desc = self._ConvertMessageDescriptor(
message_type, file_proto.package, file_descriptor, scope,
file_proto.syntax)
file_descriptor.message_types_by_name[message_desc.name] = (
message_desc)
for enum_type in file_proto.enum_type:
file_descriptor.enum_types_by_name[enum_type.name] = (
self._ConvertEnumDescriptor(enum_type, file_proto.package,
file_descriptor, None, scope))
for index, extension_proto in enumerate(file_proto.extension):
extension_desc = self.MakeFieldDescriptor(
extension_proto, file_proto.package, index, is_extension=True)
extension_desc.containing_type = self._GetTypeFromScope(
file_descriptor.package, extension_proto.extendee, scope)
self.SetFieldType(extension_proto, extension_desc,
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = (
extension_desc)
for desc_proto in file_proto.message_type:
self.SetAllFieldTypes(file_proto.package, desc_proto, scope)
if file_proto.package:
desc_proto_prefix = _PrefixWithDot(file_proto.package)
else:
desc_proto_prefix = ''
for desc_proto in file_proto.message_type:
desc = self._GetTypeFromScope(
desc_proto_prefix, desc_proto.name, scope)
file_descriptor.message_types_by_name[desc_proto.name] = desc
for desc_proto in file_proto.message_type:
desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope)
file_descriptor.message_types_by_name[desc_proto.name] = desc
self.Add(file_proto) self.Add(file_proto)
self._file_descriptors[file_proto.name] = file_descriptor self._file_descriptors[file_proto.name] = file_descriptor
return self._file_descriptors[file_proto.name] return self._file_descriptors[file_proto.name]
def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
scope=None): scope=None, syntax=None):
"""Adds the proto to the pool in the specified package. """Adds the proto to the pool in the specified package.
Args: Args:
...@@ -349,7 +375,8 @@ class DescriptorPool(object): ...@@ -349,7 +375,8 @@ class DescriptorPool(object):
scope = {} scope = {}
nested = [ nested = [
self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope) self._ConvertMessageDescriptor(
nested, desc_name, file_desc, scope, syntax)
for nested in desc_proto.nested_type] for nested in desc_proto.nested_type]
enums = [ enums = [
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
...@@ -383,7 +410,8 @@ class DescriptorPool(object): ...@@ -383,7 +410,8 @@ class DescriptorPool(object):
extension_ranges=extension_ranges, extension_ranges=extension_ranges,
file=file_desc, file=file_desc,
serialized_start=None, serialized_start=None,
serialized_end=None) serialized_end=None,
syntax=syntax)
for nested in desc.nested_types: for nested in desc.nested_types:
nested.containing_type = desc nested.containing_type = desc
for enum in desc.enum_types: for enum in desc.enum_types:
......
...@@ -50,10 +50,7 @@ namespace python { ...@@ -50,10 +50,7 @@ namespace python {
// and // and
// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 // PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
#ifdef PYTHON_PROTO2_CPP_IMPL_V1 #ifdef PYTHON_PROTO2_CPP_IMPL_V1
#if PY_MAJOR_VERSION >= 3 #error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported."
#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3."
#endif
static int kImplVersion = 1;
#else #else
#ifdef PYTHON_PROTO2_CPP_IMPL_V2 #ifdef PYTHON_PROTO2_CPP_IMPL_V2
static int kImplVersion = 2; static int kImplVersion = 2;
...@@ -62,14 +59,7 @@ static int kImplVersion = 2; ...@@ -62,14 +59,7 @@ static int kImplVersion = 2;
static int kImplVersion = 0; static int kImplVersion = 0;
#else #else
// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. static int kImplVersion = -1; // -1 means "Unspecified by compiler flags".
// Python 2 still uses the Python version by default until some compatibility
// issues can be worked around.
#if PY_MAJOR_VERSION >= 3
static int kImplVersion = 2;
#else
static int kImplVersion = 0;
#endif
#endif // PYTHON_PROTO2_PYTHON_IMPL #endif // PYTHON_PROTO2_PYTHON_IMPL
#endif // PYTHON_PROTO2_CPP_IMPL_V2 #endif // PYTHON_PROTO2_CPP_IMPL_V2
......
...@@ -40,14 +40,33 @@ try: ...@@ -40,14 +40,33 @@ try:
# The compile-time constants in the _api_implementation module can be used to # The compile-time constants in the _api_implementation module can be used to
# switch to a certain implementation of the Python API at build time. # switch to a certain implementation of the Python API at build time.
_api_version = _api_implementation.api_version _api_version = _api_implementation.api_version
del _api_implementation _proto_extension_modules_exist_in_build = True
except ImportError: except ImportError:
_api_version = 0 _api_version = -1 # Unspecified by compiler flags.
_proto_extension_modules_exist_in_build = False
if _api_version == 1:
raise ValueError('api_version=1 is no longer supported.')
if _api_version < 0: # Still unspecified?
try:
# The presence of this module in a build allows the proto implementation to
# be upgraded merely via build deps rather than a compiler flag or the
# runtime environment variable.
# pylint: disable=g-import-not-at-top
from google.protobuf import _use_fast_cpp_protos
# Work around a known issue in the classic bootstrap .par import hook.
if not _use_fast_cpp_protos:
raise ImportError('_use_fast_cpp_protos import succeeded but was None')
del _use_fast_cpp_protos
_api_version = 2
except ImportError:
if _proto_extension_modules_exist_in_build:
if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
_api_version = 2
# TODO(b/17427486): Make Python 2 default to C++ impl v2.
_default_implementation_type = ( _default_implementation_type = (
'python' if _api_version == 0 else 'cpp') 'python' if _api_version <= 0 else 'cpp')
_default_version_str = (
'1' if _api_version <= 1 else '2')
# This environment variable can be used to switch to a certain implementation # This environment variable can be used to switch to a certain implementation
# of the Python API, overriding the compile-time constants in the # of the Python API, overriding the compile-time constants in the
...@@ -64,13 +83,12 @@ if _implementation_type != 'python': ...@@ -64,13 +83,12 @@ if _implementation_type != 'python':
# _api_implementation module. Right now only 1 and 2 are valid values. Any other # _api_implementation module. Right now only 1 and 2 are valid values. Any other
# value will be ignored. # value will be ignored.
_implementation_version_str = os.getenv( _implementation_version_str = os.getenv(
'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2')
_default_version_str)
if _implementation_version_str not in ('1', '2'): if _implementation_version_str != '2':
raise ValueError( raise ValueError(
"unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' +
_implementation_version_str + "' (supported versions: 1, 2)" _implementation_version_str + '" (supported versions: 2)'
) )
_implementation_version = int(_implementation_version_str) _implementation_version = int(_implementation_version_str)
......
#! /usr/bin/python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test that the api_implementation defaults are what we expect."""
import os
import sys
# Clear environment implementation settings before the google3 imports.
os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None)
os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None)
# pylint: disable=g-import-not-at-top
from google.apputils import basetest
from google.protobuf.internal import api_implementation
class ApiImplementationDefaultTest(basetest.TestCase):
if sys.version_info.major <= 2:
def testThatPythonIsTheDefault(self):
"""If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
self.assertEqual('python', api_implementation.Type())
else:
def testThatCppApiV2IsTheDefault(self):
"""If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
self.assertEqual('cpp', api_implementation.Type())
self.assertEqual(2, api_implementation.Version())
if __name__ == '__main__':
basetest.main()
...@@ -41,7 +41,6 @@ are: ...@@ -41,7 +41,6 @@ are:
__author__ = 'petar@google.com (Petar Petrov)' __author__ = 'petar@google.com (Petar Petrov)'
class BaseContainer(object): class BaseContainer(object):
"""Base container class.""" """Base container class."""
...@@ -119,15 +118,23 @@ class RepeatedScalarFieldContainer(BaseContainer): ...@@ -119,15 +118,23 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._message_listener.Modified() self._message_listener.Modified()
def extend(self, elem_seq): def extend(self, elem_seq):
"""Extends by appending the given sequence. Similar to list.extend().""" """Extends by appending the given iterable. Similar to list.extend()."""
if not elem_seq:
return
new_values = [] if elem_seq is None:
for elem in elem_seq: return
new_values.append(self._type_checker.CheckValue(elem)) try:
self._values.extend(new_values) elem_seq_iter = iter(elem_seq)
self._message_listener.Modified() except TypeError:
if not elem_seq:
# silently ignore falsy inputs :-/.
# TODO(ptucker): Deprecate this behavior. b/18413862
return
raise
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
if new_values:
self._values.extend(new_values)
self._message_listener.Modified()
def MergeFrom(self, other): def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this """Appends the contents of another repeated field of the same type to this
...@@ -141,6 +148,12 @@ class RepeatedScalarFieldContainer(BaseContainer): ...@@ -141,6 +148,12 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._values.remove(elem) self._values.remove(elem)
self._message_listener.Modified() self._message_listener.Modified()
def pop(self, key=-1):
"""Removes and returns an item at a given index. Similar to list.pop()."""
value = self._values[key]
self.__delitem__(key)
return value
def __setitem__(self, key, value): def __setitem__(self, key, value):
"""Sets the item on the specified position.""" """Sets the item on the specified position."""
if isinstance(key, slice): # PY3 if isinstance(key, slice): # PY3
...@@ -245,6 +258,12 @@ class RepeatedCompositeFieldContainer(BaseContainer): ...@@ -245,6 +258,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
self._values.remove(elem) self._values.remove(elem)
self._message_listener.Modified() self._message_listener.Modified()
def pop(self, key=-1):
"""Removes and returns an item at a given index. Similar to list.pop()."""
value = self._values[key]
self.__delitem__(key)
return value
def __getslice__(self, start, stop): def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices.""" """Retrieves the subset of items from between the specified indices."""
return self._values[start:stop] return self._values[start:stop]
......
...@@ -621,9 +621,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): ...@@ -621,9 +621,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
if value is None: if value is None:
value = field_dict.setdefault(key, new_default(message)) value = field_dict.setdefault(key, new_default(message))
while 1: while 1:
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read length. # Read length.
(size, pos) = local_DecodeVarint(buffer, pos) (size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size new_pos = pos + size
......
...@@ -34,12 +34,16 @@ ...@@ -34,12 +34,16 @@
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import sys
from google.apputils import basetest from google.apputils import basetest
from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
from google.protobuf import descriptor from google.protobuf import descriptor
from google.protobuf import symbol_database
from google.protobuf import text_format from google.protobuf import text_format
...@@ -51,41 +55,28 @@ name: 'TestEmptyMessage' ...@@ -51,41 +55,28 @@ name: 'TestEmptyMessage'
class DescriptorTest(basetest.TestCase): class DescriptorTest(basetest.TestCase):
def setUp(self): def setUp(self):
self.my_file = descriptor.FileDescriptor( file_proto = descriptor_pb2.FileDescriptorProto(
name='some/filename/some.proto', name='some/filename/some.proto',
package='protobuf_unittest' package='protobuf_unittest')
) message_proto = file_proto.message_type.add(
self.my_enum = descriptor.EnumDescriptor( name='NestedMessage')
name='ForeignEnum', message_proto.field.add(
full_name='protobuf_unittest.ForeignEnum', name='bb',
filename=None, number=1,
file=self.my_file, type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
values=[ label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), enum_proto = message_proto.enum_type.add(
descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), name='ForeignEnum')
descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6), enum_proto.value.add(name='FOREIGN_FOO', number=4)
]) enum_proto.value.add(name='FOREIGN_BAR', number=5)
self.my_message = descriptor.Descriptor( enum_proto.value.add(name='FOREIGN_BAZ', number=6)
name='NestedMessage',
full_name='protobuf_unittest.TestAllTypes.NestedMessage', descriptor_pool = symbol_database.Default().pool
filename=None, descriptor_pool.Add(file_proto)
file=self.my_file, self.my_file = descriptor_pool.FindFileByName(file_proto.name)
containing_type=None, self.my_message = self.my_file.message_types_by_name[message_proto.name]
fields=[ self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
descriptor.FieldDescriptor(
name='bb',
full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
index=0, number=1,
type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None),
],
nested_types=[],
enum_types=[
self.my_enum,
],
extensions=[])
self.my_method = descriptor.MethodDescriptor( self.my_method = descriptor.MethodDescriptor(
name='Bar', name='Bar',
full_name='protobuf_unittest.TestService.Bar', full_name='protobuf_unittest.TestService.Bar',
...@@ -173,6 +164,11 @@ class DescriptorTest(basetest.TestCase): ...@@ -173,6 +164,11 @@ class DescriptorTest(basetest.TestCase):
self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2,
method_options.Extensions[method_opt1]) method_options.Extensions[method_opt1])
message_descriptor = (
unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
self.assertTrue(file_descriptor.has_options)
self.assertFalse(message_descriptor.has_options)
def testDifferentCustomOptionTypes(self): def testDifferentCustomOptionTypes(self):
kint32min = -2**31 kint32min = -2**31
kint64min = -2**63 kint64min = -2**63
...@@ -394,6 +390,108 @@ class DescriptorTest(basetest.TestCase): ...@@ -394,6 +390,108 @@ class DescriptorTest(basetest.TestCase):
self.assertEqual(self.my_file.name, 'some/filename/some.proto') self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest') self.assertEqual(self.my_file.package, 'protobuf_unittest')
@basetest.unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'Immutability of descriptors is only enforced in v2 implementation')
def testImmutableCppDescriptor(self):
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
with self.assertRaises(AttributeError):
message_descriptor.fields_by_name = None
with self.assertRaises(TypeError):
message_descriptor.fields_by_name['Another'] = None
with self.assertRaises(TypeError):
message_descriptor.fields.append(None)
class GeneratedDescriptorTest(basetest.TestCase):
"""Tests for the properties of descriptors in generated code."""
def CheckMessageDescriptor(self, message_descriptor):
# Basic properties
self.assertEqual(message_descriptor.name, 'TestAllTypes')
self.assertEqual(message_descriptor.full_name,
'protobuf_unittest.TestAllTypes')
# Test equality and hashability
self.assertEqual(message_descriptor, message_descriptor)
self.assertEqual(message_descriptor.fields[0].containing_type,
message_descriptor)
self.assertIn(message_descriptor, [message_descriptor])
self.assertIn(message_descriptor, {message_descriptor: None})
# Test field containers
self.CheckDescriptorSequence(message_descriptor.fields)
self.CheckDescriptorMapping(message_descriptor.fields_by_name)
self.CheckDescriptorMapping(message_descriptor.fields_by_number)
def CheckFieldDescriptor(self, field_descriptor):
# Basic properties
self.assertEqual(field_descriptor.name, 'optional_int32')
self.assertEqual(field_descriptor.full_name,
'protobuf_unittest.TestAllTypes.optional_int32')
self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
# Test equality and hashability
self.assertEqual(field_descriptor, field_descriptor)
self.assertEqual(
field_descriptor.containing_type.fields_by_name['optional_int32'],
field_descriptor)
self.assertIn(field_descriptor, [field_descriptor])
self.assertIn(field_descriptor, {field_descriptor: None})
def CheckDescriptorSequence(self, sequence):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Sequence.
self.assertGreater(len(sequence), 0) # Sized
self.assertEqual(len(sequence), len(list(sequence))) # Iterable
item = sequence[0]
self.assertEqual(item, sequence[0])
self.assertIn(item, sequence) # Container
self.assertEqual(sequence.index(item), 0)
self.assertEqual(sequence.count(item), 1)
reversed_iterator = reversed(sequence)
self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
self.assertRaises(StopIteration, next, reversed_iterator)
def CheckDescriptorMapping(self, mapping):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Mapping.
self.assertGreater(len(mapping), 0) # Sized
self.assertEqual(len(mapping), len(list(mapping))) # Iterable
if sys.version_info.major >= 3:
key, item = next(iter(mapping.items()))
else:
key, item = mapping.items()[0]
self.assertIn(key, mapping) # Container
self.assertEqual(mapping.get(key), item)
# keys(), iterkeys() &co
item = (next(iter(mapping.keys())), next(iter(mapping.values())))
self.assertEqual(item, next(iter(mapping.items())))
if sys.version_info.major < 3:
def CheckItems(seq, iterator):
self.assertEqual(next(iterator), seq[0])
self.assertEqual(list(iterator), seq[1:])
CheckItems(mapping.keys(), mapping.iterkeys())
CheckItems(mapping.values(), mapping.itervalues())
CheckItems(mapping.items(), mapping.iteritems())
def testDescriptor(self):
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
self.CheckMessageDescriptor(message_descriptor)
field_descriptor = message_descriptor.fields_by_name['optional_int32']
self.CheckFieldDescriptor(field_descriptor)
def testCppDescriptorContainer(self):
# Check that the collection is still valid even if the parent disappeared.
enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
values = enum.values
del enum
self.assertEqual('FOO', values[0].name)
def testCppDescriptorContainer_Iterator(self):
# Same test with the iterator
enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
values_iter = iter(enum.values)
del enum
self.assertEqual('FOO', next(values_iter).name)
class DescriptorCopyToProtoTest(basetest.TestCase): class DescriptorCopyToProtoTest(basetest.TestCase):
"""Tests for CopyTo functions of Descriptor.""" """Tests for CopyTo functions of Descriptor."""
...@@ -588,10 +686,12 @@ class DescriptorCopyToProtoTest(basetest.TestCase): ...@@ -588,10 +686,12 @@ class DescriptorCopyToProtoTest(basetest.TestCase):
output_type: '.protobuf_unittest.BarResponse' output_type: '.protobuf_unittest.BarResponse'
> >
""" """
self._InternalTestCopyToProto( # TODO(rocking): enable this test after the proto descriptor change is
unittest_pb2.TestService.DESCRIPTOR, # checked in.
descriptor_pb2.ServiceDescriptorProto, #self._InternalTestCopyToProto(
TEST_SERVICE_ASCII) # unittest_pb2.TestService.DESCRIPTOR,
# descriptor_pb2.ServiceDescriptorProto,
# TEST_SERVICE_ASCII)
class MakeDescriptorTest(basetest.TestCase): class MakeDescriptorTest(basetest.TestCase):
......
...@@ -219,12 +219,20 @@ def _AttachFieldHelpers(cls, field_descriptor): ...@@ -219,12 +219,20 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed): def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
cls._decoders_by_tag[tag_bytes] = ( decode_type = field_descriptor.type
type_checkers.TYPE_TO_DECODER[field_descriptor.type]( if (decode_type == _FieldDescriptor.TYPE_ENUM and
field_descriptor.number, is_repeated, is_packed, type_checkers.SupportsOpenEnums(field_descriptor)):
field_descriptor, field_descriptor._default_constructor), decode_type = _FieldDescriptor.TYPE_INT32
field_descriptor if field_descriptor.containing_oneof is not None
else None) oneof_descriptor = None
if field_descriptor.containing_oneof is not None:
oneof_descriptor = field_descriptor
field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
field_descriptor.number, is_repeated, is_packed,
field_descriptor, field_descriptor._default_constructor)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False) False)
...@@ -296,6 +304,8 @@ def _DefaultValueConstructorForField(field): ...@@ -296,6 +304,8 @@ def _DefaultValueConstructorForField(field):
def MakeSubMessageDefault(message): def MakeSubMessageDefault(message):
result = message_type._concrete_class() result = message_type._concrete_class()
result._SetListener(message._listener_for_children) result._SetListener(message._listener_for_children)
if field.containing_oneof:
message._UpdateOneofState(field)
return result return result
return MakeSubMessageDefault return MakeSubMessageDefault
...@@ -476,6 +486,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): ...@@ -476,6 +486,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
type_checker = type_checkers.GetTypeChecker(field) type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value default_value = field.default_value
valid_values = set() valid_values = set()
is_proto3 = field.containing_type.syntax == "proto3"
def getter(self): def getter(self):
# TODO(protobuf-team): This may be broken since there may not be # TODO(protobuf-team): This may be broken since there may not be
...@@ -483,15 +494,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): ...@@ -483,15 +494,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value) return self._fields.get(field, default_value)
getter.__module__ = None getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name getter.__doc__ = 'Getter for %s.' % proto_field_name
clear_when_set_to_default = is_proto3 and not field.containing_oneof
def field_setter(self, new_value): def field_setter(self, new_value):
# pylint: disable=protected-access # pylint: disable=protected-access
self._fields[field] = type_checker.CheckValue(new_value) # Testing the value for truthiness captures all of the proto3 defaults
# (0, 0.0, enum 0, and False).
new_value = type_checker.CheckValue(new_value)
if clear_when_set_to_default and not new_value:
self._fields.pop(field, None)
else:
self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar # Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently. # setters are called frequently.
if not self._cached_byte_size_dirty: if not self._cached_byte_size_dirty:
self._Modified() self._Modified()
if field.containing_oneof is not None: if field.containing_oneof:
def setter(self, new_value): def setter(self, new_value):
field_setter(self, new_value) field_setter(self, new_value)
self._UpdateOneofState(field) self._UpdateOneofState(field)
...@@ -624,24 +644,35 @@ def _AddListFieldsMethod(message_descriptor, cls): ...@@ -624,24 +644,35 @@ def _AddListFieldsMethod(message_descriptor, cls):
cls.ListFields = ListFields cls.ListFields = ListFields
_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls): def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods().""" """Helper for _AddMessageMethods()."""
singular_fields = {} is_proto3 = (message_descriptor.syntax == "proto3")
error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
hassable_fields = {}
for field in message_descriptor.fields: for field in message_descriptor.fields:
if field.label != _FieldDescriptor.LABEL_REPEATED: if field.label == _FieldDescriptor.LABEL_REPEATED:
singular_fields[field.name] = field continue
# Fields inside oneofs are never repeated (enforced by the compiler). # For proto3, only submessages and fields inside a oneof have presence.
for field in message_descriptor.oneofs: if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
singular_fields[field.name] = field not field.containing_oneof):
continue
hassable_fields[field.name] = field
if not is_proto3:
# Fields inside oneofs are never repeated (enforced by the compiler).
for oneof in message_descriptor.oneofs:
hassable_fields[oneof.name] = oneof
def HasField(self, field_name): def HasField(self, field_name):
try: try:
field = singular_fields[field_name] field = hassable_fields[field_name]
except KeyError: except KeyError:
raise ValueError( raise ValueError(error_msg % field_name)
'Protocol message has no singular "%s" field.' % field_name)
if isinstance(field, descriptor_mod.OneofDescriptor): if isinstance(field, descriptor_mod.OneofDescriptor):
try: try:
...@@ -871,6 +902,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): ...@@ -871,6 +902,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag decoders_by_tag = cls._decoders_by_tag
is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end): def InternalParse(self, buffer, pos, end):
self._Modified() self._Modified()
...@@ -884,9 +916,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls): ...@@ -884,9 +916,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1: if new_pos == -1:
return pos return pos
if not unknown_field_list: if not is_proto3:
unknown_field_list = self._unknown_fields = [] if not unknown_field_list:
unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) unknown_field_list = self._unknown_fields = []
unknown_field_list.append(
(tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos pos = new_pos
else: else:
pos = field_decoder(buffer, new_pos, end, self, field_dict) pos = field_decoder(buffer, new_pos, end, self, field_dict)
...@@ -1008,6 +1042,8 @@ def _AddMergeFromMethod(cls): ...@@ -1008,6 +1042,8 @@ def _AddMergeFromMethod(cls):
# Construct a new object to represent this field. # Construct a new object to represent this field.
field_value = field._default_constructor(self) field_value = field._default_constructor(self)
fields[field] = field_value fields[field] = field_value
if field.containing_oneof:
self._UpdateOneofState(field)
field_value.MergeFrom(value) field_value.MergeFrom(value)
else: else:
self._fields[field] = value self._fields[field] = value
...@@ -1252,11 +1288,10 @@ class _ExtensionDict(object): ...@@ -1252,11 +1288,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time, # It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway. # but we expect this to be a vanishingly uncommon case anyway.
type_checker = type_checkers.GetTypeChecker( type_checker = type_checkers.GetTypeChecker(extension_handle)
extension_handle)
# pylint: disable=protected-access # pylint: disable=protected-access
self._extended_message._fields[extension_handle] = ( self._extended_message._fields[extension_handle] = (
type_checker.CheckValue(value)) type_checker.CheckValue(value))
self._extended_message._Modified() self._extended_message._Modified()
def _FindExtensionByName(self, name): def _FindExtensionByName(self, name):
......
...@@ -1792,6 +1792,27 @@ class ReflectionTest(basetest.TestCase): ...@@ -1792,6 +1792,27 @@ class ReflectionTest(basetest.TestCase):
# Just check the default value. # Just check the default value.
self.assertEqual(57, msg.inner.value) self.assertEqual(57, msg.inner.value)
@basetest.unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'CPPv2-specific test')
def testBadArguments(self):
# Some of these assertions used to segfault.
from google.protobuf.pyext import _message
self.assertRaises(TypeError, _message.Message._GetFieldDescriptor, 3)
self.assertRaises(TypeError, _message.Message._GetExtensionDescriptor, 42)
self.assertRaises(TypeError,
unittest_pb2.TestAllTypes().__getattribute__, 42)
@basetest.unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'CPPv2-specific test')
def testRosyHack(self):
from google.protobuf.pyext import _message
from google3.gdata.rosy.proto import core_api2_pb2
from google3.gdata.rosy.proto import core_pb2
self.assertEqual(_message.Message, core_pb2.PageSelection.__base__)
self.assertEqual(_message.Message, core_api2_pb2.PageSelection.__base__)
# Since we had so many tests for protocol buffer equality, we broke these out # Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes. # into separate TestCase classes.
......
...@@ -59,6 +59,8 @@ from google.protobuf import descriptor ...@@ -59,6 +59,8 @@ from google.protobuf import descriptor
_FieldDescriptor = descriptor.FieldDescriptor _FieldDescriptor = descriptor.FieldDescriptor
def SupportsOpenEnums(field_descriptor):
return field_descriptor.containing_type.syntax == "proto3"
def GetTypeChecker(field): def GetTypeChecker(field):
"""Returns a type checker for a message field of the specified types. """Returns a type checker for a message field of the specified types.
...@@ -74,7 +76,11 @@ def GetTypeChecker(field): ...@@ -74,7 +76,11 @@ def GetTypeChecker(field):
field.type == _FieldDescriptor.TYPE_STRING): field.type == _FieldDescriptor.TYPE_STRING):
return UnicodeValueChecker() return UnicodeValueChecker()
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
return EnumValueChecker(field.enum_type) if SupportsOpenEnums(field):
# When open enums are supported, any int32 can be assigned.
return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32]
else:
return EnumValueChecker(field.enum_type)
return _VALUE_CHECKERS[field.cpp_type] return _VALUE_CHECKERS[field.cpp_type]
......
...@@ -38,6 +38,7 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' ...@@ -38,6 +38,7 @@ __author__ = 'bohdank@google.com (Bohdan Koval)'
from google.apputils import basetest from google.apputils import basetest
from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder from google.protobuf.internal import encoder
from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import missing_enum_values_pb2
...@@ -45,10 +46,81 @@ from google.protobuf.internal import test_util ...@@ -45,10 +46,81 @@ from google.protobuf.internal import test_util
from google.protobuf.internal import type_checkers from google.protobuf.internal import type_checkers
class UnknownFieldsTest(basetest.TestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
self.all_fields = unittest_pb2.TestAllTypes()
test_util.SetAllFields(self.all_fields)
self.all_fields_data = self.all_fields.SerializeToString()
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
def testSerialize(self):
data = self.empty_message.SerializeToString()
# Don't use assertEqual because we don't want to dump raw binary data to
# stdout.
self.assertTrue(data == self.all_fields_data)
def testSerializeProto3(self):
# Verify that proto3 doesn't preserve unknown fields.
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
self.assertEqual(0, len(message.SerializeToString()))
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
def testListFields(self):
# Make sure ListFields doesn't return unknown fields.
self.assertEqual(0, len(self.empty_message.ListFields()))
def testSerializeMessageSetWireFormatUnknownExtension(self):
# Create a message using the message set wire format with an unknown
# message.
raw = unittest_mset_pb2.RawMessageSet()
# Add an unknown extension.
item = raw.item.add()
item.type_id = 1545009
message1 = unittest_mset_pb2.TestMessageSetExtension1()
message1.i = 12345
item.message = message1.SerializeToString()
serialized = raw.SerializeToString()
# Parse message using the message set wire format.
proto = unittest_mset_pb2.TestMessageSet()
proto.MergeFromString(serialized)
# Verify that the unknown extension is serialized unchanged
reserialized = proto.SerializeToString()
new_raw = unittest_mset_pb2.RawMessageSet()
new_raw.MergeFromString(reserialized)
self.assertEqual(raw, new_raw)
# C++ implementation for proto2 does not currently take into account unknown
# fields when checking equality.
#
# TODO(haberman): fix this.
@basetest.unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation does not expose unknown fields to Python')
def testEquals(self):
message = unittest_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
self.assertEqual(self.empty_message, message)
self.all_fields.ClearField('optional_string')
message.ParseFromString(self.all_fields.SerializeToString())
self.assertNotEqual(self.empty_message, message)
@basetest.unittest.skipIf( @basetest.unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation does not expose unknown fields to Python') 'C++ implementation does not expose unknown fields to Python')
class UnknownFieldsTest(basetest.TestCase): class UnknownFieldsAccessorsTest(basetest.TestCase):
def setUp(self): def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
...@@ -98,13 +170,6 @@ class UnknownFieldsTest(basetest.TestCase): ...@@ -98,13 +170,6 @@ class UnknownFieldsTest(basetest.TestCase):
value = self.GetField('optionalgroup') value = self.GetField('optionalgroup')
self.assertEqual(self.all_fields.optionalgroup, value) self.assertEqual(self.all_fields.optionalgroup, value)
def testSerialize(self):
data = self.empty_message.SerializeToString()
# Don't use assertEqual because we don't want to dump raw binary data to
# stdout.
self.assertTrue(data == self.all_fields_data)
def testCopyFrom(self): def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage() message = unittest_pb2.TestEmptyMessage()
message.CopyFrom(self.empty_message) message.CopyFrom(self.empty_message)
...@@ -132,51 +197,12 @@ class UnknownFieldsTest(basetest.TestCase): ...@@ -132,51 +197,12 @@ class UnknownFieldsTest(basetest.TestCase):
self.empty_message.Clear() self.empty_message.Clear()
self.assertEqual(0, len(self.empty_message._unknown_fields)) self.assertEqual(0, len(self.empty_message._unknown_fields))
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
def testUnknownExtensions(self): def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions() message = unittest_pb2.TestEmptyMessageWithExtensions()
message.ParseFromString(self.all_fields_data) message.ParseFromString(self.all_fields_data)
self.assertEqual(self.empty_message._unknown_fields, self.assertEqual(self.empty_message._unknown_fields,
message._unknown_fields) message._unknown_fields)
def testListFields(self):
# Make sure ListFields doesn't return unknown fields.
self.assertEqual(0, len(self.empty_message.ListFields()))
def testSerializeMessageSetWireFormatUnknownExtension(self):
# Create a message using the message set wire format with an unknown
# message.
raw = unittest_mset_pb2.RawMessageSet()
# Add an unknown extension.
item = raw.item.add()
item.type_id = 1545009
message1 = unittest_mset_pb2.TestMessageSetExtension1()
message1.i = 12345
item.message = message1.SerializeToString()
serialized = raw.SerializeToString()
# Parse message using the message set wire format.
proto = unittest_mset_pb2.TestMessageSet()
proto.MergeFromString(serialized)
# Verify that the unknown extension is serialized unchanged
reserialized = proto.SerializeToString()
new_raw = unittest_mset_pb2.RawMessageSet()
new_raw.MergeFromString(reserialized)
self.assertEqual(raw, new_raw)
def testEquals(self):
message = unittest_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
self.assertEqual(self.empty_message, message)
self.all_fields.ClearField('optional_string')
message.ParseFromString(self.all_fields.SerializeToString())
self.assertNotEqual(self.empty_message, message)
@basetest.unittest.skipIf( @basetest.unittest.skipIf(
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#! /usr/bin/python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests for google.protobuf.pyext behavior."""
__author__ = 'anuraag@google.com (Anuraag Agrawal)'
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
# We must set the implementation version above before the google3 imports.
# pylint: disable=g-import-not-at-top
from google.apputils import basetest
from google.protobuf.internal import api_implementation
# Run all tests from the original module by putting them in our namespace.
# pylint: disable=wildcard-import
from google.protobuf.internal.descriptor_test import *
class ConfirmCppApi2Test(basetest.TestCase):
def testImplementationSetting(self):
self.assertEqual('cpp', api_implementation.Type())
self.assertEqual(2, api_implementation.Version())
if __name__ == '__main__':
basetest.main()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -319,6 +319,11 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): ...@@ -319,6 +319,11 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
ParseError: In case of ASCII parsing problems. ParseError: In case of ASCII parsing problems.
""" """
message_descriptor = message.DESCRIPTOR message_descriptor = message.DESCRIPTOR
if (hasattr(message_descriptor, 'syntax') and
message_descriptor.syntax == 'proto3'):
# Proto3 doesn't represent presence so we can't test if multiple
# scalars have occurred. We have to allow them.
allow_multiple_scalars = True
if tokenizer.TryConsume('['): if tokenizer.TryConsume('['):
name = [tokenizer.ConsumeIdentifier()] name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'): while tokenizer.TryConsume('.'):
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment