Commit 3ff56252 authored by Joshua Haberman's avatar Joshua Haberman

Merge pull request #722 from dano/py2_py3_straddle

Add tox, Python 2.6 compatibility, and many Python 3 compatibility fixes
parents eb65c69e 46969b99
...@@ -55,6 +55,7 @@ src/google/protobuf/util/**/*.pb.h ...@@ -55,6 +55,7 @@ src/google/protobuf/util/**/*.pb.h
*_pb2.py *_pb2.py
python/*.egg python/*.egg
python/.eggs/ python/.eggs/
python/.tox
python/build/ python/build/
python/google/protobuf/compiler/ python/google/protobuf/compiler/
......
...@@ -538,7 +538,6 @@ python_EXTRA_DIST= \ ...@@ -538,7 +538,6 @@ python_EXTRA_DIST= \
python/google/protobuf/text_format.py \ python/google/protobuf/text_format.py \
python/google/protobuf/__init__.py \ python/google/protobuf/__init__.py \
python/google/__init__.py \ python/google/__init__.py \
python/ez_setup.py \
python/setup.py \ python/setup.py \
python/mox.py \ python/mox.py \
python/stubout.py \ python/stubout.py \
......
#!python
# This file was obtained from:
# http://peak.telecommunity.com/dist/ez_setup.py
# on 2011/1/21.
"""Bootstrap setuptools installation
If you want to use setuptools in your package's setup.py, just include this
file in the same directory with it, and add this to the top of your setup.py::
from ez_setup import use_setuptools
use_setuptools()
If you want to require a specific version of setuptools, set a download
mirror, or use an alternate download directory, you can do so by supplying
the appropriate options to ``use_setuptools()``.
This file can also be run as a script to install or upgrade setuptools.
"""
import sys
DEFAULT_VERSION = "0.6c11"
DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
md5_data = {
'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca',
'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb',
'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b',
'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a',
'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618',
'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac',
'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5',
'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
'setuptools-0.6c10-py2.3.egg': 'ce1e2ab5d3a0256456d9fc13800a7090',
'setuptools-0.6c10-py2.4.egg': '57d6d9d6e9b80772c59a53a8433a5dd4',
'setuptools-0.6c10-py2.5.egg': 'de46ac8b1c97c895572e5e8596aeb8c7',
'setuptools-0.6c10-py2.6.egg': '58ea40aef06da02ce641495523a0b7f5',
'setuptools-0.6c11-py2.3.egg': '2baeac6e13d414a9d28e7ba5b5a596de',
'setuptools-0.6c11-py2.4.egg': 'bd639f9b0eac4c42497034dec2ec0c2b',
'setuptools-0.6c11-py2.5.egg': '64c94f3bf7a72a13ec83e0b24f2749b2',
'setuptools-0.6c11-py2.6.egg': 'bfa92100bd772d5a213eedd356d64086',
'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e',
'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e',
'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f',
'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2',
'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc',
'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167',
'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64',
'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d',
'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20',
'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab',
'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53',
'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2',
'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e',
'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372',
'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902',
'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de',
'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b',
'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03',
'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a',
'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6',
'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a',
}
import sys, os
try: from hashlib import md5
except ImportError: from md5 import md5
def _validate_md5(egg_name, data):
if egg_name in md5_data:
digest = md5(data).hexdigest()
if digest != md5_data[egg_name]:
print >>sys.stderr, (
"md5 validation of %s failed! (Possible download problem?)"
% egg_name
)
sys.exit(2)
return data
def use_setuptools(
version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
download_delay=15
):
"""Automatically find/download setuptools and make it available on sys.path
`version` should be a valid setuptools version number that is available
as an egg for download under the `download_base` URL (which should end with
a '/'). `to_dir` is the directory where setuptools will be downloaded, if
it is not already available. If `download_delay` is specified, it should
be the number of seconds that will be paused before initiating a download,
should one be required. If an older version of setuptools is installed,
this routine will print a message to ``sys.stderr`` and raise SystemExit in
an attempt to abort the calling script.
"""
was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules
def do_download():
egg = download_setuptools(version, download_base, to_dir, download_delay)
sys.path.insert(0, egg)
import setuptools; setuptools.bootstrap_install_from = egg
try:
import pkg_resources
except ImportError:
return do_download()
try:
return do_download()
pkg_resources.require("setuptools>="+version); return
except pkg_resources.VersionConflict, e:
if was_imported:
print >>sys.stderr, (
"The required version of setuptools (>=%s) is not available, and\n"
"can't be installed while this script is running. Please install\n"
" a more recent version first, using 'easy_install -U setuptools'."
"\n\n(Currently using %r)"
) % (version, e.args[0])
sys.exit(2)
except pkg_resources.DistributionNotFound:
pass
del pkg_resources, sys.modules['pkg_resources'] # reload ok
return do_download()
def download_setuptools(
version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
delay = 15
):
"""Download setuptools from a specified location and return its filename
`version` should be a valid setuptools version number that is available
as an egg for download under the `download_base` URL (which should end
with a '/'). `to_dir` is the directory where the egg will be downloaded.
`delay` is the number of seconds to pause before an actual download attempt.
"""
import urllib2, shutil
egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3])
url = download_base + egg_name
saveto = os.path.join(to_dir, egg_name)
src = dst = None
if not os.path.exists(saveto): # Avoid repeated downloads
try:
from distutils import log
if delay:
log.warn("""
---------------------------------------------------------------------------
This script requires setuptools version %s to run (even to display
help). I will attempt to download it for you (from
%s), but
you may need to enable firewall access for this script first.
I will start the download in %d seconds.
(Note: if this machine does not have network access, please obtain the file
%s
and place it in this directory before rerunning this script.)
---------------------------------------------------------------------------""",
version, download_base, delay, url
); from time import sleep; sleep(delay)
log.warn("Downloading %s", url)
src = urllib2.urlopen(url)
# Read/write all in one block, so we don't create a corrupt file
# if the download is interrupted.
data = _validate_md5(egg_name, src.read())
dst = open(saveto,"wb"); dst.write(data)
finally:
if src: src.close()
if dst: dst.close()
return os.path.realpath(saveto)
def main(argv, version=DEFAULT_VERSION):
"""Install or upgrade setuptools and EasyInstall"""
try:
import setuptools
except ImportError:
egg = None
try:
egg = download_setuptools(version, delay=0)
sys.path.insert(0,egg)
from setuptools.command.easy_install import main
return main(list(argv)+[egg]) # we're done here
finally:
if egg and os.path.exists(egg):
os.unlink(egg)
else:
if setuptools.__version__ == '0.0.1':
print >>sys.stderr, (
"You have an obsolete version of setuptools installed. Please\n"
"remove it from your system entirely before rerunning this script."
)
sys.exit(2)
req = "setuptools>="+version
import pkg_resources
try:
pkg_resources.require(req)
except pkg_resources.VersionConflict:
try:
from setuptools.command.easy_install import main
except ImportError:
from easy_install import main
main(list(argv)+[download_setuptools(delay=0)])
sys.exit(0) # try to force an exit
else:
if argv:
from setuptools.command.easy_install import main
main(argv)
else:
print "Setuptools version",version,"or greater has been installed."
print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)'
def update_md5(filenames):
"""Update our built-in md5 registry"""
import re
for name in filenames:
base = os.path.basename(name)
f = open(name,'rb')
md5_data[base] = md5(f.read()).hexdigest()
f.close()
data = [" %r: %r,\n" % it for it in md5_data.items()]
data.sort()
repl = "".join(data)
import inspect
srcfile = inspect.getsourcefile(sys.modules[__name__])
f = open(srcfile, 'rb'); src = f.read(); f.close()
match = re.search("\nmd5_data = {\n([^}]+)}", src)
if not match:
print >>sys.stderr, "Internal error!"
sys.exit(2)
src = src[:match.start(1)] + repl + src[match.end(1):]
f = open(srcfile,'w')
f.write(src)
f.close()
if __name__=='__main__':
if len(sys.argv)>2 and sys.argv[1]=='--md5update':
update_md5(sys.argv[2:])
else:
main(sys.argv[1:])
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Needs to stay compatible with Python 2.5 due to GAE.
#
# Copyright 2007 Google Inc. All Rights Reserved. # Copyright 2007 Google Inc. All Rights Reserved.
__version__ = '3.0.0a4.dev0' __version__ = '3.0.0a4.dev0'
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Needs to stay compatible with Python 2.5 due to GAE.
#
# Copyright 2007 Google Inc. All Rights Reserved. # Copyright 2007 Google Inc. All Rights Reserved.
"""Descriptors essentially contain exactly the information found in a .proto """Descriptors essentially contain exactly the information found in a .proto
...@@ -918,5 +916,5 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, ...@@ -918,5 +916,5 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
desc_name = '.'.join(full_message_name) desc_name = '.'.join(full_message_name)
return Descriptor(desc_proto.name, desc_name, None, None, fields, return Descriptor(desc_proto.name, desc_name, None, None, fields,
nested_types.values(), enum_types.values(), [], list(nested_types.values()), list(enum_types.values()), [],
options=desc_proto.options) options=desc_proto.options)
...@@ -149,9 +149,14 @@ import collections ...@@ -149,9 +149,14 @@ import collections
import functools import functools
import re import re
import types import types
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
import uuid import uuid
import six
ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
_SEPARATOR = uuid.uuid1().hex _SEPARATOR = uuid.uuid1().hex
_FIRST_ARG = object() _FIRST_ARG = object()
...@@ -170,13 +175,13 @@ def _StrClass(cls): ...@@ -170,13 +175,13 @@ def _StrClass(cls):
def _NonStringIterable(obj): def _NonStringIterable(obj):
return (isinstance(obj, collections.Iterable) and not return (isinstance(obj, collections.Iterable) and not
isinstance(obj, basestring)) isinstance(obj, six.string_types))
def _FormatParameterList(testcase_params): def _FormatParameterList(testcase_params):
if isinstance(testcase_params, collections.Mapping): if isinstance(testcase_params, collections.Mapping):
return ', '.join('%s=%s' % (argname, _CleanRepr(value)) return ', '.join('%s=%s' % (argname, _CleanRepr(value))
for argname, value in testcase_params.iteritems()) for argname, value in testcase_params.items())
elif _NonStringIterable(testcase_params): elif _NonStringIterable(testcase_params):
return ', '.join(map(_CleanRepr, testcase_params)) return ', '.join(map(_CleanRepr, testcase_params))
else: else:
...@@ -258,7 +263,9 @@ def _ModifyClass(class_object, testcases, naming_type): ...@@ -258,7 +263,9 @@ def _ModifyClass(class_object, testcases, naming_type):
'Cannot add parameters to %s,' 'Cannot add parameters to %s,'
' which already has parameterized methods.' % (class_object,)) ' which already has parameterized methods.' % (class_object,))
class_object._id_suffix = id_suffix = {} class_object._id_suffix = id_suffix = {}
for name, obj in class_object.__dict__.items(): # We change the size of __dict__ while we iterate over it,
# which Python 3.x will complain about, so use copy().
for name, obj in class_object.__dict__.copy().items():
if (name.startswith(unittest.TestLoader.testMethodPrefix) if (name.startswith(unittest.TestLoader.testMethodPrefix)
and isinstance(obj, types.FunctionType)): and isinstance(obj, types.FunctionType)):
delattr(class_object, name) delattr(class_object, name)
...@@ -266,7 +273,7 @@ def _ModifyClass(class_object, testcases, naming_type): ...@@ -266,7 +273,7 @@ def _ModifyClass(class_object, testcases, naming_type):
_UpdateClassDictForParamTestCase( _UpdateClassDictForParamTestCase(
methods, id_suffix, name, methods, id_suffix, name,
_ParameterizedTestIter(obj, testcases, naming_type)) _ParameterizedTestIter(obj, testcases, naming_type))
for name, meth in methods.iteritems(): for name, meth in methods.items():
setattr(class_object, name, meth) setattr(class_object, name, meth)
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#PY25 compatible for GAE.
#
# Copyright 2009 Google Inc. All Rights Reserved. # Copyright 2009 Google Inc. All Rights Reserved.
"""Code for decoding protocol buffer primitives. """Code for decoding protocol buffer primitives.
...@@ -85,8 +83,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. ...@@ -85,8 +83,12 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)' __author__ = 'kenton@google.com (Kenton Varda)'
import struct import struct
import sys ##PY25
_PY2 = sys.version_info[0] < 3 ##PY25 import six
if six.PY3:
long = int
from google.protobuf.internal import encoder from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format from google.protobuf.internal import wire_format
from google.protobuf import message from google.protobuf import message
...@@ -114,14 +116,11 @@ def _VarintDecoder(mask, result_type): ...@@ -114,14 +116,11 @@ def _VarintDecoder(mask, result_type):
decoder returns a (value, new_pos) pair. decoder returns a (value, new_pos) pair.
""" """
local_ord = ord
py2 = _PY2 ##PY25
##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos): def DecodeVarint(buffer, pos):
result = 0 result = 0
shift = 0 shift = 0
while 1: while 1:
b = local_ord(buffer[pos]) if py2 else buffer[pos] b = six.indexbytes(buffer, pos)
result |= ((b & 0x7f) << shift) result |= ((b & 0x7f) << shift)
pos += 1 pos += 1
if not (b & 0x80): if not (b & 0x80):
...@@ -137,14 +136,11 @@ def _VarintDecoder(mask, result_type): ...@@ -137,14 +136,11 @@ def _VarintDecoder(mask, result_type):
def _SignedVarintDecoder(mask, result_type): def _SignedVarintDecoder(mask, result_type):
"""Like _VarintDecoder() but decodes signed values.""" """Like _VarintDecoder() but decodes signed values."""
local_ord = ord
py2 = _PY2 ##PY25
##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos): def DecodeVarint(buffer, pos):
result = 0 result = 0
shift = 0 shift = 0
while 1: while 1:
b = local_ord(buffer[pos]) if py2 else buffer[pos] b = six.indexbytes(buffer, pos)
result |= ((b & 0x7f) << shift) result |= ((b & 0x7f) << shift)
pos += 1 pos += 1
if not (b & 0x80): if not (b & 0x80):
...@@ -183,10 +179,8 @@ def ReadTag(buffer, pos): ...@@ -183,10 +179,8 @@ def ReadTag(buffer, pos):
use that, but not in Python. use that, but not in Python.
""" """
py2 = _PY2 ##PY25
##!PY25 py2 = str is bytes
start = pos start = pos
while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: while six.indexbytes(buffer, pos) & 0x80:
pos += 1 pos += 1
pos += 1 pos += 1
return (buffer[start:pos], pos) return (buffer[start:pos], pos)
...@@ -301,7 +295,6 @@ def _FloatDecoder(): ...@@ -301,7 +295,6 @@ def _FloatDecoder():
""" """
local_unpack = struct.unpack local_unpack = struct.unpack
b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
def InnerDecode(buffer, pos): def InnerDecode(buffer, pos):
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
...@@ -312,17 +305,12 @@ def _FloatDecoder(): ...@@ -312,17 +305,12 @@ def _FloatDecoder():
# If this value has all its exponent bits set, then it's non-finite. # If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
# To avoid that, we parse it specially. # To avoid that, we parse it specially.
if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF')
and (float_bytes[2:3] >= b('\x80'))): ##PY25
##!PY25 and (float_bytes[2:3] >= b'\x80')):
# If at least one significand bit is set... # If at least one significand bit is set...
if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 if float_bytes[0:3] != b'\x00\x00\x80':
##!PY25 if float_bytes[0:3] != b'\x00\x00\x80':
return (_NAN, new_pos) return (_NAN, new_pos)
# If sign bit is set... # If sign bit is set...
if float_bytes[3:4] == b('\xFF'): ##PY25 if float_bytes[3:4] == b'\xFF':
##!PY25 if float_bytes[3:4] == b'\xFF':
return (_NEG_INF, new_pos) return (_NEG_INF, new_pos)
return (_POS_INF, new_pos) return (_POS_INF, new_pos)
...@@ -341,7 +329,6 @@ def _DoubleDecoder(): ...@@ -341,7 +329,6 @@ def _DoubleDecoder():
""" """
local_unpack = struct.unpack local_unpack = struct.unpack
b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
def InnerDecode(buffer, pos): def InnerDecode(buffer, pos):
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
...@@ -352,12 +339,9 @@ def _DoubleDecoder(): ...@@ -352,12 +339,9 @@ def _DoubleDecoder():
# If this value has all its exponent bits set and at least one significand # If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
# as inf or -inf. To avoid that, we treat it specially. # as inf or -inf. To avoid that, we treat it specially.
##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') if ((double_bytes[7:8] in b'\x7F\xFF')
##!PY25 and (double_bytes[6:7] >= b'\xF0') and (double_bytes[6:7] >= b'\xF0')
##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25
and (double_bytes[6:7] >= b('\xF0')) ##PY25
and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25
return (_NAN, new_pos) return (_NAN, new_pos)
# Note that we expect someone up-stack to catch struct.error and convert # Note that we expect someone up-stack to catch struct.error and convert
...@@ -480,12 +464,12 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): ...@@ -480,12 +464,12 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a string field.""" """Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint local_DecodeVarint = _DecodeVarint
local_unicode = unicode local_unicode = six.text_type
def _ConvertToUnicode(byte_str): def _ConvertToUnicode(byte_str):
try: try:
return local_unicode(byte_str, 'utf-8') return local_unicode(byte_str, 'utf-8')
except UnicodeDecodeError, e: except UnicodeDecodeError as e:
# add more information to the error message and re-raise it. # add more information to the error message and re-raise it.
e.reason = '%s in field: %s' % (e, key.full_name) e.reason = '%s in field: %s' % (e, key.full_name)
raise raise
......
...@@ -34,7 +34,10 @@ ...@@ -34,7 +34,10 @@
__author__ = 'matthewtoia@google.com (Matt Toia)' __author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2 from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database from google.protobuf import descriptor_database
...@@ -48,17 +51,17 @@ class DescriptorDatabaseTest(unittest.TestCase): ...@@ -48,17 +51,17 @@ class DescriptorDatabaseTest(unittest.TestCase):
factory_test2_pb2.DESCRIPTOR.serialized_pb) factory_test2_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto) db.Add(file_desc_proto)
self.assertEquals(file_desc_proto, db.FindFileByName( self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto')) 'google/protobuf/internal/factory_test2.proto'))
self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message')) 'google.protobuf.python.internal.Factory2Message'))
self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message'))
self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum')) 'google.protobuf.python.internal.Factory2Enum'))
self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum'))
self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum')) 'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum'))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -35,9 +35,11 @@ ...@@ -35,9 +35,11 @@
__author__ = 'matthewtoia@google.com (Matt Toia)' __author__ = 'matthewtoia@google.com (Matt Toia)'
import os import os
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
import unittest
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation from google.protobuf.internal import api_implementation
...@@ -66,15 +68,15 @@ class DescriptorPoolTest(unittest.TestCase): ...@@ -66,15 +68,15 @@ class DescriptorPoolTest(unittest.TestCase):
name1 = 'google/protobuf/internal/factory_test1.proto' name1 = 'google/protobuf/internal/factory_test1.proto'
file_desc1 = self.pool.FindFileByName(name1) file_desc1 = self.pool.FindFileByName(name1)
self.assertIsInstance(file_desc1, descriptor.FileDescriptor) self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
self.assertEquals(name1, file_desc1.name) self.assertEqual(name1, file_desc1.name)
self.assertEquals('google.protobuf.python.internal', file_desc1.package) self.assertEqual('google.protobuf.python.internal', file_desc1.package)
self.assertIn('Factory1Message', file_desc1.message_types_by_name) self.assertIn('Factory1Message', file_desc1.message_types_by_name)
name2 = 'google/protobuf/internal/factory_test2.proto' name2 = 'google/protobuf/internal/factory_test2.proto'
file_desc2 = self.pool.FindFileByName(name2) file_desc2 = self.pool.FindFileByName(name2)
self.assertIsInstance(file_desc2, descriptor.FileDescriptor) self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
self.assertEquals(name2, file_desc2.name) self.assertEqual(name2, file_desc2.name)
self.assertEquals('google.protobuf.python.internal', file_desc2.package) self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name) self.assertIn('Factory2Message', file_desc2.message_types_by_name)
def testFindFileByNameFailure(self): def testFindFileByNameFailure(self):
...@@ -85,17 +87,17 @@ class DescriptorPoolTest(unittest.TestCase): ...@@ -85,17 +87,17 @@ class DescriptorPoolTest(unittest.TestCase):
file_desc1 = self.pool.FindFileContainingSymbol( file_desc1 = self.pool.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory1Message') 'google.protobuf.python.internal.Factory1Message')
self.assertIsInstance(file_desc1, descriptor.FileDescriptor) self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
self.assertEquals('google/protobuf/internal/factory_test1.proto', self.assertEqual('google/protobuf/internal/factory_test1.proto',
file_desc1.name) file_desc1.name)
self.assertEquals('google.protobuf.python.internal', file_desc1.package) self.assertEqual('google.protobuf.python.internal', file_desc1.package)
self.assertIn('Factory1Message', file_desc1.message_types_by_name) self.assertIn('Factory1Message', file_desc1.message_types_by_name)
file_desc2 = self.pool.FindFileContainingSymbol( file_desc2 = self.pool.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message') 'google.protobuf.python.internal.Factory2Message')
self.assertIsInstance(file_desc2, descriptor.FileDescriptor) self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
self.assertEquals('google/protobuf/internal/factory_test2.proto', self.assertEqual('google/protobuf/internal/factory_test2.proto',
file_desc2.name) file_desc2.name)
self.assertEquals('google.protobuf.python.internal', file_desc2.package) self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name) self.assertIn('Factory2Message', file_desc2.message_types_by_name)
def testFindFileContainingSymbolFailure(self): def testFindFileContainingSymbolFailure(self):
...@@ -106,72 +108,72 @@ class DescriptorPoolTest(unittest.TestCase): ...@@ -106,72 +108,72 @@ class DescriptorPoolTest(unittest.TestCase):
msg1 = self.pool.FindMessageTypeByName( msg1 = self.pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory1Message') 'google.protobuf.python.internal.Factory1Message')
self.assertIsInstance(msg1, descriptor.Descriptor) self.assertIsInstance(msg1, descriptor.Descriptor)
self.assertEquals('Factory1Message', msg1.name) self.assertEqual('Factory1Message', msg1.name)
self.assertEquals('google.protobuf.python.internal.Factory1Message', self.assertEqual('google.protobuf.python.internal.Factory1Message',
msg1.full_name) msg1.full_name)
self.assertEquals(None, msg1.containing_type) self.assertEqual(None, msg1.containing_type)
nested_msg1 = msg1.nested_types[0] nested_msg1 = msg1.nested_types[0]
self.assertEquals('NestedFactory1Message', nested_msg1.name) self.assertEqual('NestedFactory1Message', nested_msg1.name)
self.assertEquals(msg1, nested_msg1.containing_type) self.assertEqual(msg1, nested_msg1.containing_type)
nested_enum1 = msg1.enum_types[0] nested_enum1 = msg1.enum_types[0]
self.assertEquals('NestedFactory1Enum', nested_enum1.name) self.assertEqual('NestedFactory1Enum', nested_enum1.name)
self.assertEquals(msg1, nested_enum1.containing_type) self.assertEqual(msg1, nested_enum1.containing_type)
self.assertEquals(nested_msg1, msg1.fields_by_name[ self.assertEqual(nested_msg1, msg1.fields_by_name[
'nested_factory_1_message'].message_type) 'nested_factory_1_message'].message_type)
self.assertEquals(nested_enum1, msg1.fields_by_name[ self.assertEqual(nested_enum1, msg1.fields_by_name[
'nested_factory_1_enum'].enum_type) 'nested_factory_1_enum'].enum_type)
msg2 = self.pool.FindMessageTypeByName( msg2 = self.pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message') 'google.protobuf.python.internal.Factory2Message')
self.assertIsInstance(msg2, descriptor.Descriptor) self.assertIsInstance(msg2, descriptor.Descriptor)
self.assertEquals('Factory2Message', msg2.name) self.assertEqual('Factory2Message', msg2.name)
self.assertEquals('google.protobuf.python.internal.Factory2Message', self.assertEqual('google.protobuf.python.internal.Factory2Message',
msg2.full_name) msg2.full_name)
self.assertIsNone(msg2.containing_type) self.assertIsNone(msg2.containing_type)
nested_msg2 = msg2.nested_types[0] nested_msg2 = msg2.nested_types[0]
self.assertEquals('NestedFactory2Message', nested_msg2.name) self.assertEqual('NestedFactory2Message', nested_msg2.name)
self.assertEquals(msg2, nested_msg2.containing_type) self.assertEqual(msg2, nested_msg2.containing_type)
nested_enum2 = msg2.enum_types[0] nested_enum2 = msg2.enum_types[0]
self.assertEquals('NestedFactory2Enum', nested_enum2.name) self.assertEqual('NestedFactory2Enum', nested_enum2.name)
self.assertEquals(msg2, nested_enum2.containing_type) self.assertEqual(msg2, nested_enum2.containing_type)
self.assertEquals(nested_msg2, msg2.fields_by_name[ self.assertEqual(nested_msg2, msg2.fields_by_name[
'nested_factory_2_message'].message_type) 'nested_factory_2_message'].message_type)
self.assertEquals(nested_enum2, msg2.fields_by_name[ self.assertEqual(nested_enum2, msg2.fields_by_name[
'nested_factory_2_enum'].enum_type) 'nested_factory_2_enum'].enum_type)
self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value) self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value)
self.assertEquals( self.assertEqual(
1776, msg2.fields_by_name['int_with_default'].default_value) 1776, msg2.fields_by_name['int_with_default'].default_value)
self.assertTrue( self.assertTrue(
msg2.fields_by_name['double_with_default'].has_default_value) msg2.fields_by_name['double_with_default'].has_default_value)
self.assertEquals( self.assertEqual(
9.99, msg2.fields_by_name['double_with_default'].default_value) 9.99, msg2.fields_by_name['double_with_default'].default_value)
self.assertTrue( self.assertTrue(
msg2.fields_by_name['string_with_default'].has_default_value) msg2.fields_by_name['string_with_default'].has_default_value)
self.assertEquals( self.assertEqual(
'hello world', msg2.fields_by_name['string_with_default'].default_value) 'hello world', msg2.fields_by_name['string_with_default'].default_value)
self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value) self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value)
self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value)
self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value) self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value)
self.assertEquals( self.assertEqual(
1, msg2.fields_by_name['enum_with_default'].default_value) 1, msg2.fields_by_name['enum_with_default'].default_value)
msg3 = self.pool.FindMessageTypeByName( msg3 = self.pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message') 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')
self.assertEquals(nested_msg2, msg3) self.assertEqual(nested_msg2, msg3)
self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value) self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value)
self.assertEquals( self.assertEqual(
b'a\xfb\x00c', b'a\xfb\x00c',
msg2.fields_by_name['bytes_with_default'].default_value) msg2.fields_by_name['bytes_with_default'].default_value)
...@@ -191,29 +193,29 @@ class DescriptorPoolTest(unittest.TestCase): ...@@ -191,29 +193,29 @@ class DescriptorPoolTest(unittest.TestCase):
enum1 = self.pool.FindEnumTypeByName( enum1 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory1Enum') 'google.protobuf.python.internal.Factory1Enum')
self.assertIsInstance(enum1, descriptor.EnumDescriptor) self.assertIsInstance(enum1, descriptor.EnumDescriptor)
self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number)
self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number)
nested_enum1 = self.pool.FindEnumTypeByName( nested_enum1 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum')
self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor)
self.assertEquals( self.assertEqual(
0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number)
self.assertEquals( self.assertEqual(
1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number)
enum2 = self.pool.FindEnumTypeByName( enum2 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory2Enum') 'google.protobuf.python.internal.Factory2Enum')
self.assertIsInstance(enum2, descriptor.EnumDescriptor) self.assertIsInstance(enum2, descriptor.EnumDescriptor)
self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) self.assertEqual(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number)
self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) self.assertEqual(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number)
nested_enum2 = self.pool.FindEnumTypeByName( nested_enum2 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum') 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')
self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor)
self.assertEquals( self.assertEqual(
0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number)
self.assertEquals( self.assertEqual(
1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number)
def testFindEnumTypeByNameFailure(self): def testFindEnumTypeByNameFailure(self):
...@@ -282,8 +284,8 @@ class ProtoFile(object): ...@@ -282,8 +284,8 @@ class ProtoFile(object):
def CheckFile(self, test, pool): def CheckFile(self, test, pool):
file_desc = pool.FindFileByName(self.name) file_desc = pool.FindFileByName(self.name)
test.assertEquals(self.name, file_desc.name) test.assertEqual(self.name, file_desc.name)
test.assertEquals(self.package, file_desc.package) test.assertEqual(self.package, file_desc.package)
dependencies_names = [f.name for f in file_desc.dependencies] dependencies_names = [f.name for f in file_desc.dependencies]
test.assertEqual(self.dependencies, dependencies_names) test.assertEqual(self.dependencies, dependencies_names)
for name, msg_type in self.messages.items(): for name, msg_type in self.messages.items():
...@@ -438,7 +440,7 @@ class AddDescriptorTest(unittest.TestCase): ...@@ -438,7 +440,7 @@ class AddDescriptorTest(unittest.TestCase):
def _TestMessage(self, prefix): def _TestMessage(self, prefix):
pool = descriptor_pool.DescriptorPool() pool = descriptor_pool.DescriptorPool()
pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR) pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes', 'protobuf_unittest.TestAllTypes',
pool.FindMessageTypeByName( pool.FindMessageTypeByName(
prefix + 'protobuf_unittest.TestAllTypes').full_name) prefix + 'protobuf_unittest.TestAllTypes').full_name)
...@@ -449,18 +451,18 @@ class AddDescriptorTest(unittest.TestCase): ...@@ -449,18 +451,18 @@ class AddDescriptorTest(unittest.TestCase):
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage') prefix + 'protobuf_unittest.TestAllTypes.NestedMessage')
pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR) pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR)
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedMessage', 'protobuf_unittest.TestAllTypes.NestedMessage',
pool.FindMessageTypeByName( pool.FindMessageTypeByName(
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
# Files are implicitly also indexed when messages are added. # Files are implicitly also indexed when messages are added.
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
pool.FindFileByName( pool.FindFileByName(
'google/protobuf/unittest.proto').name) 'google/protobuf/unittest.proto').name)
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
pool.FindFileContainingSymbol( pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)
...@@ -472,7 +474,7 @@ class AddDescriptorTest(unittest.TestCase): ...@@ -472,7 +474,7 @@ class AddDescriptorTest(unittest.TestCase):
def _TestEnum(self, prefix): def _TestEnum(self, prefix):
pool = descriptor_pool.DescriptorPool() pool = descriptor_pool.DescriptorPool()
pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
self.assertEquals( self.assertEqual(
'protobuf_unittest.ForeignEnum', 'protobuf_unittest.ForeignEnum',
pool.FindEnumTypeByName( pool.FindEnumTypeByName(
prefix + 'protobuf_unittest.ForeignEnum').full_name) prefix + 'protobuf_unittest.ForeignEnum').full_name)
...@@ -483,18 +485,18 @@ class AddDescriptorTest(unittest.TestCase): ...@@ -483,18 +485,18 @@ class AddDescriptorTest(unittest.TestCase):
prefix + 'protobuf_unittest.ForeignEnum.NestedEnum') prefix + 'protobuf_unittest.ForeignEnum.NestedEnum')
pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum', 'protobuf_unittest.TestAllTypes.NestedEnum',
pool.FindEnumTypeByName( pool.FindEnumTypeByName(
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
# Files are implicitly also indexed when enums are added. # Files are implicitly also indexed when enums are added.
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
pool.FindFileByName( pool.FindFileByName(
'google/protobuf/unittest.proto').name) 'google/protobuf/unittest.proto').name)
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
pool.FindFileContainingSymbol( pool.FindFileContainingSymbol(
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)
...@@ -506,7 +508,7 @@ class AddDescriptorTest(unittest.TestCase): ...@@ -506,7 +508,7 @@ class AddDescriptorTest(unittest.TestCase):
def testFile(self): def testFile(self):
pool = descriptor_pool.DescriptorPool() pool = descriptor_pool.DescriptorPool()
pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
pool.FindFileByName( pool.FindFileByName(
'google/protobuf/unittest.proto').name) 'google/protobuf/unittest.proto').name)
...@@ -518,43 +520,6 @@ class AddDescriptorTest(unittest.TestCase): ...@@ -518,43 +520,6 @@ class AddDescriptorTest(unittest.TestCase):
'protobuf_unittest.TestAllTypes') 'protobuf_unittest.TestAllTypes')
@unittest.skipIf(
api_implementation.Type() != 'cpp',
'default_pool is only supported by the C++ implementation')
class DefaultPoolTest(unittest.TestCase):
def testFindMethods(self):
# pylint: disable=g-import-not-at-top
from google.protobuf.pyext import _message
pool = _message.default_pool
self.assertIs(
pool.FindFileByName('google/protobuf/unittest.proto'),
unittest_pb2.DESCRIPTOR)
self.assertIs(
pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertIs(
pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
self.assertIs(
pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
self.assertIs(
pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
unittest_pb2.ForeignEnum.DESCRIPTOR)
self.assertIs(
pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
def testAddFileDescriptor(self):
# pylint: disable=g-import-not-at-top
from google.protobuf.pyext import _message
pool = _message.default_pool
file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
pool.Add(file_desc)
pool.AddSerializedFile(file_desc.SerializeToString())
TEST1_FILE = ProtoFile( TEST1_FILE = ProtoFile(
'google/protobuf/internal/descriptor_pool_test1.proto', 'google/protobuf/internal/descriptor_pool_test1.proto',
'google.protobuf.python.internal', 'google.protobuf.python.internal',
......
#! /usr/bin/env python #! /usr/bin/python
# #
# Protocol Buffers - Google's data interchange format # Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved. # Copyright 2008 Google Inc. All rights reserved.
...@@ -36,7 +36,6 @@ __author__ = 'robinson@google.com (Will Robinson)' ...@@ -36,7 +36,6 @@ __author__ = 'robinson@google.com (Will Robinson)'
import sys import sys
import unittest
from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
...@@ -46,6 +45,11 @@ from google.protobuf import descriptor ...@@ -46,6 +45,11 @@ from google.protobuf import descriptor
from google.protobuf import symbol_database from google.protobuf import symbol_database
from google.protobuf import text_format from google.protobuf import text_format
try:
import unittest2 as unittest
except ImportError:
import unittest
TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """ TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """
name: 'TestEmptyMessage' name: 'TestEmptyMessage'
...@@ -455,7 +459,7 @@ class GeneratedDescriptorTest(unittest.TestCase): ...@@ -455,7 +459,7 @@ class GeneratedDescriptorTest(unittest.TestCase):
# properties of an immutable abc.Mapping. # properties of an immutable abc.Mapping.
self.assertGreater(len(mapping), 0) # Sized self.assertGreater(len(mapping), 0) # Sized
self.assertEqual(len(mapping), len(list(mapping))) # Iterable self.assertEqual(len(mapping), len(list(mapping))) # Iterable
if sys.version_info.major >= 3: if sys.version_info >= (3,):
key, item = next(iter(mapping.items())) key, item = next(iter(mapping.items()))
else: else:
key, item = mapping.items()[0] key, item = mapping.items()[0]
...@@ -464,7 +468,7 @@ class GeneratedDescriptorTest(unittest.TestCase): ...@@ -464,7 +468,7 @@ class GeneratedDescriptorTest(unittest.TestCase):
# keys(), iterkeys() &co # keys(), iterkeys() &co
item = (next(iter(mapping.keys())), next(iter(mapping.values()))) item = (next(iter(mapping.keys())), next(iter(mapping.values())))
self.assertEqual(item, next(iter(mapping.items()))) self.assertEqual(item, next(iter(mapping.items())))
if sys.version_info.major < 3: if sys.version_info < (3,):
def CheckItems(seq, iterator): def CheckItems(seq, iterator):
self.assertEqual(next(iterator), seq[0]) self.assertEqual(next(iterator), seq[0])
self.assertEqual(list(iterator), seq[1:]) self.assertEqual(list(iterator), seq[1:])
...@@ -772,7 +776,7 @@ class MakeDescriptorTest(unittest.TestCase): ...@@ -772,7 +776,7 @@ class MakeDescriptorTest(unittest.TestCase):
reformed_descriptor = descriptor.MakeDescriptor(descriptor_proto) reformed_descriptor = descriptor.MakeDescriptor(descriptor_proto)
options = reformed_descriptor.GetOptions() options = reformed_descriptor.GetOptions()
self.assertEquals(101, self.assertEqual(101,
options.Extensions[unittest_custom_options_pb2.msgopt].i) options.Extensions[unittest_custom_options_pb2.msgopt].i)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#PY25 compatible for GAE.
#
# Copyright 2009 Google Inc. All Rights Reserved. # Copyright 2009 Google Inc. All Rights Reserved.
"""Code for encoding protocol message primitives. """Code for encoding protocol message primitives.
...@@ -45,7 +43,7 @@ FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The ...@@ -45,7 +43,7 @@ FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
sizer takes a value of this field's type and computes its byte size. The sizer takes a value of this field's type and computes its byte size. The
encoder takes a writer function and a value. It encodes the value into byte encoder takes a writer function and a value. It encodes the value into byte
strings and invokes the writer function to write those strings. Typically the strings and invokes the writer function to write those strings. Typically the
writer function is the write() method of a cStringIO. writer function is the write() method of a BytesIO.
We try to do as much work as possible when constructing the writer and the We try to do as much work as possible when constructing the writer and the
sizer rather than when calling them. In particular: sizer rather than when calling them. In particular:
...@@ -71,8 +69,9 @@ sizer rather than when calling them. In particular: ...@@ -71,8 +69,9 @@ sizer rather than when calling them. In particular:
__author__ = 'kenton@google.com (Kenton Varda)' __author__ = 'kenton@google.com (Kenton Varda)'
import struct import struct
import sys ##PY25
_PY2 = sys.version_info[0] < 3 ##PY25 import six
from google.protobuf.internal import wire_format from google.protobuf.internal import wire_format
...@@ -372,16 +371,14 @@ def MapSizer(field_descriptor): ...@@ -372,16 +371,14 @@ def MapSizer(field_descriptor):
def _VarintEncoder(): def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag).""" """Return an encoder for a basic varint value (does not include tag)."""
local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25
##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,))
def EncodeVarint(write, value): def EncodeVarint(write, value):
bits = value & 0x7f bits = value & 0x7f
value >>= 7 value >>= 7
while value: while value:
write(local_chr(0x80|bits)) write(six.int2byte(0x80|bits))
bits = value & 0x7f bits = value & 0x7f
value >>= 7 value >>= 7
return write(local_chr(bits)) return write(six.int2byte(bits))
return EncodeVarint return EncodeVarint
...@@ -390,18 +387,16 @@ def _SignedVarintEncoder(): ...@@ -390,18 +387,16 @@ def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include """Return an encoder for a basic signed varint value (does not include
tag).""" tag)."""
local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25
##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,))
def EncodeSignedVarint(write, value): def EncodeSignedVarint(write, value):
if value < 0: if value < 0:
value += (1 << 64) value += (1 << 64)
bits = value & 0x7f bits = value & 0x7f
value >>= 7 value >>= 7
while value: while value:
write(local_chr(0x80|bits)) write(six.int2byte(0x80|bits))
bits = value & 0x7f bits = value & 0x7f
value >>= 7 value >>= 7
return write(local_chr(bits)) return write(six.int2byte(bits))
return EncodeSignedVarint return EncodeSignedVarint
...@@ -416,8 +411,7 @@ def _VarintBytes(value): ...@@ -416,8 +411,7 @@ def _VarintBytes(value):
pieces = [] pieces = []
_EncodeVarint(pieces.append, value) _EncodeVarint(pieces.append, value)
return "".encode("latin1").join(pieces) ##PY25 return b"".join(pieces)
##!PY25 return b"".join(pieces)
def TagBytes(field_number, wire_type): def TagBytes(field_number, wire_type):
...@@ -555,33 +549,26 @@ def _FloatingPointEncoder(wire_type, format): ...@@ -555,33 +549,26 @@ def _FloatingPointEncoder(wire_type, format):
format: The format string to pass to struct.pack(). format: The format string to pass to struct.pack().
""" """
b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25
value_size = struct.calcsize(format) value_size = struct.calcsize(format)
if value_size == 4: if value_size == 4:
def EncodeNonFiniteOrRaise(write, value): def EncodeNonFiniteOrRaise(write, value):
# Remember that the serialized form uses little-endian byte order. # Remember that the serialized form uses little-endian byte order.
if value == _POS_INF: if value == _POS_INF:
write(b('\x00\x00\x80\x7F')) ##PY25 write(b'\x00\x00\x80\x7F')
##!PY25 write(b'\x00\x00\x80\x7F')
elif value == _NEG_INF: elif value == _NEG_INF:
write(b('\x00\x00\x80\xFF')) ##PY25 write(b'\x00\x00\x80\xFF')
##!PY25 write(b'\x00\x00\x80\xFF')
elif value != value: # NaN elif value != value: # NaN
write(b('\x00\x00\xC0\x7F')) ##PY25 write(b'\x00\x00\xC0\x7F')
##!PY25 write(b'\x00\x00\xC0\x7F')
else: else:
raise raise
elif value_size == 8: elif value_size == 8:
def EncodeNonFiniteOrRaise(write, value): def EncodeNonFiniteOrRaise(write, value):
if value == _POS_INF: if value == _POS_INF:
write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
elif value == _NEG_INF: elif value == _NEG_INF:
write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
elif value != value: # NaN elif value != value: # NaN
write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
else: else:
raise raise
else: else:
...@@ -657,10 +644,8 @@ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') ...@@ -657,10 +644,8 @@ DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
def BoolEncoder(field_number, is_repeated, is_packed): def BoolEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a boolean field.""" """Returns an encoder for a boolean field."""
##!PY25 false_byte = b'\x00' false_byte = b'\x00'
##!PY25 true_byte = b'\x01' true_byte = b'\x01'
false_byte = '\x00'.encode('latin1') ##PY25
true_byte = '\x01'.encode('latin1') ##PY25
if is_packed: if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint local_EncodeVarint = _EncodeVarint
...@@ -796,8 +781,7 @@ def MessageSetItemEncoder(field_number): ...@@ -796,8 +781,7 @@ def MessageSetItemEncoder(field_number):
} }
} }
""" """
start_bytes = "".encode("latin1").join([ ##PY25 start_bytes = b"".join([
##!PY25 start_bytes = b"".join([
TagBytes(1, wire_format.WIRETYPE_START_GROUP), TagBytes(1, wire_format.WIRETYPE_START_GROUP),
TagBytes(2, wire_format.WIRETYPE_VARINT), TagBytes(2, wire_format.WIRETYPE_VARINT),
_VarintBytes(field_number), _VarintBytes(field_number),
......
...@@ -41,7 +41,10 @@ further ensures that we can use Python protocol message objects as we expect. ...@@ -41,7 +41,10 @@ further ensures that we can use Python protocol message objects as we expect.
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf.internal import test_bad_identifiers_pb2
from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_pb2
...@@ -153,7 +156,7 @@ class GeneratorTest(unittest.TestCase): ...@@ -153,7 +156,7 @@ class GeneratorTest(unittest.TestCase):
# extension and for its value to be set to -789. # extension and for its value to be set to -789.
def testNestedTypes(self): def testNestedTypes(self):
self.assertEquals( self.assertEqual(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types), set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
set([ set([
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR, unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
...@@ -291,10 +294,10 @@ class GeneratorTest(unittest.TestCase): ...@@ -291,10 +294,10 @@ class GeneratorTest(unittest.TestCase):
self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field']) self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field'])
nested_names = set(['oneof_uint32', 'oneof_nested_message', nested_names = set(['oneof_uint32', 'oneof_nested_message',
'oneof_string', 'oneof_bytes']) 'oneof_string', 'oneof_bytes'])
self.assertItemsEqual( self.assertEqual(
nested_names, nested_names,
[field.name for field in desc.oneofs[0].fields]) set([field.name for field in desc.oneofs[0].fields]))
for field_name, field_desc in desc.fields_by_name.iteritems(): for field_name, field_desc in desc.fields_by_name.items():
if field_name in nested_names: if field_name in nested_names:
self.assertIs(desc.oneofs[0], field_desc.containing_oneof) self.assertIs(desc.oneofs[0], field_desc.containing_oneof)
else: else:
...@@ -305,36 +308,36 @@ class SymbolDatabaseRegistrationTest(unittest.TestCase): ...@@ -305,36 +308,36 @@ class SymbolDatabaseRegistrationTest(unittest.TestCase):
"""Checks that messages, enums and files are correctly registered.""" """Checks that messages, enums and files are correctly registered."""
def testGetSymbol(self): def testGetSymbol(self):
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol( unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes')) 'protobuf_unittest.TestAllTypes'))
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage, unittest_pb2.TestAllTypes.NestedMessage,
symbol_database.Default().GetSymbol( symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.NestedMessage')) 'protobuf_unittest.TestAllTypes.NestedMessage'))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage') symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage')
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes.OptionalGroup, unittest_pb2.TestAllTypes.OptionalGroup,
symbol_database.Default().GetSymbol( symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.OptionalGroup')) 'protobuf_unittest.TestAllTypes.OptionalGroup'))
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup, unittest_pb2.TestAllTypes.RepeatedGroup,
symbol_database.Default().GetSymbol( symbol_database.Default().GetSymbol(
'protobuf_unittest.TestAllTypes.RepeatedGroup')) 'protobuf_unittest.TestAllTypes.RepeatedGroup'))
def testEnums(self): def testEnums(self):
self.assertEquals( self.assertEqual(
'protobuf_unittest.ForeignEnum', 'protobuf_unittest.ForeignEnum',
symbol_database.Default().pool.FindEnumTypeByName( symbol_database.Default().pool.FindEnumTypeByName(
'protobuf_unittest.ForeignEnum').full_name) 'protobuf_unittest.ForeignEnum').full_name)
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum', 'protobuf_unittest.TestAllTypes.NestedEnum',
symbol_database.Default().pool.FindEnumTypeByName( symbol_database.Default().pool.FindEnumTypeByName(
'protobuf_unittest.TestAllTypes.NestedEnum').full_name) 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
def testFindFileByName(self): def testFindFileByName(self):
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
symbol_database.Default().pool.FindFileByName( symbol_database.Default().pool.FindFileByName(
'google/protobuf/unittest.proto').name) 'google/protobuf/unittest.proto').name)
......
...@@ -34,7 +34,10 @@ ...@@ -34,7 +34,10 @@
__author__ = 'matthewtoia@google.com (Matt Toia)' __author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2 from google.protobuf.internal import factory_test2_pb2
...@@ -81,9 +84,9 @@ class MessageFactoryTest(unittest.TestCase): ...@@ -81,9 +84,9 @@ class MessageFactoryTest(unittest.TestCase):
serialized = msg.SerializeToString() serialized = msg.SerializeToString()
converted = factory_test2_pb2.Factory2Message.FromString(serialized) converted = factory_test2_pb2.Factory2Message.FromString(serialized)
reserialized = converted.SerializeToString() reserialized = converted.SerializeToString()
self.assertEquals(serialized, reserialized) self.assertEqual(serialized, reserialized)
result = cls.FromString(reserialized) result = cls.FromString(reserialized)
self.assertEquals(msg, result) self.assertEqual(msg, result)
def testGetPrototype(self): def testGetPrototype(self):
db = descriptor_database.DescriptorDatabase() db = descriptor_database.DescriptorDatabase()
...@@ -93,11 +96,11 @@ class MessageFactoryTest(unittest.TestCase): ...@@ -93,11 +96,11 @@ class MessageFactoryTest(unittest.TestCase):
factory = message_factory.MessageFactory() factory = message_factory.MessageFactory()
cls = factory.GetPrototype(pool.FindMessageTypeByName( cls = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message')) 'google.protobuf.python.internal.Factory2Message'))
self.assertIsNot(cls, factory_test2_pb2.Factory2Message) self.assertFalse(cls is factory_test2_pb2.Factory2Message)
self._ExerciseDynamicClass(cls) self._ExerciseDynamicClass(cls)
cls2 = factory.GetPrototype(pool.FindMessageTypeByName( cls2 = factory.GetPrototype(pool.FindMessageTypeByName(
'google.protobuf.python.internal.Factory2Message')) 'google.protobuf.python.internal.Factory2Message'))
self.assertIs(cls, cls2) self.assertTrue(cls is cls2)
def testGetMessages(self): def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed # performed twice because multiple calls with the same input must be allowed
...@@ -124,8 +127,8 @@ class MessageFactoryTest(unittest.TestCase): ...@@ -124,8 +127,8 @@ class MessageFactoryTest(unittest.TestCase):
'google.protobuf.python.internal.another_field'] 'google.protobuf.python.internal.another_field']
msg1.Extensions[ext1] = 'test1' msg1.Extensions[ext1] = 'test1'
msg1.Extensions[ext2] = 'test2' msg1.Extensions[ext2] = 'test2'
self.assertEquals('test1', msg1.Extensions[ext1]) self.assertEqual('test1', msg1.Extensions[ext1])
self.assertEquals('test2', msg1.Extensions[ext2]) self.assertEqual('test2', msg1.Extensions[ext2])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -48,9 +48,16 @@ import math ...@@ -48,9 +48,16 @@ import math
import operator import operator
import pickle import pickle
import sys import sys
import unittest
import unittest import six
if six.PY3:
long = int
try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf.internal import _parameterized from google.protobuf.internal import _parameterized
from google.protobuf import map_unittest_pb2 from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
...@@ -320,7 +327,7 @@ class MessageTest(unittest.TestCase): ...@@ -320,7 +327,7 @@ class MessageTest(unittest.TestCase):
def testHighPrecisionFloatPrinting(self, message_module): def testHighPrecisionFloatPrinting(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
message.optional_double = 0.12345678912345678 message.optional_double = 0.12345678912345678
if sys.version_info.major >= 3: if sys.version_info >= (3,):
self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
else: else:
self.assertEqual(str(message), 'optional_double: 0.123456789123\n') self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
...@@ -439,7 +446,7 @@ class MessageTest(unittest.TestCase): ...@@ -439,7 +446,7 @@ class MessageTest(unittest.TestCase):
message.repeated_nested_message.sort(key=get_bb, reverse=True) message.repeated_nested_message.sort(key=get_bb, reverse=True)
self.assertEqual([k.bb for k in message.repeated_nested_message], self.assertEqual([k.bb for k in message.repeated_nested_message],
[6, 5, 4, 3, 2, 1]) [6, 5, 4, 3, 2, 1])
if sys.version_info.major >= 3: return # No cmp sorting in PY3. if sys.version_info >= (3,): return # No cmp sorting in PY3.
message.repeated_nested_message.sort(sort_function=cmp_bb) message.repeated_nested_message.sort(sort_function=cmp_bb)
self.assertEqual([k.bb for k in message.repeated_nested_message], self.assertEqual([k.bb for k in message.repeated_nested_message],
[1, 2, 3, 4, 5, 6]) [1, 2, 3, 4, 5, 6])
...@@ -458,7 +465,7 @@ class MessageTest(unittest.TestCase): ...@@ -458,7 +465,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
message.repeated_int32.sort(key=abs, reverse=True) message.repeated_int32.sort(key=abs, reverse=True)
self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
if sys.version_info.major < 3: # No cmp sorting in PY3. if sys.version_info < (3,): # No cmp sorting in PY3.
abs_cmp = lambda a, b: cmp(abs(a), abs(b)) abs_cmp = lambda a, b: cmp(abs(a), abs(b))
message.repeated_int32.sort(sort_function=abs_cmp) message.repeated_int32.sort(sort_function=abs_cmp)
self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
...@@ -472,7 +479,7 @@ class MessageTest(unittest.TestCase): ...@@ -472,7 +479,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
message.repeated_string.sort(key=len, reverse=True) message.repeated_string.sort(key=len, reverse=True)
self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
if sys.version_info.major < 3: # No cmp sorting in PY3. if sys.version_info < (3,): # No cmp sorting in PY3.
len_cmp = lambda a, b: cmp(len(a), len(b)) len_cmp = lambda a, b: cmp(len(a), len(b))
message.repeated_string.sort(sort_function=len_cmp) message.repeated_string.sort(sort_function=len_cmp)
self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
...@@ -495,7 +502,7 @@ class MessageTest(unittest.TestCase): ...@@ -495,7 +502,7 @@ class MessageTest(unittest.TestCase):
m2.repeated_nested_message.add().bb = 2 m2.repeated_nested_message.add().bb = 2
m2.repeated_nested_message.add().bb = 3 m2.repeated_nested_message.add().bb = 3
if sys.version_info.major >= 3: return # No cmp() in PY3. if sys.version_info >= (3,): return # No cmp() in PY3.
# These comparisons should not raise errors. # These comparisons should not raise errors.
_ = m1 < m2 _ = m1 < m2
...@@ -676,7 +683,7 @@ class MessageTest(unittest.TestCase): ...@@ -676,7 +683,7 @@ class MessageTest(unittest.TestCase):
in the value being converted to a Unicode string.""" in the value being converted to a Unicode string."""
m = message_module.TestAllTypes() m = message_module.TestAllTypes()
m.optional_string = str('') m.optional_string = str('')
self.assertTrue(isinstance(m.optional_string, unicode)) self.assertTrue(isinstance(m.optional_string, six.text_type))
# TODO(haberman): why are these tests Google-internal only? # TODO(haberman): why are these tests Google-internal only?
...@@ -1229,7 +1236,7 @@ class Proto3Test(unittest.TestCase): ...@@ -1229,7 +1236,7 @@ class Proto3Test(unittest.TestCase):
self.assertTrue('abc' in msg.map_string_string) self.assertTrue('abc' in msg.map_string_string)
self.assertTrue(888 in msg.map_int32_enum) self.assertTrue(888 in msg.map_int32_enum)
self.assertTrue(isinstance(msg.map_string_string['abc'], unicode)) self.assertTrue(isinstance(msg.map_string_string['abc'], six.text_type))
# Accessing an unset key still throws TypeError of the type of the key # Accessing an unset key still throws TypeError of the type of the key
# is incorrect. # is incorrect.
...@@ -1244,14 +1251,14 @@ class Proto3Test(unittest.TestCase): ...@@ -1244,14 +1251,14 @@ class Proto3Test(unittest.TestCase):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
self.assertIsNone(msg.map_int32_int32.get(5)) self.assertIsNone(msg.map_int32_int32.get(5))
self.assertEquals(10, msg.map_int32_int32.get(5, 10)) self.assertEqual(10, msg.map_int32_int32.get(5, 10))
self.assertIsNone(msg.map_int32_int32.get(5)) self.assertIsNone(msg.map_int32_int32.get(5))
msg.map_int32_int32[5] = 15 msg.map_int32_int32[5] = 15
self.assertEquals(15, msg.map_int32_int32.get(5)) self.assertEqual(15, msg.map_int32_int32.get(5))
self.assertIsNone(msg.map_int32_foreign_message.get(5)) self.assertIsNone(msg.map_int32_foreign_message.get(5))
self.assertEquals(10, msg.map_int32_foreign_message.get(5, 10)) self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
submsg = msg.map_int32_foreign_message[5] submsg = msg.map_int32_foreign_message[5]
self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
...@@ -1312,13 +1319,13 @@ class Proto3Test(unittest.TestCase): ...@@ -1312,13 +1319,13 @@ class Proto3Test(unittest.TestCase):
msg.map_string_string[bytes_obj] = bytes_obj msg.map_string_string[bytes_obj] = bytes_obj
(key, value) = msg.map_string_string.items()[0] (key, value) = list(msg.map_string_string.items())[0]
self.assertEqual(key, unicode_obj) self.assertEqual(key, unicode_obj)
self.assertEqual(value, unicode_obj) self.assertEqual(value, unicode_obj)
self.assertTrue(isinstance(key, unicode)) self.assertTrue(isinstance(key, six.text_type))
self.assertTrue(isinstance(value, unicode)) self.assertTrue(isinstance(value, six.text_type))
def testMessageMap(self): def testMessageMap(self):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
...@@ -1503,7 +1510,7 @@ class Proto3Test(unittest.TestCase): ...@@ -1503,7 +1510,7 @@ class Proto3Test(unittest.TestCase):
def testMapIteration(self): def testMapIteration(self):
msg = map_unittest_pb2.TestMap() msg = map_unittest_pb2.TestMap()
for k, v in msg.map_int32_int32.iteritems(): for k, v in msg.map_int32_int32.items():
# Should not be reached. # Should not be reached.
self.assertTrue(False) self.assertTrue(False)
...@@ -1513,7 +1520,7 @@ class Proto3Test(unittest.TestCase): ...@@ -1513,7 +1520,7 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(3, len(msg.map_int32_int32)) self.assertEqual(3, len(msg.map_int32_int32))
matching_dict = {2: 4, 3: 6, 4: 8} matching_dict = {2: 4, 3: 6, 4: 8}
self.assertMapIterEquals(msg.map_int32_int32.iteritems(), matching_dict) self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
def testMapIterationClearMessage(self): def testMapIterationClearMessage(self):
# Iterator needs to work even if message and map are deleted. # Iterator needs to work even if message and map are deleted.
...@@ -1523,7 +1530,7 @@ class Proto3Test(unittest.TestCase): ...@@ -1523,7 +1530,7 @@ class Proto3Test(unittest.TestCase):
msg.map_int32_int32[3] = 6 msg.map_int32_int32[3] = 6
msg.map_int32_int32[4] = 8 msg.map_int32_int32[4] = 8
it = msg.map_int32_int32.iteritems() it = msg.map_int32_int32.items()
del msg del msg
matching_dict = {2: 4, 3: 6, 4: 8} matching_dict = {2: 4, 3: 6, 4: 8}
...@@ -1551,7 +1558,7 @@ class Proto3Test(unittest.TestCase): ...@@ -1551,7 +1558,7 @@ class Proto3Test(unittest.TestCase):
msg.ClearField('map_int32_int32') msg.ClearField('map_int32_int32')
matching_dict = {2: 4, 3: 6, 4: 8} matching_dict = {2: 4, 3: 6, 4: 8}
self.assertMapIterEquals(map.iteritems(), matching_dict) self.assertMapIterEquals(map.items(), matching_dict)
def testMapIterValidAfterFieldCleared(self): def testMapIterValidAfterFieldCleared(self):
# Map iterator needs to work even if field is cleared. # Map iterator needs to work even if field is cleared.
...@@ -1563,7 +1570,7 @@ class Proto3Test(unittest.TestCase): ...@@ -1563,7 +1570,7 @@ class Proto3Test(unittest.TestCase):
msg.map_int32_int32[3] = 6 msg.map_int32_int32[3] = 6
msg.map_int32_int32[4] = 8 msg.map_int32_int32[4] = 8
it = msg.map_int32_int32.iteritems() it = msg.map_int32_int32.items()
msg.ClearField('map_int32_int32') msg.ClearField('map_int32_int32')
matching_dict = {2: 4, 3: 6, 4: 8} matching_dict = {2: 4, 3: 6, 4: 8}
......
...@@ -32,8 +32,15 @@ ...@@ -32,8 +32,15 @@
"""Tests for google.protobuf.proto_builder.""" """Tests for google.protobuf.proto_builder."""
import collections try:
import unittest from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict #PY26
try:
import unittest2 as unittest #PY26
except ImportError:
import unittest
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool from google.protobuf import descriptor_pool
...@@ -44,7 +51,7 @@ from google.protobuf import text_format ...@@ -44,7 +51,7 @@ from google.protobuf import text_format
class ProtoBuilderTest(unittest.TestCase): class ProtoBuilderTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.ordered_fields = collections.OrderedDict([ self.ordered_fields = OrderedDict([
('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64), ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64),
('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING), ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING),
]) ])
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Keep it Python2.5 compatible for GAE.
#
# Copyright 2007 Google Inc. All Rights Reserved. # Copyright 2007 Google Inc. All Rights Reserved.
# #
# This code is meant to work on Python 2.4 and above only. # This code is meant to work on Python 2.4 and above only.
...@@ -54,21 +52,14 @@ this file*. ...@@ -54,21 +52,14 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
from io import BytesIO
import sys import sys
if sys.version_info[0] < 3:
try:
from cStringIO import StringIO as BytesIO
except ImportError:
from StringIO import StringIO as BytesIO
import copy_reg as copyreg
_basestring = basestring
else:
from io import BytesIO
import copyreg
_basestring = str
import struct import struct
import weakref import weakref
import six
import six.moves.copyreg as copyreg
# We use "as" to avoid name collisions with variables. # We use "as" to avoid name collisions with variables.
from google.protobuf.internal import containers from google.protobuf.internal import containers
from google.protobuf.internal import decoder from google.protobuf.internal import decoder
...@@ -281,7 +272,7 @@ def _AttachFieldHelpers(cls, field_descriptor): ...@@ -281,7 +272,7 @@ def _AttachFieldHelpers(cls, field_descriptor):
def _AddClassAttributesForNestedExtensions(descriptor, dictionary): def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
extension_dict = descriptor.extensions_by_name extension_dict = descriptor.extensions_by_name
for extension_name, extension_field in extension_dict.iteritems(): for extension_name, extension_field in extension_dict.items():
assert extension_name not in dictionary assert extension_name not in dictionary
dictionary[extension_name] = extension_field dictionary[extension_name] = extension_field
...@@ -383,7 +374,7 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name): ...@@ -383,7 +374,7 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name):
exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
# re-raise possibly-amended exception with original traceback: # re-raise possibly-amended exception with original traceback:
raise type(exc)(exc, sys.exc_info()[2]) six.reraise(type(exc), exc, sys.exc_info()[2])
def _AddInitMethod(message_descriptor, cls): def _AddInitMethod(message_descriptor, cls):
...@@ -396,7 +387,7 @@ def _AddInitMethod(message_descriptor, cls): ...@@ -396,7 +387,7 @@ def _AddInitMethod(message_descriptor, cls):
enum_type with the same name. If the value is not a string, it's enum_type with the same name. If the value is not a string, it's
returned as-is. (No conversion or bounds-checking is done.) returned as-is. (No conversion or bounds-checking is done.)
""" """
if isinstance(value, _basestring): if isinstance(value, six.string_types):
try: try:
return enum_type.values_by_name[value].number return enum_type.values_by_name[value].number
except KeyError: except KeyError:
...@@ -418,7 +409,7 @@ def _AddInitMethod(message_descriptor, cls): ...@@ -418,7 +409,7 @@ def _AddInitMethod(message_descriptor, cls):
self._is_present_in_parent = False self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener() self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self) self._listener_for_children = _Listener(self)
for field_name, field_value in kwargs.iteritems(): for field_name, field_value in kwargs.items():
field = _GetFieldByName(message_descriptor, field_name) field = _GetFieldByName(message_descriptor, field_name)
if field is None: if field is None:
raise TypeError("%s() got an unexpected keyword argument '%s'" % raise TypeError("%s() got an unexpected keyword argument '%s'" %
...@@ -675,7 +666,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): ...@@ -675,7 +666,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
def _AddPropertiesForExtensions(descriptor, cls): def _AddPropertiesForExtensions(descriptor, cls):
"""Adds properties for all fields in this protocol message type.""" """Adds properties for all fields in this protocol message type."""
extension_dict = descriptor.extensions_by_name extension_dict = descriptor.extensions_by_name
for extension_name, extension_field in extension_dict.iteritems(): for extension_name, extension_field in extension_dict.items():
constant_name = extension_name.upper() + "_FIELD_NUMBER" constant_name = extension_name.upper() + "_FIELD_NUMBER"
setattr(cls, constant_name, extension_field.number) setattr(cls, constant_name, extension_field.number)
...@@ -730,7 +721,7 @@ def _AddListFieldsMethod(message_descriptor, cls): ...@@ -730,7 +721,7 @@ def _AddListFieldsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods().""" """Helper for _AddMessageMethods()."""
def ListFields(self): def ListFields(self):
all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] all_fields = [item for item in self._fields.items() if _IsPresent(item)]
all_fields.sort(key = lambda item: item[0].number) all_fields.sort(key = lambda item: item[0].number)
return all_fields return all_fields
...@@ -1128,7 +1119,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): ...@@ -1128,7 +1119,7 @@ def _AddIsInitializedMethod(message_descriptor, cls):
# ScalarMaps can't have any initialization errors. # ScalarMaps can't have any initialization errors.
pass pass
elif field.label == _FieldDescriptor.LABEL_REPEATED: elif field.label == _FieldDescriptor.LABEL_REPEATED:
for i in xrange(len(value)): for i in range(len(value)):
element = value[i] element = value[i]
prefix = "%s[%d]." % (name, i) prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors() sub_errors = element.FindInitializationErrors()
...@@ -1158,7 +1149,7 @@ def _AddMergeFromMethod(cls): ...@@ -1158,7 +1149,7 @@ def _AddMergeFromMethod(cls):
fields = self._fields fields = self._fields
for field, value in msg._fields.iteritems(): for field, value in msg._fields.items():
if field.label == LABEL_REPEATED: if field.label == LABEL_REPEATED:
field_value = fields.get(field) field_value = fields.get(field)
if field_value is None: if field_value is None:
......
...@@ -39,8 +39,13 @@ import copy ...@@ -39,8 +39,13 @@ import copy
import gc import gc
import operator import operator
import struct import struct
try:
import unittest2 as unittest
except ImportError:
import unittest
import six
import unittest
from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
...@@ -128,10 +133,10 @@ class ReflectionTest(unittest.TestCase): ...@@ -128,10 +133,10 @@ class ReflectionTest(unittest.TestCase):
repeated_bool=[True, False, False], repeated_bool=[True, False, False],
repeated_string=["optional_string"]) repeated_string=["optional_string"])
self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
self.assertEquals([1.23, 54.321], list(proto.repeated_double)) self.assertEqual([1.23, 54.321], list(proto.repeated_double))
self.assertEquals([True, False, False], list(proto.repeated_bool)) self.assertEqual([True, False, False], list(proto.repeated_bool))
self.assertEquals(["optional_string"], list(proto.repeated_string)) self.assertEqual(["optional_string"], list(proto.repeated_string))
def testRepeatedCompositeConstructor(self): def testRepeatedCompositeConstructor(self):
# Constructor with only repeated composite types should succeed. # Constructor with only repeated composite types should succeed.
...@@ -150,18 +155,18 @@ class ReflectionTest(unittest.TestCase): ...@@ -150,18 +155,18 @@ class ReflectionTest(unittest.TestCase):
unittest_pb2.TestAllTypes.RepeatedGroup(a=1), unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
self.assertEquals( self.assertEqual(
[unittest_pb2.TestAllTypes.NestedMessage( [unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO), bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage( unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)], bb=unittest_pb2.TestAllTypes.BAR)],
list(proto.repeated_nested_message)) list(proto.repeated_nested_message))
self.assertEquals( self.assertEqual(
[unittest_pb2.ForeignMessage(c=-43), [unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324), unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)], unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message)) list(proto.repeated_foreign_message))
self.assertEquals( self.assertEqual(
[unittest_pb2.TestAllTypes.RepeatedGroup(), [unittest_pb2.TestAllTypes.RepeatedGroup(),
unittest_pb2.TestAllTypes.RepeatedGroup(a=1), unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
...@@ -186,15 +191,15 @@ class ReflectionTest(unittest.TestCase): ...@@ -186,15 +191,15 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(24, proto.optional_int32) self.assertEqual(24, proto.optional_int32)
self.assertEqual('optional_string', proto.optional_string) self.assertEqual('optional_string', proto.optional_string)
self.assertEquals([1.23, 54.321], list(proto.repeated_double)) self.assertEqual([1.23, 54.321], list(proto.repeated_double))
self.assertEquals([True, False, False], list(proto.repeated_bool)) self.assertEqual([True, False, False], list(proto.repeated_bool))
self.assertEquals( self.assertEqual(
[unittest_pb2.TestAllTypes.NestedMessage( [unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO), bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage( unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)], bb=unittest_pb2.TestAllTypes.BAR)],
list(proto.repeated_nested_message)) list(proto.repeated_nested_message))
self.assertEquals( self.assertEqual(
[unittest_pb2.ForeignMessage(c=-43), [unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324), unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)], unittest_pb2.ForeignMessage(c=12)],
...@@ -222,18 +227,18 @@ class ReflectionTest(unittest.TestCase): ...@@ -222,18 +227,18 @@ class ReflectionTest(unittest.TestCase):
def testConstructorInvalidatesCachedByteSize(self): def testConstructorInvalidatesCachedByteSize(self):
message = unittest_pb2.TestAllTypes(optional_int32 = 12) message = unittest_pb2.TestAllTypes(optional_int32 = 12)
self.assertEquals(2, message.ByteSize()) self.assertEqual(2, message.ByteSize())
message = unittest_pb2.TestAllTypes( message = unittest_pb2.TestAllTypes(
optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
self.assertEquals(3, message.ByteSize()) self.assertEqual(3, message.ByteSize())
message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
self.assertEquals(3, message.ByteSize()) self.assertEqual(3, message.ByteSize())
message = unittest_pb2.TestAllTypes( message = unittest_pb2.TestAllTypes(
repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
self.assertEquals(3, message.ByteSize()) self.assertEqual(3, message.ByteSize())
def testSimpleHasBits(self): def testSimpleHasBits(self):
# Test a scalar. # Test a scalar.
...@@ -467,7 +472,7 @@ class ReflectionTest(unittest.TestCase): ...@@ -467,7 +472,7 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_string.extend(['foo', 'bar']) proto.repeated_string.extend(['foo', 'bar'])
proto.repeated_string.extend([]) proto.repeated_string.extend([])
proto.repeated_string.append('baz') proto.repeated_string.append('baz')
proto.repeated_string.extend(str(x) for x in xrange(2)) proto.repeated_string.extend(str(x) for x in range(2))
proto.optional_int32 = 21 proto.optional_int32 = 21
proto.repeated_bool # Access but don't set anything; should not be listed. proto.repeated_bool # Access but don't set anything; should not be listed.
self.assertEqual( self.assertEqual(
...@@ -617,6 +622,10 @@ class ReflectionTest(unittest.TestCase): ...@@ -617,6 +622,10 @@ class ReflectionTest(unittest.TestCase):
TestGetAndDeserialize('optional_int32', 1, int) TestGetAndDeserialize('optional_int32', 1, int)
TestGetAndDeserialize('optional_int32', 1 << 30, int) TestGetAndDeserialize('optional_int32', 1 << 30, int)
TestGetAndDeserialize('optional_uint32', 1 << 30, int) TestGetAndDeserialize('optional_uint32', 1 << 30, int)
try:
integer_64 = long
except NameError: # Python3
integer_64 = int
if struct.calcsize('L') == 4: if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32 # Python only has signed ints, so 32-bit python can't fit an uint32
# in an int. # in an int.
...@@ -624,10 +633,10 @@ class ReflectionTest(unittest.TestCase): ...@@ -624,10 +633,10 @@ class ReflectionTest(unittest.TestCase):
else: else:
# 64-bit python can fit uint32 inside an int # 64-bit python can fit uint32 inside an int
TestGetAndDeserialize('optional_uint32', 1 << 31, int) TestGetAndDeserialize('optional_uint32', 1 << 31, int)
TestGetAndDeserialize('optional_int64', 1 << 30, long) TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
TestGetAndDeserialize('optional_int64', 1 << 60, long) TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
TestGetAndDeserialize('optional_uint64', 1 << 30, long) TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
TestGetAndDeserialize('optional_uint64', 1 << 60, long) TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
def testSingleScalarBoundsChecking(self): def testSingleScalarBoundsChecking(self):
def TestMinAndMaxIntegers(field_name, expected_min, expected_max): def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
...@@ -753,18 +762,18 @@ class ReflectionTest(unittest.TestCase): ...@@ -753,18 +762,18 @@ class ReflectionTest(unittest.TestCase):
def testEnum_KeysAndValues(self): def testEnum_KeysAndValues(self):
self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
unittest_pb2.ForeignEnum.keys()) list(unittest_pb2.ForeignEnum.keys()))
self.assertEqual([4, 5, 6], self.assertEqual([4, 5, 6],
unittest_pb2.ForeignEnum.values()) list(unittest_pb2.ForeignEnum.values()))
self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
('FOREIGN_BAZ', 6)], ('FOREIGN_BAZ', 6)],
unittest_pb2.ForeignEnum.items()) list(unittest_pb2.ForeignEnum.items()))
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys()))
self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values()))
self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
proto.NestedEnum.items()) list(proto.NestedEnum.items()))
def testRepeatedScalars(self): def testRepeatedScalars(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
...@@ -803,7 +812,7 @@ class ReflectionTest(unittest.TestCase): ...@@ -803,7 +812,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
# Test slice assignment with an iterator # Test slice assignment with an iterator
proto.repeated_int32[1:4] = (i for i in xrange(3)) proto.repeated_int32[1:4] = (i for i in range(3))
self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
# Test slice assignment. # Test slice assignment.
...@@ -1006,9 +1015,8 @@ class ReflectionTest(unittest.TestCase): ...@@ -1006,9 +1015,8 @@ class ReflectionTest(unittest.TestCase):
containing_type=None, nested_types=[], enum_types=[], containing_type=None, nested_types=[], enum_types=[],
fields=[foo_field_descriptor], extensions=[], fields=[foo_field_descriptor], extensions=[],
options=descriptor_pb2.MessageOptions()) options=descriptor_pb2.MessageOptions())
class MyProtoClass(message.Message): class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
DESCRIPTOR = mydescriptor DESCRIPTOR = mydescriptor
__metaclass__ = reflection.GeneratedProtocolMessageType
myproto_instance = MyProtoClass() myproto_instance = MyProtoClass()
self.assertEqual(0, myproto_instance.foo_field) self.assertEqual(0, myproto_instance.foo_field)
self.assertTrue(not myproto_instance.HasField('foo_field')) self.assertTrue(not myproto_instance.HasField('foo_field'))
...@@ -1048,14 +1056,13 @@ class ReflectionTest(unittest.TestCase): ...@@ -1048,14 +1056,13 @@ class ReflectionTest(unittest.TestCase):
new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
desc = descriptor.MakeDescriptor(desc_proto) desc = descriptor.MakeDescriptor(desc_proto)
self.assertTrue(desc.fields_by_name.has_key('name')) self.assertTrue('name' in desc.fields_by_name)
self.assertTrue(desc.fields_by_name.has_key('year')) self.assertTrue('year' in desc.fields_by_name)
self.assertTrue(desc.fields_by_name.has_key('automatic')) self.assertTrue('automatic' in desc.fields_by_name)
self.assertTrue(desc.fields_by_name.has_key('price')) self.assertTrue('price' in desc.fields_by_name)
self.assertTrue(desc.fields_by_name.has_key('owners')) self.assertTrue('owners' in desc.fields_by_name)
class CarMessage(message.Message): class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
__metaclass__ = reflection.GeneratedProtocolMessageType
DESCRIPTOR = desc DESCRIPTOR = desc
prius = CarMessage() prius = CarMessage()
...@@ -1173,7 +1180,7 @@ class ReflectionTest(unittest.TestCase): ...@@ -1173,7 +1180,7 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
# Make sure extensions haven't been registered into types that shouldn't # Make sure extensions haven't been registered into types that shouldn't
# have any. # have any.
self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
# If message A directly contains message B, and # If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any # a.HasField('b') is currently False, then mutating any
...@@ -1497,18 +1504,18 @@ class ReflectionTest(unittest.TestCase): ...@@ -1497,18 +1504,18 @@ class ReflectionTest(unittest.TestCase):
test_util.SetAllNonLazyFields(proto) test_util.SetAllNonLazyFields(proto)
# Clear the message. # Clear the message.
proto.Clear() proto.Clear()
self.assertEquals(proto.ByteSize(), 0) self.assertEqual(proto.ByteSize(), 0)
empty_proto = unittest_pb2.TestAllTypes() empty_proto = unittest_pb2.TestAllTypes()
self.assertEquals(proto, empty_proto) self.assertEqual(proto, empty_proto)
# Test if extensions which were set are cleared. # Test if extensions which were set are cleared.
proto = unittest_pb2.TestAllExtensions() proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(proto) test_util.SetAllExtensions(proto)
# Clear the message. # Clear the message.
proto.Clear() proto.Clear()
self.assertEquals(proto.ByteSize(), 0) self.assertEqual(proto.ByteSize(), 0)
empty_proto = unittest_pb2.TestAllExtensions() empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto) self.assertEqual(proto, empty_proto)
def testDisconnectingBeforeClear(self): def testDisconnectingBeforeClear(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
...@@ -1661,14 +1668,14 @@ class ReflectionTest(unittest.TestCase): ...@@ -1661,14 +1668,14 @@ class ReflectionTest(unittest.TestCase):
setattr, proto, 'optional_bytes', u'unicode object') setattr, proto, 'optional_bytes', u'unicode object')
# Check that the default value is of python's 'unicode' type. # Check that the default value is of python's 'unicode' type.
self.assertEqual(type(proto.optional_string), unicode) self.assertEqual(type(proto.optional_string), six.text_type)
proto.optional_string = unicode('Testing') proto.optional_string = six.text_type('Testing')
self.assertEqual(proto.optional_string, str('Testing')) self.assertEqual(proto.optional_string, str('Testing'))
# Assign a value of type 'str' which can be encoded in UTF-8. # Assign a value of type 'str' which can be encoded in UTF-8.
proto.optional_string = str('Testing') proto.optional_string = str('Testing')
self.assertEqual(proto.optional_string, unicode('Testing')) self.assertEqual(proto.optional_string, six.text_type('Testing'))
# Try to assign a 'bytes' object which contains non-UTF-8. # Try to assign a 'bytes' object which contains non-UTF-8.
self.assertRaises(ValueError, self.assertRaises(ValueError,
...@@ -1715,7 +1722,7 @@ class ReflectionTest(unittest.TestCase): ...@@ -1715,7 +1722,7 @@ class ReflectionTest(unittest.TestCase):
bytes_read = message2.MergeFromString(raw.item[0].message) bytes_read = message2.MergeFromString(raw.item[0].message)
self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(len(raw.item[0].message), bytes_read)
self.assertEqual(type(message2.str), unicode) self.assertEqual(type(message2.str), six.text_type)
self.assertEqual(message2.str, test_utf8) self.assertEqual(message2.str, test_utf8)
# The pure Python API throws an exception on MergeFromString(), # The pure Python API throws an exception on MergeFromString(),
...@@ -1739,7 +1746,7 @@ class ReflectionTest(unittest.TestCase): ...@@ -1739,7 +1746,7 @@ class ReflectionTest(unittest.TestCase):
def testBytesInTextFormat(self): def testBytesInTextFormat(self):
proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
unicode(proto)) six.text_type(proto))
def testEmptyNestedMessage(self): def testEmptyNestedMessage(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
...@@ -1792,17 +1799,6 @@ class ReflectionTest(unittest.TestCase): ...@@ -1792,17 +1799,6 @@ class ReflectionTest(unittest.TestCase):
# Just check the default value. # Just check the default value.
self.assertEqual(57, msg.inner.value) self.assertEqual(57, msg.inner.value)
@unittest.skipIf(
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'CPPv2-specific test')
def testBadArguments(self):
# Some of these assertions used to segfault.
from google.protobuf.pyext import _message
self.assertRaises(TypeError, _message.default_pool.FindFieldByName, 3)
self.assertRaises(TypeError, _message.default_pool.FindExtensionByName, 42)
self.assertRaises(TypeError,
unittest_pb2.TestAllTypes().__getattribute__, 42)
# Since we had so many tests for protocol buffer equality, we broke these out # Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes. # into separate TestCase classes.
...@@ -2318,7 +2314,7 @@ class SerializationTest(unittest.TestCase): ...@@ -2318,7 +2314,7 @@ class SerializationTest(unittest.TestCase):
test_util.SetAllFields(first_proto) test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString() serialized = first_proto.SerializeToString()
for truncation_point in xrange(len(serialized) + 1): for truncation_point in range(len(serialized) + 1):
try: try:
second_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes()
unknown_fields = unittest_pb2.TestEmptyMessage() unknown_fields = unittest_pb2.TestEmptyMessage()
...@@ -2478,7 +2474,7 @@ class SerializationTest(unittest.TestCase): ...@@ -2478,7 +2474,7 @@ class SerializationTest(unittest.TestCase):
# Check that the message parsed well. # Check that the message parsed well.
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
extension1 = extension_message1.message_set_extension extension1 = extension_message1.message_set_extension
self.assertEquals(12345, proto.Extensions[extension1].i) self.assertEqual(12345, proto.Extensions[extension1].i)
def testUnknownFields(self): def testUnknownFields(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
...@@ -2919,8 +2915,7 @@ class ClassAPITest(unittest.TestCase): ...@@ -2919,8 +2915,7 @@ class ClassAPITest(unittest.TestCase):
msg_descriptor = descriptor.MakeDescriptor( msg_descriptor = descriptor.MakeDescriptor(
file_descriptor.message_type[0]) file_descriptor.message_type[0])
class MessageClass(message.Message): class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
__metaclass__ = reflection.GeneratedProtocolMessageType
DESCRIPTOR = msg_descriptor DESCRIPTOR = msg_descriptor
msg = MessageClass() msg = MessageClass()
msg_str = ( msg_str = (
......
...@@ -34,7 +34,10 @@ ...@@ -34,7 +34,10 @@
__author__ = 'petar@google.com (Petar Petrov)' __author__ = 'petar@google.com (Petar Petrov)'
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import service_reflection from google.protobuf import service_reflection
from google.protobuf import service from google.protobuf import service
......
...@@ -32,7 +32,10 @@ ...@@ -32,7 +32,10 @@
"""Tests for google.protobuf.symbol_database.""" """Tests for google.protobuf.symbol_database."""
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import symbol_database from google.protobuf import symbol_database
...@@ -64,53 +67,53 @@ class SymbolDatabaseTest(unittest.TestCase): ...@@ -64,53 +67,53 @@ class SymbolDatabaseTest(unittest.TestCase):
messages['protobuf_unittest.TestAllTypes']) messages['protobuf_unittest.TestAllTypes'])
def testGetSymbol(self): def testGetSymbol(self):
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes, self._Database().GetSymbol( unittest_pb2.TestAllTypes, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes')) 'protobuf_unittest.TestAllTypes'))
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol( unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.NestedMessage')) 'protobuf_unittest.TestAllTypes.NestedMessage'))
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol( unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.OptionalGroup')) 'protobuf_unittest.TestAllTypes.OptionalGroup'))
self.assertEquals( self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol( unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol(
'protobuf_unittest.TestAllTypes.RepeatedGroup')) 'protobuf_unittest.TestAllTypes.RepeatedGroup'))
def testEnums(self): def testEnums(self):
# Check registration of types in the pool. # Check registration of types in the pool.
self.assertEquals( self.assertEqual(
'protobuf_unittest.ForeignEnum', 'protobuf_unittest.ForeignEnum',
self._Database().pool.FindEnumTypeByName( self._Database().pool.FindEnumTypeByName(
'protobuf_unittest.ForeignEnum').full_name) 'protobuf_unittest.ForeignEnum').full_name)
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedEnum', 'protobuf_unittest.TestAllTypes.NestedEnum',
self._Database().pool.FindEnumTypeByName( self._Database().pool.FindEnumTypeByName(
'protobuf_unittest.TestAllTypes.NestedEnum').full_name) 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
def testFindMessageTypeByName(self): def testFindMessageTypeByName(self):
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes', 'protobuf_unittest.TestAllTypes',
self._Database().pool.FindMessageTypeByName( self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes').full_name) 'protobuf_unittest.TestAllTypes').full_name)
self.assertEquals( self.assertEqual(
'protobuf_unittest.TestAllTypes.NestedMessage', 'protobuf_unittest.TestAllTypes.NestedMessage',
self._Database().pool.FindMessageTypeByName( self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes.NestedMessage').full_name) 'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
def testFindFindContainingSymbol(self): def testFindFindContainingSymbol(self):
# Lookup based on either enum or message. # Lookup based on either enum or message.
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
self._Database().pool.FindFileContainingSymbol( self._Database().pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes.NestedEnum').name) 'protobuf_unittest.TestAllTypes.NestedEnum').name)
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
self._Database().pool.FindFileContainingSymbol( self._Database().pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes').name) 'protobuf_unittest.TestAllTypes').name)
def testFindFileByName(self): def testFindFileByName(self):
self.assertEquals( self.assertEqual(
'google/protobuf/unittest.proto', 'google/protobuf/unittest.proto',
self._Database().pool.FindFileByName( self._Database().pool.FindFileByName(
'google/protobuf/unittest.proto').name) 'google/protobuf/unittest.proto').name)
......
...@@ -32,7 +32,10 @@ ...@@ -32,7 +32,10 @@
"""Tests for google.protobuf.text_encoding.""" """Tests for google.protobuf.text_encoding."""
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import text_encoding from google.protobuf import text_encoding
TEST_VALUES = [ TEST_VALUES = [
...@@ -53,15 +56,15 @@ TEST_VALUES = [ ...@@ -53,15 +56,15 @@ TEST_VALUES = [
class TextEncodingTestCase(unittest.TestCase): class TextEncodingTestCase(unittest.TestCase):
def testCEscape(self): def testCEscape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES: for escaped, escaped_utf8, unescaped in TEST_VALUES:
self.assertEquals(escaped, self.assertEqual(escaped,
text_encoding.CEscape(unescaped, as_utf8=False)) text_encoding.CEscape(unescaped, as_utf8=False))
self.assertEquals(escaped_utf8, self.assertEqual(escaped_utf8,
text_encoding.CEscape(unescaped, as_utf8=True)) text_encoding.CEscape(unescaped, as_utf8=True))
def testCUnescape(self): def testCUnescape(self):
for escaped, escaped_utf8, unescaped in TEST_VALUES: for escaped, escaped_utf8, unescaped in TEST_VALUES:
self.assertEquals(unescaped, text_encoding.CUnescape(escaped)) self.assertEqual(unescaped, text_encoding.CUnescape(escaped))
self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8)) self.assertEqual(unescaped, text_encoding.CUnescape(escaped_utf8))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -35,9 +35,13 @@ ...@@ -35,9 +35,13 @@
__author__ = 'kenton@google.com (Kenton Varda)' __author__ = 'kenton@google.com (Kenton Varda)'
import re import re
import unittest
import unittest import six
try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf.internal import _parameterized from google.protobuf.internal import _parameterized
from google.protobuf import map_unittest_pb2 from google.protobuf import map_unittest_pb2
...@@ -61,7 +65,7 @@ class TextFormatBase(unittest.TestCase): ...@@ -61,7 +65,7 @@ class TextFormatBase(unittest.TestCase):
self.assertMultiLineEqual(text, ''.join(golden_lines)) self.assertMultiLineEqual(text, ''.join(golden_lines))
def CompareToGoldenText(self, text, golden_text): def CompareToGoldenText(self, text, golden_text):
self.assertMultiLineEqual(text, golden_text) self.assertEqual(text, golden_text)
def RemoveRedundantZeros(self, text): def RemoveRedundantZeros(self, text):
# Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
...@@ -100,7 +104,7 @@ class TextFormatTest(TextFormatBase): ...@@ -100,7 +104,7 @@ class TextFormatTest(TextFormatBase):
'repeated_string: "\\303\\274\\352\\234\\237"\n') 'repeated_string: "\\303\\274\\352\\234\\237"\n')
def testPrintExoticUnicodeSubclass(self, message_module): def testPrintExoticUnicodeSubclass(self, message_module):
class UnicodeSub(unicode): class UnicodeSub(six.text_type):
pass pass
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
...@@ -172,7 +176,7 @@ class TextFormatTest(TextFormatBase): ...@@ -172,7 +176,7 @@ class TextFormatTest(TextFormatBase):
parsed_message = message_module.TestAllTypes() parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message) r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message) self.assertIs(r, parsed_message)
self.assertEquals(message, parsed_message) self.assertEqual(message, parsed_message)
# Test as_utf8 = True. # Test as_utf8 = True.
wire_text = text_format.MessageToString( wire_text = text_format.MessageToString(
...@@ -180,7 +184,7 @@ class TextFormatTest(TextFormatBase): ...@@ -180,7 +184,7 @@ class TextFormatTest(TextFormatBase):
parsed_message = message_module.TestAllTypes() parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message) r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message) self.assertIs(r, parsed_message)
self.assertEquals(message, parsed_message, self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message)) '\n%s != %s' % (message, parsed_message))
def testPrintRawUtf8String(self, message_module): def testPrintRawUtf8String(self, message_module):
...@@ -190,7 +194,7 @@ class TextFormatTest(TextFormatBase): ...@@ -190,7 +194,7 @@ class TextFormatTest(TextFormatBase):
self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
parsed_message = message_module.TestAllTypes() parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message) text_format.Parse(text, parsed_message)
self.assertEquals(message, parsed_message, self.assertEqual(message, parsed_message,
'\n%s != %s' % (message, parsed_message)) '\n%s != %s' % (message, parsed_message))
def testPrintFloatFormat(self, message_module): def testPrintFloatFormat(self, message_module):
...@@ -217,13 +221,13 @@ class TextFormatTest(TextFormatBase): ...@@ -217,13 +221,13 @@ class TextFormatTest(TextFormatBase):
text_message = text_format.MessageToString(message, float_format='.15g') text_message = text_format.MessageToString(message, float_format='.15g')
self.CompareToGoldenText( self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message), self.RemoveRedundantZeros(text_message),
'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields)) 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(*formatted_fields))
# as_one_line=True is a separate code branch where float_format is passed. # as_one_line=True is a separate code branch where float_format is passed.
text_message = text_format.MessageToString(message, as_one_line=True, text_message = text_format.MessageToString(message, as_one_line=True,
float_format='.15g') float_format='.15g')
self.CompareToGoldenText( self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message), self.RemoveRedundantZeros(text_message),
'payload {{ {} {} {} {} }}'.format(*formatted_fields)) 'payload {{ {0} {1} {2} {3} }}'.format(*formatted_fields))
def testMessageToString(self, message_module): def testMessageToString(self, message_module):
message = message_module.ForeignMessage() message = message_module.ForeignMessage()
...@@ -286,7 +290,7 @@ class TextFormatTest(TextFormatBase): ...@@ -286,7 +290,7 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
text = '' text = ''
text_format.Parse(text, message) text_format.Parse(text, message)
self.assertEquals(message_module.TestAllTypes(), message) self.assertEqual(message_module.TestAllTypes(), message)
def testParseInvalidUtf8(self, message_module): def testParseInvalidUtf8(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
...@@ -296,7 +300,7 @@ class TextFormatTest(TextFormatBase): ...@@ -296,7 +300,7 @@ class TextFormatTest(TextFormatBase):
def testParseSingleWord(self, message_module): def testParseSingleWord(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
text = 'foo' text = 'foo'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named ' (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"foo".'), r'"foo".'),
...@@ -305,7 +309,7 @@ class TextFormatTest(TextFormatBase): ...@@ -305,7 +309,7 @@ class TextFormatTest(TextFormatBase):
def testParseUnknownField(self, message_module): def testParseUnknownField(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
text = 'unknown_field: 8\n' text = 'unknown_field: 8\n'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
(r'1:1 : Message type "\w+.TestAllTypes" has no field named ' (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
r'"unknown_field".'), r'"unknown_field".'),
...@@ -314,7 +318,7 @@ class TextFormatTest(TextFormatBase): ...@@ -314,7 +318,7 @@ class TextFormatTest(TextFormatBase):
def testParseBadEnumValue(self, message_module): def testParseBadEnumValue(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR' text = 'optional_nested_enum: BARR'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value named BARR.'), r'has no value named BARR.'),
...@@ -322,7 +326,7 @@ class TextFormatTest(TextFormatBase): ...@@ -322,7 +326,7 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
text = 'optional_nested_enum: 100' text = 'optional_nested_enum: 100'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
(r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value with number 100.'), r'has no value with number 100.'),
...@@ -331,7 +335,7 @@ class TextFormatTest(TextFormatBase): ...@@ -331,7 +335,7 @@ class TextFormatTest(TextFormatBase):
def testParseBadIntValue(self, message_module): def testParseBadIntValue(self, message_module):
message = message_module.TestAllTypes() message = message_module.TestAllTypes()
text = 'optional_int32: bork' text = 'optional_int32: bork'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
('1:17 : Couldn\'t parse integer: bork'), ('1:17 : Couldn\'t parse integer: bork'),
text_format.Parse, text, message) text_format.Parse, text, message)
...@@ -401,7 +405,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): ...@@ -401,7 +405,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message) test_util.SetAllFields(message)
self.assertEquals(message, parsed_message) self.assertEqual(message, parsed_message)
def testPrintAllFields(self): def testPrintAllFields(self):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
...@@ -454,7 +458,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): ...@@ -454,7 +458,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message) test_util.SetAllFields(message)
self.assertEquals(message, parsed_message) self.assertEqual(message, parsed_message)
def testPrintMap(self): def testPrintMap(self):
message = map_unittest_pb2.TestMap() message = map_unittest_pb2.TestMap()
...@@ -555,8 +559,8 @@ class Proto2Tests(TextFormatBase): ...@@ -555,8 +559,8 @@ class Proto2Tests(TextFormatBase):
text_format.Parse(text, message) text_format.Parse(text, message)
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
self.assertEquals(23, message.message_set.Extensions[ext1].i) self.assertEqual(23, message.message_set.Extensions[ext1].i)
self.assertEquals('foo', message.message_set.Extensions[ext2].str) self.assertEqual('foo', message.message_set.Extensions[ext2].str)
def testPrintAllExtensions(self): def testPrintAllExtensions(self):
message = unittest_pb2.TestAllExtensions() message = unittest_pb2.TestAllExtensions()
...@@ -581,7 +585,7 @@ class Proto2Tests(TextFormatBase): ...@@ -581,7 +585,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllExtensions() message = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(message) test_util.SetAllExtensions(message)
self.assertEquals(message, parsed_message) self.assertEqual(message, parsed_message)
def testParseAllExtensions(self): def testParseAllExtensions(self):
message = unittest_pb2.TestAllExtensions() message = unittest_pb2.TestAllExtensions()
...@@ -595,12 +599,12 @@ class Proto2Tests(TextFormatBase): ...@@ -595,12 +599,12 @@ class Proto2Tests(TextFormatBase):
def testParseBadExtension(self): def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions() message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n' text = '[unknown_extension]: 8\n'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
'1:2 : Extension "unknown_extension" not registered.', '1:2 : Extension "unknown_extension" not registered.',
text_format.Parse, text, message) text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
'extensions.'), 'extensions.'),
...@@ -619,7 +623,7 @@ class Proto2Tests(TextFormatBase): ...@@ -619,7 +623,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllExtensions() message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 ' text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67') '[protobuf_unittest.optional_int32_extension]: 67')
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' ('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
'should not have multiple ' 'should not have multiple '
...@@ -630,7 +634,7 @@ class Proto2Tests(TextFormatBase): ...@@ -630,7 +634,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
text = ('optional_nested_message { bb: 1 } ' text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }') 'optional_nested_message { bb: 2 }')
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
'should not have multiple "bb" fields.'), 'should not have multiple "bb" fields.'),
...@@ -640,7 +644,7 @@ class Proto2Tests(TextFormatBase): ...@@ -640,7 +644,7 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
text = ('optional_int32: 42 ' text = ('optional_int32: 42 '
'optional_int32: 67') 'optional_int32: 67')
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, text_format.ParseError,
('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
'have multiple "optional_int32" fields.'), 'have multiple "optional_int32" fields.'),
...@@ -649,11 +653,11 @@ class Proto2Tests(TextFormatBase): ...@@ -649,11 +653,11 @@ class Proto2Tests(TextFormatBase):
def testParseGroupNotClosed(self): def testParseGroupNotClosed(self):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
text = 'RepeatedGroup: <' text = 'RepeatedGroup: <'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, '1:16 : Expected ">".', text_format.ParseError, '1:16 : Expected ">".',
text_format.Parse, text, message) text_format.Parse, text, message)
text = 'RepeatedGroup: {' text = 'RepeatedGroup: {'
self.assertRaisesRegexp( six.assertRaisesRegex(self,
text_format.ParseError, '1:16 : Expected "}".', text_format.ParseError, '1:16 : Expected "}".',
text_format.Parse, text, message) text_format.Parse, text, message)
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#PY25 compatible for GAE.
#
# Copyright 2008 Google Inc. All Rights Reserved. # Copyright 2008 Google Inc. All Rights Reserved.
"""Provides type checking routines. """Provides type checking routines.
...@@ -49,9 +47,11 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization ...@@ -49,9 +47,11 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import sys ##PY25 import six
if sys.version < '2.6': bytes = str ##PY25
from google.protobuf.internal import api_implementation if six.PY3:
long = int
from google.protobuf.internal import decoder from google.protobuf.internal import decoder
from google.protobuf.internal import encoder from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format from google.protobuf.internal import wire_format
...@@ -117,9 +117,9 @@ class IntValueChecker(object): ...@@ -117,9 +117,9 @@ class IntValueChecker(object):
"""Checker used for integer fields. Performs type-check and range check.""" """Checker used for integer fields. Performs type-check and range check."""
def CheckValue(self, proposed_value): def CheckValue(self, proposed_value):
if not isinstance(proposed_value, (int, long)): if not isinstance(proposed_value, six.integer_types):
message = ('%.1024r has type %s, but expected one of: %s' % message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (int, long))) (proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message) raise TypeError(message)
if not self._MIN <= proposed_value <= self._MAX: if not self._MIN <= proposed_value <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value) raise ValueError('Value out of range: %d' % proposed_value)
...@@ -141,9 +141,9 @@ class EnumValueChecker(object): ...@@ -141,9 +141,9 @@ class EnumValueChecker(object):
self._enum_type = enum_type self._enum_type = enum_type
def CheckValue(self, proposed_value): def CheckValue(self, proposed_value):
if not isinstance(proposed_value, (int, long)): if not isinstance(proposed_value, six.integer_types):
message = ('%.1024r has type %s, but expected one of: %s' % message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (int, long))) (proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message) raise TypeError(message)
if proposed_value not in self._enum_type.values_by_number: if proposed_value not in self._enum_type.values_by_number:
raise ValueError('Unknown enum value: %d' % proposed_value) raise ValueError('Unknown enum value: %d' % proposed_value)
...@@ -161,9 +161,9 @@ class UnicodeValueChecker(object): ...@@ -161,9 +161,9 @@ class UnicodeValueChecker(object):
""" """
def CheckValue(self, proposed_value): def CheckValue(self, proposed_value):
if not isinstance(proposed_value, (bytes, unicode)): if not isinstance(proposed_value, (bytes, six.text_type)):
message = ('%.1024r has type %s, but expected one of: %s' % message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), (bytes, unicode))) (proposed_value, type(proposed_value), (bytes, six.text_type)))
raise TypeError(message) raise TypeError(message)
# If the value is of type 'bytes' make sure that it is valid UTF-8 data. # If the value is of type 'bytes' make sure that it is valid UTF-8 data.
......
...@@ -35,7 +35,11 @@ ...@@ -35,7 +35,11 @@
__author__ = 'bohdank@google.com (Bohdan Koval)' __author__ = 'bohdank@google.com (Bohdan Koval)'
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf import unittest_proto3_arena_pb2
......
...@@ -34,7 +34,10 @@ ...@@ -34,7 +34,10 @@
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import unittest try:
import unittest2 as unittest
except ImportError:
import unittest
from google.protobuf import message from google.protobuf import message
from google.protobuf.internal import wire_format from google.protobuf.internal import wire_format
......
...@@ -28,8 +28,6 @@ ...@@ -28,8 +28,6 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#PY25 compatible for GAE.
#
# Copyright 2012 Google Inc. All Rights Reserved. # Copyright 2012 Google Inc. All Rights Reserved.
"""Provides a factory class for generating dynamic messages. """Provides a factory class for generating dynamic messages.
...@@ -43,7 +41,6 @@ my_proto_instance = message_classes['some.proto.package.MessageName']() ...@@ -43,7 +41,6 @@ my_proto_instance = message_classes['some.proto.package.MessageName']()
__author__ = 'matthewtoia@google.com (Matt Toia)' __author__ = 'matthewtoia@google.com (Matt Toia)'
import sys ##PY25
from google.protobuf import descriptor_database from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool from google.protobuf import descriptor_pool
from google.protobuf import message from google.protobuf import message
...@@ -75,8 +72,7 @@ class MessageFactory(object): ...@@ -75,8 +72,7 @@ class MessageFactory(object):
""" """
if descriptor.full_name not in self._classes: if descriptor.full_name not in self._classes:
descriptor_name = descriptor.name descriptor_name = descriptor.name
if sys.version_info[0] < 3: ##PY25 if str is bytes: # PY2
##!PY25 if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore') descriptor_name = descriptor.name.encode('ascii', 'ignore')
result_class = reflection.GeneratedProtocolMessageType( result_class = reflection.GeneratedProtocolMessageType(
descriptor_name, descriptor_name,
...@@ -111,7 +107,7 @@ class MessageFactory(object): ...@@ -111,7 +107,7 @@ class MessageFactory(object):
result = {} result = {}
for file_name in files: for file_name in files:
file_desc = self.pool.FindFileByName(file_name) file_desc = self.pool.FindFileByName(file_name)
for name, msg in file_desc.message_types_by_name.iteritems(): for name, msg in file_desc.message_types_by_name.items():
if file_desc.package: if file_desc.package:
full_name = '.'.join([file_desc.package, name]) full_name = '.'.join([file_desc.package, name])
else: else:
...@@ -128,7 +124,7 @@ class MessageFactory(object): ...@@ -128,7 +124,7 @@ class MessageFactory(object):
# ignore the registration if the original was the same, or raise # ignore the registration if the original was the same, or raise
# an error if they were different. # an error if they were different.
for name, extension in file_desc.extensions_by_name.iteritems(): for name, extension in file_desc.extensions_by_name.items():
if extension.containing_type.full_name not in self._classes: if extension.containing_type.full_name not in self._classes:
self.GetPrototype(extension.containing_type) self.GetPrototype(extension.containing_type)
extended_class = self._classes[extension.containing_type.full_name] extended_class = self._classes[extension.containing_type.full_name]
......
...@@ -30,7 +30,10 @@ ...@@ -30,7 +30,10 @@
"""Dynamic Protobuf class creator.""" """Dynamic Protobuf class creator."""
import collections try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict #PY26
import hashlib import hashlib
import os import os
...@@ -80,7 +83,7 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): ...@@ -80,7 +83,7 @@ def MakeSimpleProtoClass(fields, full_name, pool=None):
# an OrderedDict we keep the order, but otherwise we sort the field to ensure # an OrderedDict we keep the order, but otherwise we sort the field to ensure
# consistent ordering. # consistent ordering.
field_items = fields.items() field_items = fields.items()
if not isinstance(fields, collections.OrderedDict): if not isinstance(fields, OrderedDict):
field_items = sorted(field_items) field_items = sorted(field_items)
# Use a consistent file name that is unlikely to conflict with any imported # Use a consistent file name that is unlikely to conflict with any imported
......
...@@ -27,16 +27,13 @@ ...@@ -27,16 +27,13 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#PY25 compatible for GAE.
#
"""Encoding related utilities.""" """Encoding related utilities."""
import re import re
import sys ##PY25
import six
# Lookup table for utf8 # Lookup table for utf8
_cescape_utf8_to_str = [chr(i) for i in xrange(0, 256)] _cescape_utf8_to_str = [chr(i) for i in range(0, 256)]
_cescape_utf8_to_str[9] = r'\t' # optional escape _cescape_utf8_to_str[9] = r'\t' # optional escape
_cescape_utf8_to_str[10] = r'\n' # optional escape _cescape_utf8_to_str[10] = r'\n' # optional escape
_cescape_utf8_to_str[13] = r'\r' # optional escape _cescape_utf8_to_str[13] = r'\r' # optional escape
...@@ -46,9 +43,9 @@ _cescape_utf8_to_str[34] = r'\"' # necessary escape ...@@ -46,9 +43,9 @@ _cescape_utf8_to_str[34] = r'\"' # necessary escape
_cescape_utf8_to_str[92] = r'\\' # necessary escape _cescape_utf8_to_str[92] = r'\\' # necessary escape
# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) # Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32)
_cescape_byte_to_str = ([r'\%03o' % i for i in xrange(0, 32)] + _cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] +
[chr(i) for i in xrange(32, 127)] + [chr(i) for i in range(32, 127)] +
[r'\%03o' % i for i in xrange(127, 256)]) [r'\%03o' % i for i in range(127, 256)])
_cescape_byte_to_str[9] = r'\t' # optional escape _cescape_byte_to_str[9] = r'\t' # optional escape
_cescape_byte_to_str[10] = r'\n' # optional escape _cescape_byte_to_str[10] = r'\n' # optional escape
_cescape_byte_to_str[13] = r'\r' # optional escape _cescape_byte_to_str[13] = r'\r' # optional escape
...@@ -75,7 +72,7 @@ def CEscape(text, as_utf8): ...@@ -75,7 +72,7 @@ def CEscape(text, as_utf8):
""" """
# PY3 hack: make Ord work for str and bytes: # PY3 hack: make Ord work for str and bytes:
# //platforms/networking/data uses unicode here, hence basestring. # //platforms/networking/data uses unicode here, hence basestring.
Ord = ord if isinstance(text, basestring) else lambda x: x Ord = ord if isinstance(text, six.string_types) else lambda x: x
if as_utf8: if as_utf8:
return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text) return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text)
return ''.join(_cescape_byte_to_str[Ord(c)] for c in text) return ''.join(_cescape_byte_to_str[Ord(c)] for c in text)
...@@ -100,8 +97,7 @@ def CUnescape(text): ...@@ -100,8 +97,7 @@ def CUnescape(text):
# allow single-digit hex escapes (like '\xf'). # allow single-digit hex escapes (like '\xf').
result = _CUNESCAPE_HEX.sub(ReplaceHex, text) result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
if sys.version_info[0] < 3: ##PY25 if str is bytes: # PY2
##!PY25 if str is bytes: # PY2
return result.decode('string_escape') return result.decode('string_escape')
result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result) result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result)
return (result.encode('ascii') # Make it bytes to allow decode. return (result.encode('ascii') # Make it bytes to allow decode.
......
...@@ -28,17 +28,20 @@ ...@@ -28,17 +28,20 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#PY25 compatible for GAE.
#
# Copyright 2007 Google Inc. All Rights Reserved. # Copyright 2007 Google Inc. All Rights Reserved.
"""Contains routines for printing protocol messages in text format.""" """Contains routines for printing protocol messages in text format."""
__author__ = 'kenton@google.com (Kenton Varda)' __author__ = 'kenton@google.com (Kenton Varda)'
import cStringIO import io
import re import re
import six
if six.PY3:
long = int
from google.protobuf.internal import type_checkers from google.protobuf.internal import type_checkers
from google.protobuf import descriptor from google.protobuf import descriptor
from google.protobuf import text_encoding from google.protobuf import text_encoding
...@@ -89,7 +92,7 @@ def MessageToString(message, as_utf8=False, as_one_line=False, ...@@ -89,7 +92,7 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
Returns: Returns:
A string of the text formatted protocol buffer message. A string of the text formatted protocol buffer message.
""" """
out = cStringIO.StringIO() out = io.BytesIO()
PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line,
pointy_brackets=pointy_brackets, pointy_brackets=pointy_brackets,
use_index_order=use_index_order, use_index_order=use_index_order,
...@@ -136,7 +139,6 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, ...@@ -136,7 +139,6 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
use_index_order=use_index_order, use_index_order=use_index_order,
float_format=float_format) float_format=float_format)
def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
pointy_brackets=False, use_index_order=False, float_format=None): pointy_brackets=False, use_index_order=False, float_format=None):
"""Print a single field name/value pair. For repeated fields, the value """Print a single field name/value pair. For repeated fields, the value
...@@ -157,7 +159,11 @@ def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, ...@@ -157,7 +159,11 @@ def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
# For groups, use the capitalized name. # For groups, use the capitalized name.
out.write(field.message_type.name) out.write(field.message_type.name)
else: else:
out.write(field.name) if isinstance(field.name, six.text_type):
name = field.name.encode('utf-8')
else:
name = field.name
out.write(name)
if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
# The colon is optional in this case, but our cross-language golden files # The colon is optional in this case, but our cross-language golden files
...@@ -211,7 +217,7 @@ def PrintFieldValue(field, value, out, indent=0, as_utf8=False, ...@@ -211,7 +217,7 @@ def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
out.write(str(value)) out.write(str(value))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
out.write('\"') out.write('\"')
if isinstance(value, unicode): if isinstance(value, six.text_type):
out_value = value.encode('utf-8') out_value = value.encode('utf-8')
else: else:
out_value = value out_value = value
...@@ -537,7 +543,7 @@ class _Tokenizer(object): ...@@ -537,7 +543,7 @@ class _Tokenizer(object):
def _PopLine(self): def _PopLine(self):
while len(self._current_line) <= self._column: while len(self._current_line) <= self._column:
try: try:
self._current_line = self._lines.next() self._current_line = next(self._lines)
except StopIteration: except StopIteration:
self._current_line = '' self._current_line = ''
self._more_lines = False self._more_lines = False
...@@ -607,7 +613,7 @@ class _Tokenizer(object): ...@@ -607,7 +613,7 @@ class _Tokenizer(object):
""" """
try: try:
result = ParseInteger(self.token, is_signed=True, is_long=False) result = ParseInteger(self.token, is_signed=True, is_long=False)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -623,7 +629,7 @@ class _Tokenizer(object): ...@@ -623,7 +629,7 @@ class _Tokenizer(object):
""" """
try: try:
result = ParseInteger(self.token, is_signed=False, is_long=False) result = ParseInteger(self.token, is_signed=False, is_long=False)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -639,7 +645,7 @@ class _Tokenizer(object): ...@@ -639,7 +645,7 @@ class _Tokenizer(object):
""" """
try: try:
result = ParseInteger(self.token, is_signed=True, is_long=True) result = ParseInteger(self.token, is_signed=True, is_long=True)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -655,7 +661,7 @@ class _Tokenizer(object): ...@@ -655,7 +661,7 @@ class _Tokenizer(object):
""" """
try: try:
result = ParseInteger(self.token, is_signed=False, is_long=True) result = ParseInteger(self.token, is_signed=False, is_long=True)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -671,7 +677,7 @@ class _Tokenizer(object): ...@@ -671,7 +677,7 @@ class _Tokenizer(object):
""" """
try: try:
result = ParseFloat(self.token) result = ParseFloat(self.token)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -687,7 +693,7 @@ class _Tokenizer(object): ...@@ -687,7 +693,7 @@ class _Tokenizer(object):
""" """
try: try:
result = ParseBool(self.token) result = ParseBool(self.token)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -703,8 +709,8 @@ class _Tokenizer(object): ...@@ -703,8 +709,8 @@ class _Tokenizer(object):
""" """
the_bytes = self.ConsumeByteString() the_bytes = self.ConsumeByteString()
try: try:
return unicode(the_bytes, 'utf-8') return six.text_type(the_bytes, 'utf-8')
except UnicodeDecodeError, e: except UnicodeDecodeError as e:
raise self._StringParseError(e) raise self._StringParseError(e)
def ConsumeByteString(self): def ConsumeByteString(self):
...@@ -719,8 +725,7 @@ class _Tokenizer(object): ...@@ -719,8 +725,7 @@ class _Tokenizer(object):
the_list = [self._ConsumeSingleByteString()] the_list = [self._ConsumeSingleByteString()]
while self.token and self.token[0] in ('\'', '"'): while self.token and self.token[0] in ('\'', '"'):
the_list.append(self._ConsumeSingleByteString()) the_list.append(self._ConsumeSingleByteString())
return ''.encode('latin1').join(the_list) ##PY25 return b''.join(the_list)
##!PY25 return b''.join(the_list)
def _ConsumeSingleByteString(self): def _ConsumeSingleByteString(self):
"""Consume one token of a string literal. """Consume one token of a string literal.
...@@ -741,7 +746,7 @@ class _Tokenizer(object): ...@@ -741,7 +746,7 @@ class _Tokenizer(object):
try: try:
result = text_encoding.CUnescape(text[1:-1]) result = text_encoding.CUnescape(text[1:-1])
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
...@@ -749,7 +754,7 @@ class _Tokenizer(object): ...@@ -749,7 +754,7 @@ class _Tokenizer(object):
def ConsumeEnum(self, field): def ConsumeEnum(self, field):
try: try:
result = ParseEnum(field, self.token) result = ParseEnum(field, self.token)
except ValueError, e: except ValueError as e:
raise self._ParseError(str(e)) raise self._ParseError(str(e))
self.NextToken() self.NextToken()
return result return result
......
...@@ -8,19 +8,7 @@ import sys ...@@ -8,19 +8,7 @@ import sys
# We must use setuptools, not distutils, because we need to use the # We must use setuptools, not distutils, because we need to use the
# namespace_packages option for the "google" package. # namespace_packages option for the "google" package.
try: from setuptools import setup, Extension, find_packages
from setuptools import setup, Extension, find_packages
except ImportError:
try:
from ez_setup import use_setuptools
use_setuptools()
from setuptools import setup, Extension, find_packages
except ImportError:
sys.stderr.write(
"Could not import setuptools; make sure you have setuptools or "
"ez_setup installed.\n"
)
raise
from distutils.command.clean import clean as _clean from distutils.command.clean import clean as _clean
...@@ -79,16 +67,14 @@ def generate_proto(source, require = True): ...@@ -79,16 +67,14 @@ def generate_proto(source, require = True):
if protoc is None: if protoc is None:
sys.stderr.write( sys.stderr.write(
"protoc is not installed nor found in ../src. " "protoc is not installed nor found in ../src. Please compile it "
"Please compile it or install the binary package.\n" "or install the binary package.\n")
)
sys.exit(-1) sys.exit(-1)
protoc_command = [protoc, "-I../src", "-I.", "--python_out=.", source] protoc_command = [ protoc, "-I../src", "-I.", "--python_out=.", source ]
if subprocess.call(protoc_command) != 0: if subprocess.call(protoc_command) != 0:
sys.exit(-1) sys.exit(-1)
def GenerateUnittestProtos(): def GenerateUnittestProtos():
generate_proto("../src/google/protobuf/map_unittest.proto", False) generate_proto("../src/google/protobuf/map_unittest.proto", False)
generate_proto("../src/google/protobuf/unittest.proto", False) generate_proto("../src/google/protobuf/unittest.proto", False)
...@@ -119,13 +105,12 @@ class clean(_clean): ...@@ -119,13 +105,12 @@ class clean(_clean):
for filename in filenames: for filename in filenames:
filepath = os.path.join(dirpath, filename) filepath = os.path.join(dirpath, filename)
if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \ if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \
filepath.endswith(".so") or filepath.endswith(".o") or \ filepath.endswith(".so") or filepath.endswith(".o") or \
filepath.endswith('google/protobuf/compiler/__init__.py'): filepath.endswith('google/protobuf/compiler/__init__.py'):
os.remove(filepath) os.remove(filepath)
# _clean is an old-style class, so super() doesn't work. # _clean is an old-style class, so super() doesn't work.
_clean.run(self) _clean.run(self)
class build_py(_build_py): class build_py(_build_py):
def run(self): def run(self):
# Generate necessary .proto file if it doesn't exist. # Generate necessary .proto file if it doesn't exist.
...@@ -141,13 +126,7 @@ class build_py(_build_py): ...@@ -141,13 +126,7 @@ class build_py(_build_py):
pass pass
# _build_py is an old-style class, so super() doesn't work. # _build_py is an old-style class, so super() doesn't work.
_build_py.run(self) _build_py.run(self)
# TODO(mrovner): Subclass to run 2to3 on some files only.
# Tracing what https://wiki.python.org/moin/PortingPythonToPy3k's
# "Approach 2" section on how to get 2to3 to run on source files during
# install under Python 3. This class seems like a good place to put logic
# that calls python3's distutils.util.run_2to3 on the subset of the files we
# have in our release that are subject to conversion.
# See code reference in previous code review.
if __name__ == '__main__': if __name__ == '__main__':
ext_module_list = [] ext_module_list = []
...@@ -167,6 +146,12 @@ if __name__ == '__main__': ...@@ -167,6 +146,12 @@ if __name__ == '__main__':
) )
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
# Keep this list of dependencies in sync with tox.ini.
install_requires = ['six', 'setuptools']
if sys.version_info <= (2,7):
install_requires.append('ordereddict')
install_requires.append('unittest2')
setup( setup(
name='protobuf', name='protobuf',
version=GetVersion(), version=GetVersion(),
...@@ -177,8 +162,14 @@ if __name__ == '__main__': ...@@ -177,8 +162,14 @@ if __name__ == '__main__':
maintainer_email='protobuf@googlegroups.com', maintainer_email='protobuf@googlegroups.com',
license='New BSD License', license='New BSD License',
classifiers=[ classifiers=[
'Programming Language :: Python :: 2.7', "Programming Language :: Python",
], "Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.6",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4",
],
namespace_packages=['google'], namespace_packages=['google'],
packages=find_packages( packages=find_packages(
exclude=[ exclude=[
...@@ -190,6 +181,6 @@ if __name__ == '__main__': ...@@ -190,6 +181,6 @@ if __name__ == '__main__':
'clean': clean, 'clean': clean,
'build_py': build_py, 'build_py': build_py,
}, },
install_requires=['setuptools'], install_requires=install_requires,
ext_modules=ext_module_list, ext_modules=ext_module_list,
) )
[tox]
envlist =
# Py3 tests currently fail because of text handling issues,
# So only test py26/py27 for now.
#py{26,27,33,34}-{cpp,python}
py{26,27}-{cpp,python}
[testenv]
usedevelop=true
setenv =
cpp: LD_LIBRARY_PATH={toxinidir}/../src/.libs
cpp: DYLD_LIBRARY_PATH={toxinidir}/../src/.libs
cpp: PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
commands =
python setup.py -q build_py
python: python setup.py -q build
cpp: python setup.py -q build --cpp_implementation
python: python setup.py -q test -q
cpp: python setup.py -q test -q --cpp_implementation
deps =
# Keep this list of dependencies in sync with setup.py.
six
py26: ordereddict
py26: unittest2
...@@ -111,25 +111,45 @@ build_javanano_oracle7() { ...@@ -111,25 +111,45 @@ build_javanano_oracle7() {
build_javanano build_javanano
} }
internal_install_python_deps() {
sudo pip install tox
# Only install Python2.6 on Linux.
if [ $(uname -s) == "Linux" ]; then
sudo apt-get install -y python-software-properties # for apt-add-repository
sudo apt-add-repository -y ppa:fkrull/deadsnakes
sudo apt-get update -qq
sudo apt-get install -y python2.6 python2.6-dev
fi
}
build_python() { build_python() {
internal_build_cpp internal_build_cpp
internal_install_python_deps
cd python cd python
python setup.py build # Only test Python 2.6 on Linux
python setup.py test if [ $(uname -s) == "Linux" ]; then
python setup.py sdist envlist=py26-python,py27-python
sudo pip install virtualenv && virtualenv /tmp/protoenv && /tmp/protoenv/bin/pip install dist/* else
envlist=py27-python
fi
tox -e $envlist
cd .. cd ..
} }
build_python_cpp() { build_python_cpp() {
internal_build_cpp internal_build_cpp
export LD_LIBRARY_PATH=../src/.libs # for Linux internal_install_python_deps
export LD_LIBRARY_PATH=../src/.libs # for Linux
export DYLD_LIBRARY_PATH=../src/.libs # for OS X export DYLD_LIBRARY_PATH=../src/.libs # for OS X
cd python cd python
python setup.py build --cpp_implementation # Only test Python 2.6 on Linux
python setup.py test --cpp_implementation if [ $(uname -s) == "Linux" ]; then
python setup.py sdist --cpp_implementation envlist=py26-cpp,py27-cpp
sudo pip install virtualenv && virtualenv /tmp/protoenv && /tmp/protoenv/bin/pip install dist/* else
envlist=py27-cpp
fi
tox -e $envlist
cd .. cd ..
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment