Commit 99aa0f9e authored by Feng Xiao's avatar Feng Xiao

Down-integrate from internal code base.

parent 49bc8c09
......@@ -55,14 +55,6 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite
/** For use by generated code only. */
protected UnknownFieldSetLite unknownFields;
protected GeneratedMessageLite() {
unknownFields = UnknownFieldSetLite.getDefaultInstance();
}
protected GeneratedMessageLite(Builder builder) {
unknownFields = builder.unknownFields;
}
public Parser<? extends MessageLite> getParserForType() {
throw new UnsupportedOperationException(
"This is supposed to be overridden by subclasses.");
......@@ -87,7 +79,9 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite
extends AbstractMessageLite.Builder<BuilderType> {
private final MessageType defaultInstance;
private UnknownFieldSetLite unknownFields =
/** For use by generated code only. */
protected UnknownFieldSetLite unknownFields =
UnknownFieldSetLite.getDefaultInstance();
protected Builder(MessageType defaultInstance) {
......@@ -211,16 +205,8 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite
* Represents the set of extensions on this message. For use by generated
* code only.
*/
protected final FieldSet<ExtensionDescriptor> extensions;
protected ExtendableMessage() {
this.extensions = FieldSet.newFieldSet();
}
protected ExtendableMessage(ExtendableBuilder<MessageType, ?> builder) {
this.extensions = builder.buildExtensions();
}
protected FieldSet<ExtensionDescriptor> extensions = FieldSet.newFieldSet();
private void verifyExtensionContainingType(
final GeneratedExtension<MessageType, ?> extension) {
if (extension.getContainingTypeDefaultInstance() !=
......@@ -396,8 +382,10 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite
/**
* Called by the build code path to create a copy of the extensions for
* building the message.
* <p>
* For use by generated code only.
*/
private FieldSet<ExtensionDescriptor> buildExtensions() {
protected final FieldSet<ExtensionDescriptor> buildExtensions() {
extensions.makeImmutable();
extensionsIsMutable = false;
return extensions;
......@@ -472,8 +460,7 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite
// of this dummy clone() implementation makes it go away.
@Override
public BuilderType clone() {
throw new UnsupportedOperationException(
"This is supposed to be overridden by subclasses.");
return super.clone();
}
/** Set the value of an extension. */
......
......@@ -393,6 +393,10 @@ public class Internal {
*/
public static final ByteBuffer EMPTY_BYTE_BUFFER =
ByteBuffer.wrap(EMPTY_BYTE_ARRAY);
/** An empty coded input stream constant used in generated code. */
public static final CodedInputStream EMPTY_CODED_INPUT_STREAM =
CodedInputStream.newInstance(EMPTY_BYTE_ARRAY);
/**
......
......@@ -145,4 +145,17 @@ public class LiteTest extends TestCase {
assertEquals(expected.getOptionalNestedMessage().getBb(),
actual.getOptionalNestedMessage().getBb());
}
public void testClone() {
TestAllTypesLite.Builder expected = TestAllTypesLite.newBuilder()
.setOptionalInt32(123);
assertEquals(
expected.getOptionalInt32(), expected.clone().getOptionalInt32());
TestAllExtensionsLite.Builder expected2 = TestAllExtensionsLite.newBuilder()
.setExtension(UnittestLite.optionalInt32ExtensionLite, 123);
assertEquals(
expected2.getExtension(UnittestLite.optionalInt32ExtensionLite),
expected2.clone().getExtension(UnittestLite.optionalInt32ExtensionLite));
}
}
......@@ -699,7 +699,7 @@ class _Tokenizer(object):
"""
text = self.token
if len(text) < 1 or text[0] not in ('\'', '"'):
raise self._ParseError('Expected string.')
raise self._ParseError('Expected string but found: "%r"' % text)
if len(text) < 2 or text[-1] != text[0]:
raise self._ParseError('String missing ending quote.')
......
......@@ -30,6 +30,10 @@
#include <google/protobuf/arena.h>
#ifdef ADDRESS_SANITIZER
#include <sanitizer/asan_interface.h>
#endif
namespace google {
namespace protobuf {
......@@ -91,6 +95,12 @@ Arena::Block* Arena::NewBlock(void* me, Block* my_last_block, size_t n,
} else {
b->owner = me;
}
#ifdef ADDRESS_SANITIZER
// Poison the rest of the block for ASAN. It was unpoisoned by the underlying
// malloc but it's not yet usable until we return it as part of an allocation.
ASAN_POISON_MEMORY_REGION(
reinterpret_cast<char*>(b) + b->pos, b->size - b->pos);
#endif
return b;
}
......@@ -152,6 +162,9 @@ void* Arena::AllocateAligned(size_t n) {
void* Arena::AllocFromBlock(Block* b, size_t n) {
size_t p = b->pos;
b->pos = p + n;
#ifdef ADDRESS_SANITIZER
ASAN_UNPOISON_MEMORY_REGION(reinterpret_cast<char*>(b) + p, n);
#endif
return reinterpret_cast<char*>(b) + p;
}
......
......@@ -144,7 +144,7 @@ TEST(ArenaTest, InitialBlockTooSmall) {
// Write to the memory we allocated; this should (but is not guaranteed to)
// trigger a check for heap corruption if the object was allocated from the
// initially-provided block.
memset(p, '\0', 128);
memset(p, '\0', 96);
}
TEST(ArenaTest, Parsing) {
......
......@@ -614,7 +614,7 @@ GenerateFieldAccessorDefinitions(io::Printer* printer) {
vars["classname"] = classname_;
printer->Print(
vars,
"inline bool $classname$::has_$oneof_name$() {\n"
"inline bool $classname$::has_$oneof_name$() const {\n"
" return $oneof_name$_case() != $cap_oneof_name$_NOT_SET;\n"
"}\n"
"inline void $classname$::clear_has_$oneof_name$() {\n"
......@@ -975,7 +975,7 @@ GenerateClassDefinition(io::Printer* printer) {
// Generate oneof function declarations
for (int i = 0; i < descriptor_->oneof_decl_count(); i++) {
printer->Print(
"inline bool has_$oneof_name$();\n"
"inline bool has_$oneof_name$() const;\n"
"void clear_$oneof_name$();\n"
"inline void clear_has_$oneof_name$();\n\n",
"oneof_name", descriptor_->oneof_decl(i)->name());
......
......@@ -306,31 +306,28 @@ void ImmutableMessageGenerator::Generate(io::Printer* printer) {
variables["lite"]);
}
printer->Indent();
// Using builder_type, instead of Builder, prevents the Builder class from
// being loaded into PermGen space when the default instance is created.
// This optimizes the PermGen space usage for clients that do not modify
// messages.
printer->Print(
"// Use $classname$.newBuilder() to construct.\n"
"private $classname$($buildertype$ builder) {\n"
" super(builder);\n"
"}\n",
"classname", descriptor_->name(),
"buildertype", builder_type);
printer->Print(
"private $classname$() {\n",
"classname", descriptor_->name());
printer->Indent();
for (int i = 0; i < descriptor_->field_count(); i++) {
if (!descriptor_->field(i)->containing_oneof()) {
field_generators_.get(descriptor_->field(i))
.GenerateInitializationCode(printer);
}
if (HasDescriptorMethods(descriptor_)) {
// Using builder_type, instead of Builder, prevents the Builder class from
// being loaded into PermGen space when the default instance is created.
// This optimizes the PermGen space usage for clients that do not modify
// messages.
printer->Print(
"// Use $classname$.newBuilder() to construct.\n"
"private $classname$($buildertype$ builder) {\n"
" super(builder);\n"
"}\n",
"classname", descriptor_->name(),
"buildertype", builder_type);
printer->Print(
"private $classname$() {\n",
"classname", descriptor_->name());
printer->Indent();
GenerateInitializers(printer);
printer->Outdent();
printer->Print(
"}\n"
"\n");
}
printer->Outdent();
printer->Print(
"}\n"
"\n");
if (HasDescriptorMethods(descriptor_)) {
printer->Print(
......@@ -480,10 +477,37 @@ void ImmutableMessageGenerator::Generate(io::Printer* printer) {
"// @@protoc_insertion_point(class_scope:$full_name$)\n",
"full_name", descriptor_->full_name());
// Carefully initialize the default instance in such a way that it doesn't
// conflict with other initialization.
printer->Print("private static final $classname$ defaultInstance =\n"
" new $classname$();\n"
printer->Print(
"private static final $classname$ defaultInstance;",
"classname", name_resolver_->GetImmutableClassName(descriptor_));
if (HasDescriptorMethods(descriptor_)) {
printer->Print(
"static {\n"
" defaultInstance = new $classname$();\n"
"}\n"
"\n",
"classname", name_resolver_->GetImmutableClassName(descriptor_));
} else {
// LITE_RUNTIME only has one constructor.
printer->Print(
"static {\n"
" try {\n"
" defaultInstance = new $classname$(\n"
" com.google.protobuf.Internal\n"
" .EMPTY_CODED_INPUT_STREAM,\n"
" com.google.protobuf.ExtensionRegistryLite\n"
" .getEmptyRegistry());\n"
" } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
" throw new ExceptionInInitializerError(e);\n"
" }\n"
"}\n"
"\n",
"classname", descriptor_->name());
}
printer->Print(
"public static $classname$ getDefaultInstance() {\n"
" return defaultInstance;\n"
"}\n"
......@@ -492,7 +516,7 @@ void ImmutableMessageGenerator::Generate(io::Printer* printer) {
" return defaultInstance;\n"
"}\n"
"\n",
"classname", descriptor_->name());
"classname", name_resolver_->GetImmutableClassName(descriptor_));
// Extensions must be declared after the defaultInstance is initialized
// because the defaultInstance is used by the extension to lazily retrieve
......@@ -773,6 +797,7 @@ void ImmutableMessageGenerator::GenerateBuilder(io::Printer* printer) {
if (HasGeneratedMethods(descriptor_)) {
GenerateIsInitialized(printer, DONT_MEMOIZE);
GenerateBuilderParsingMethods(printer);
}
// oneof
......@@ -1034,10 +1059,33 @@ GenerateCommonBuilderMethods(io::Printer* printer) {
"classname", name_resolver_->GetImmutableClassName(descriptor_));
}
printer->Print(
"public $classname$ buildPartial() {\n"
" $classname$ result = new $classname$(this);\n",
"classname", name_resolver_->GetImmutableClassName(descriptor_));
if (HasDescriptorMethods(descriptor_)) {
printer->Print(
"public $classname$ buildPartial() {\n"
" $classname$ result = new $classname$(this);\n",
"classname", name_resolver_->GetImmutableClassName(descriptor_));
} else {
// LITE_RUNTIME only provides a single message constructor.
printer->Print(
"public $classname$ buildPartial() {\n"
" $classname$ result = null;\n"
" try {\n"
" result = new $classname$(\n"
" com.google.protobuf.Internal\n"
" .EMPTY_CODED_INPUT_STREAM,\n"
" com.google.protobuf.ExtensionRegistryLite\n"
" .getEmptyRegistry());\n"
" } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
" throw new RuntimeException(e);\n"
" }\n"
" result.unknownFields = this.unknownFields;\n",
"classname", name_resolver_->GetImmutableClassName(descriptor_));
if (descriptor_->extension_range_count() > 0) {
printer->Print(
" result.extensions = this.buildExtensions();\n");
}
}
printer->Indent();
......@@ -1193,6 +1241,34 @@ GenerateCommonBuilderMethods(io::Printer* printer) {
// ===================================================================
void ImmutableMessageGenerator::
GenerateBuilderParsingMethods(io::Printer* printer) {
if (HasDescriptorMethods(descriptor_)) {
// LITE_RUNTIME implements this at the GeneratedMessageLite level.
printer->Print(
"public Builder mergeFrom(\n"
" com.google.protobuf.CodedInputStream input,\n"
" com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
" throws java.io.IOException {\n"
" $classname$ parsedMessage = null;\n"
" try {\n"
" parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);\n"
" } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
" parsedMessage = ($classname$) e.getUnfinishedMessage();\n"
" throw e;\n"
" } finally {\n"
" if (parsedMessage != null) {\n"
" mergeFrom(parsedMessage);\n"
" }\n"
" }\n"
" return this;\n"
"}\n",
"classname", name_resolver_->GetImmutableClassName(descriptor_));
}
}
// ===================================================================
void ImmutableMessageGenerator::GenerateIsInitialized(
io::Printer* printer, UseMemoization useMemoization) {
bool memoization = useMemoization == MEMOIZE;
......@@ -1488,8 +1564,13 @@ GenerateParsingConstructor(io::Printer* printer) {
printer->Indent();
// Initialize all fields to default.
printer->Print(
"this();\n");
if (HasDescriptorMethods(descriptor_)) {
printer->Print(
"this();\n");
} else {
// LITE_RUNTIME only has one constructor.
GenerateInitializers(printer);
}
// Use builder bits to track mutable repeated fields.
int totalBuilderBits = 0;
......@@ -1703,6 +1784,16 @@ void ImmutableMessageGenerator::GenerateParser(io::Printer* printer) {
"classname", descriptor_->name());
}
// ===================================================================
void ImmutableMessageGenerator::GenerateInitializers(io::Printer* printer) {
for (int i = 0; i < descriptor_->field_count(); i++) {
if (!descriptor_->field(i)->containing_oneof()) {
field_generators_.get(descriptor_->field(i))
.GenerateInitializationCode(printer);
}
}
}
} // namespace java
} // namespace compiler
......
......@@ -118,8 +118,10 @@ class ImmutableMessageGenerator : public MessageGenerator {
void GenerateBuilder(io::Printer* printer);
void GenerateCommonBuilderMethods(io::Printer* printer);
void GenerateDescriptorMethods(io::Printer* printer);
void GenerateBuilderParsingMethods(io::Printer* printer);
void GenerateIsInitialized(io::Printer* printer,
UseMemoization useMemoization);
void GenerateInitializers(io::Printer* printer);
void GenerateEqualsAndHashCode(io::Printer* printer);
void GenerateParser(io::Printer* printer);
void GenerateParsingConstructor(io::Printer* printer);
......
......@@ -1632,6 +1632,9 @@ bool Parser::ParseServiceMethod(MethodDescriptorProto* method,
// Parse input type.
DO(Consume("("));
{
if (TryConsume("stream")) {
method->set_client_streaming(true);
}
LocationRecorder location(method_location,
MethodDescriptorProto::kInputTypeFieldNumber);
location.RecordLegacyLocation(
......@@ -1644,6 +1647,9 @@ bool Parser::ParseServiceMethod(MethodDescriptorProto* method,
DO(Consume("returns"));
DO(Consume("("));
{
if (TryConsume("stream")) {
method->set_server_streaming(true);
}
LocationRecorder location(method_location,
MethodDescriptorProto::kOutputTypeFieldNumber);
location.RecordLegacyLocation(
......
......@@ -1843,6 +1843,13 @@ void MethodDescriptor::CopyTo(MethodDescriptorProto* proto) const {
if (&options() != &MethodOptions::default_instance()) {
proto->mutable_options()->CopyFrom(options());
}
if (client_streaming_) {
proto->set_client_streaming(true);
}
if (server_streaming_) {
proto->set_server_streaming(true);
}
}
// DebugString methods ===============================================
......@@ -2395,10 +2402,12 @@ void MethodDescriptor::DebugString(int depth, string *contents,
comment_printer(this, prefix, debug_string_options);
comment_printer.AddPreComment(contents);
strings::SubstituteAndAppend(contents, "$0rpc $1(.$2) returns (.$3)",
strings::SubstituteAndAppend(contents, "$0rpc $1($4.$2) returns ($5.$3)",
prefix, name(),
input_type()->full_name(),
output_type()->full_name());
output_type()->full_name(),
client_streaming() ? "stream " : "",
server_streaming() ? "stream " : "");
string formatted_options;
if (FormatLineOptions(depth, options(), &formatted_options)) {
......@@ -4393,6 +4402,9 @@ void DescriptorBuilder::BuildMethod(const MethodDescriptorProto& proto,
AllocateOptions(proto.options(), result);
}
result->client_streaming_ = proto.client_streaming();
result->server_streaming_ = proto.server_streaming();
AddSymbol(result->full_name(), parent, result->name(),
proto, Symbol(result));
}
......
......@@ -990,6 +990,11 @@ class LIBPROTOBUF_EXPORT MethodDescriptor {
// Gets the type of protocol message which this message produces as output.
const Descriptor* output_type() const;
// Gets whether the client streams multiple requests.
bool client_streaming() const;
// Gets whether the server streams multiple responses.
bool server_streaming() const;
// Get options for this method. These are specified in the .proto file by
// placing lines like "option foo = 1234;" in curly-braces after a method
// declaration. Allowed options are defined by MethodOptions in
......@@ -1031,6 +1036,8 @@ class LIBPROTOBUF_EXPORT MethodDescriptor {
const Descriptor* input_type_;
const Descriptor* output_type_;
const MethodOptions* options_;
bool client_streaming_;
bool server_streaming_;
// IMPORTANT: If you add a new field, make sure to search for all instances
// of Allocate<MethodDescriptor>() and AllocateArray<MethodDescriptor>() in
// descriptor.cc and update them to initialize the field.
......@@ -1623,6 +1630,9 @@ PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, service, const ServiceDescriptor*)
PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, input_type, const Descriptor*)
PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, output_type, const Descriptor*)
PROTOBUF_DEFINE_OPTIONS_ACCESSOR(MethodDescriptor, MethodOptions);
PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, client_streaming, bool)
PROTOBUF_DEFINE_ACCESSOR(MethodDescriptor, server_streaming, bool)
PROTOBUF_DEFINE_STRING_ACCESSOR(FileDescriptor, name)
PROTOBUF_DEFINE_STRING_ACCESSOR(FileDescriptor, package)
PROTOBUF_DEFINE_ACCESSOR(FileDescriptor, pool, const DescriptorPool*)
......
This diff is collapsed.
......@@ -1590,6 +1590,20 @@ class LIBPROTOBUF_EXPORT MethodDescriptorProto : public ::google::protobuf::Mess
inline ::google::protobuf::MethodOptions* release_options();
inline void set_allocated_options(::google::protobuf::MethodOptions* options);
// optional bool client_streaming = 5 [default = false];
inline bool has_client_streaming() const;
inline void clear_client_streaming();
static const int kClientStreamingFieldNumber = 5;
inline bool client_streaming() const;
inline void set_client_streaming(bool value);
// optional bool server_streaming = 6 [default = false];
inline bool has_server_streaming() const;
inline void clear_server_streaming();
static const int kServerStreamingFieldNumber = 6;
inline bool server_streaming() const;
inline void set_server_streaming(bool value);
// @@protoc_insertion_point(class_scope:google.protobuf.MethodDescriptorProto)
private:
inline void set_has_name();
......@@ -1600,6 +1614,10 @@ class LIBPROTOBUF_EXPORT MethodDescriptorProto : public ::google::protobuf::Mess
inline void clear_has_output_type();
inline void set_has_options();
inline void clear_has_options();
inline void set_has_client_streaming();
inline void clear_has_client_streaming();
inline void set_has_server_streaming();
inline void clear_has_server_streaming();
::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_;
::google::protobuf::uint32 _has_bits_[1];
......@@ -1608,6 +1626,8 @@ class LIBPROTOBUF_EXPORT MethodDescriptorProto : public ::google::protobuf::Mess
::google::protobuf::internal::ArenaStringPtr input_type_;
::google::protobuf::internal::ArenaStringPtr output_type_;
::google::protobuf::MethodOptions* options_;
bool client_streaming_;
bool server_streaming_;
friend void LIBPROTOBUF_EXPORT protobuf_AddDesc_google_2fprotobuf_2fdescriptor_2eproto();
friend void protobuf_AssignDesc_google_2fprotobuf_2fdescriptor_2eproto();
friend void protobuf_ShutdownFile_google_2fprotobuf_2fdescriptor_2eproto();
......@@ -4968,6 +4988,54 @@ inline void MethodDescriptorProto::set_allocated_options(::google::protobuf::Met
// @@protoc_insertion_point(field_set_allocated:google.protobuf.MethodDescriptorProto.options)
}
// optional bool client_streaming = 5 [default = false];
inline bool MethodDescriptorProto::has_client_streaming() const {
return (_has_bits_[0] & 0x00000010u) != 0;
}
inline void MethodDescriptorProto::set_has_client_streaming() {
_has_bits_[0] |= 0x00000010u;
}
inline void MethodDescriptorProto::clear_has_client_streaming() {
_has_bits_[0] &= ~0x00000010u;
}
inline void MethodDescriptorProto::clear_client_streaming() {
client_streaming_ = false;
clear_has_client_streaming();
}
inline bool MethodDescriptorProto::client_streaming() const {
// @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.client_streaming)
return client_streaming_;
}
inline void MethodDescriptorProto::set_client_streaming(bool value) {
set_has_client_streaming();
client_streaming_ = value;
// @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.client_streaming)
}
// optional bool server_streaming = 6 [default = false];
inline bool MethodDescriptorProto::has_server_streaming() const {
return (_has_bits_[0] & 0x00000020u) != 0;
}
inline void MethodDescriptorProto::set_has_server_streaming() {
_has_bits_[0] |= 0x00000020u;
}
inline void MethodDescriptorProto::clear_has_server_streaming() {
_has_bits_[0] &= ~0x00000020u;
}
inline void MethodDescriptorProto::clear_server_streaming() {
server_streaming_ = false;
clear_has_server_streaming();
}
inline bool MethodDescriptorProto::server_streaming() const {
// @@protoc_insertion_point(field_get:google.protobuf.MethodDescriptorProto.server_streaming)
return server_streaming_;
}
inline void MethodDescriptorProto::set_server_streaming(bool value) {
set_has_server_streaming();
server_streaming_ = value;
// @@protoc_insertion_point(field_set:google.protobuf.MethodDescriptorProto.server_streaming)
}
// -------------------------------------------------------------------
// FileOptions
......
......@@ -220,6 +220,11 @@ message MethodDescriptorProto {
optional string output_type = 3;
optional MethodOptions options = 4;
// Identifies if client streams multiple client messages
optional bool client_streaming = 5 [default=false];
// Identifies if server streams multiple server messages
optional bool server_streaming = 6 [default=false];
}
......
......@@ -127,8 +127,10 @@ class LIBPROTOBUF_EXPORT ArrayOutputStream : public ZeroCopyOutputStream {
class LIBPROTOBUF_EXPORT StringOutputStream : public ZeroCopyOutputStream {
public:
// Create a StringOutputStream which appends bytes to the given string.
// The string remains property of the caller, but it MUST NOT be accessed
// in any way until the stream is destroyed.
// The string remains property of the caller, but it is mutated in arbitrary
// ways and MUST NOT be accessed in any way until you're done with the
// stream. Either be sure there's no further usage, or (safest) destroy the
// stream before using the contents.
//
// Hint: If you call target->reserve(n) before creating the stream,
// the first call to Next() will return at least n bytes of buffer
......
......@@ -31,10 +31,10 @@
#ifndef GOOGLE_PROTOBUF_MAP_H__
#define GOOGLE_PROTOBUF_MAP_H__
#include <vector>
#include <iterator>
#include <google/protobuf/stubs/hash.h>
#include <google/protobuf/map_type_handler.h>
#include <google/protobuf/stubs/hash.h>
namespace google {
namespace protobuf {
......@@ -48,7 +48,7 @@ namespace internal {
template <typename K, typename V, FieldDescriptor::Type KeyProto,
FieldDescriptor::Type ValueProto, int default_enum_value>
class MapField;
}
} // namespace internal
// This is the class for google::protobuf::Map's internal value_type. Instead of using
// std::pair as value_type, we use this class which provides us more control of
......@@ -62,7 +62,7 @@ class MapPair {
MapPair(const Key& other_first, const T& other_second)
: first(other_first), second(other_second) {}
MapPair(const Key& other_first) : first(other_first), second() {}
explicit MapPair(const Key& other_first) : first(other_first), second() {}
MapPair(const MapPair& other)
: first(other.first), second(other.second) {}
......@@ -82,52 +82,13 @@ class MapPair {
friend class Map<Key, T>;
};
// STL-like iterator implementation for google::protobuf::Map. Users should not refer to
// this class directly; use google::protobuf::Map<Key, T>::iterator instead.
template <typename Key, typename T>
class MapIterator {
public:
typedef MapPair<Key, T> value_type;
typedef value_type* pointer;
typedef value_type& reference;
typedef MapIterator iterator;
// constructor
MapIterator(const typename hash_map<Key, value_type*>::iterator& it)
: it_(it) {}
MapIterator(const MapIterator& other) : it_(other.it_) {}
MapIterator& operator=(const MapIterator& other) {
it_ = other.it_;
return *this;
}
// deferenceable
reference operator*() const { return *it_->second; }
pointer operator->() const { return it_->second; }
// incrementable
iterator& operator++() {
++it_;
return *this;
}
iterator operator++(int) { return iterator(it_++); }
// equality_comparable
bool operator==(const iterator& x) const { return it_ == x.it_; }
bool operator!=(const iterator& x) const { return it_ != x.it_; }
private:
typename hash_map<Key, value_type*>::iterator it_;
friend class Map<Key, T>;
};
// google::protobuf::Map is an associative container type used to store protobuf map
// fields. Its interface is similar to std::unordered_map. Users should use this
// interface directly to visit or change map fields.
template <typename Key, typename T>
class Map {
typedef internal::MapCppTypeHandler<T> ValueTypeHandler;
public:
typedef Key key_type;
typedef T mapped_type;
......@@ -138,9 +99,6 @@ class Map {
typedef value_type& reference;
typedef const value_type& const_reference;
typedef MapIterator<Key, T> iterator;
typedef MapIterator<Key, T> const_iterator;
typedef size_t size_type;
typedef hash<Key> hasher;
......@@ -153,16 +111,70 @@ class Map {
~Map() { clear(); }
// Iterators
class LIBPROTOBUF_EXPORT const_iterator
: public std::iterator<std::forward_iterator_tag, value_type, ptrdiff_t,
const value_type*, const value_type&> {
typedef typename hash_map<Key, value_type*>::const_iterator InnerIt;
public:
const_iterator() {}
explicit const_iterator(const InnerIt& it) : it_(it) {}
const_reference operator*() const { return *it_->second; }
const_pointer operator->() const { return it_->second; }
const_iterator& operator++() {
++it_;
return *this;
}
const_iterator operator++(int) { return const_iterator(it_++); }
friend bool operator==(const const_iterator& a, const const_iterator& b) {
return a.it_ == b.it_;
}
friend bool operator!=(const const_iterator& a, const const_iterator& b) {
return a.it_ != b.it_;
}
private:
InnerIt it_;
};
class LIBPROTOBUF_EXPORT iterator : public std::iterator<std::forward_iterator_tag, value_type> {
typedef typename hash_map<Key, value_type*>::iterator InnerIt;
public:
iterator() {}
explicit iterator(const InnerIt& it) : it_(it) {}
reference operator*() const { return *it_->second; }
pointer operator->() const { return it_->second; }
iterator& operator++() {
++it_;
return *this;
}
iterator operator++(int) { return iterator(it_++); }
// Implicitly convertible to const_iterator.
operator const_iterator() const { return const_iterator(it_); }
friend bool operator==(const iterator& a, const iterator& b) {
return a.it_ == b.it_;
}
friend bool operator!=(const iterator& a, const iterator& b) {
return a.it_ != b.it_;
}
private:
friend class Map;
InnerIt it_;
};
iterator begin() { return iterator(elements_.begin()); }
iterator end() { return iterator(elements_.end()); }
const_iterator begin() const {
return const_iterator(
const_cast<hash_map<Key, value_type*>&>(elements_).begin());
}
const_iterator end() const {
return const_iterator(
const_cast<hash_map<Key, value_type*>&>(elements_).end());
}
const_iterator begin() const { return const_iterator(elements_.begin()); }
const_iterator end() const { return const_iterator(elements_.end()); }
const_iterator cbegin() const { return begin(); }
const_iterator cend() const { return end(); }
......@@ -197,16 +209,30 @@ class Map {
return elements_.count(key);
}
const_iterator find(const key_type& key) const {
// When elements_ is a const instance, find(key) returns a const iterator.
// However, to reduce code complexity, we use MapIterator for Map's both
// const and non-const iterator, which only takes non-const iterator to
// construct.
return const_iterator(
const_cast<hash_map<Key, value_type*>&>(elements_).find(key));
return const_iterator(elements_.find(key));
}
iterator find(const key_type& key) {
return iterator(elements_.find(key));
}
std::pair<const_iterator, const_iterator> equal_range(
const key_type& key) const {
const_iterator it = find(key);
if (it == end()) {
return std::pair<const_iterator, const_iterator>(it, it);
} else {
const_iterator begin = it++;
return std::pair<const_iterator, const_iterator>(begin, it);
}
}
std::pair<iterator, iterator> equal_range(const key_type& key) {
iterator it = find(key);
if (it == end()) {
return std::pair<iterator, iterator>(it, it);
} else {
iterator begin = it++;
return std::pair<iterator, iterator>(begin, it);
}
}
// insert
std::pair<iterator, bool> insert(const value_type& value) {
......@@ -214,8 +240,9 @@ class Map {
if (it != end()) {
return std::pair<iterator, bool>(it, false);
} else {
return elements_.insert(
std::pair<Key, value_type*>(value.first, new value_type(value)));
return std::pair<iterator, bool>(
iterator(elements_.insert(std::pair<Key, value_type*>(
value.first, new value_type(value))).first), true);
}
}
template <class InputIt>
......@@ -258,7 +285,10 @@ class Map {
// Assign
Map& operator=(const Map& other) {
insert(other.begin(), other.end());
if (this != &other) {
clear();
insert(other.begin(), other.end());
}
return *this;
}
......
......@@ -237,6 +237,57 @@ TEST_F(MapImplTest, GetReferenceFromIterator) {
}
}
TEST_F(MapImplTest, IteratorBasic) {
map_[0] = 0;
// Default constructible (per forward iterator requirements).
Map<int, int>::const_iterator cit;
Map<int, int>::iterator it;
it = map_.begin();
cit = it; // Converts to const_iterator
// Can compare between them.
EXPECT_TRUE(it == cit);
EXPECT_FALSE(cit != it);
// Pre increment.
EXPECT_FALSE(it == ++cit);
// Post increment.
EXPECT_FALSE(it++ == cit);
EXPECT_TRUE(it == cit);
}
template <typename T>
bool IsConstHelper(T& /*t*/) { // NOLINT. We want to catch non-const refs here.
return false;
}
template <typename T>
bool IsConstHelper(const T& /*t*/) {
return true;
}
TEST_F(MapImplTest, IteratorConstness) {
map_[0] = 0;
EXPECT_TRUE(IsConstHelper(*map_.cbegin()));
EXPECT_TRUE(IsConstHelper(*const_map_.begin()));
EXPECT_FALSE(IsConstHelper(*map_.begin()));
}
bool IsForwardIteratorHelper(std::forward_iterator_tag /*tag*/) { return true; }
template <typename T>
bool IsForwardIteratorHelper(T /*t*/) {
return false;
}
TEST_F(MapImplTest, IteratorCategory) {
EXPECT_TRUE(IsForwardIteratorHelper(
std::iterator_traits<Map<int, int>::iterator>::iterator_category()));
EXPECT_TRUE(IsForwardIteratorHelper(std::iterator_traits<
Map<int, int>::const_iterator>::iterator_category()));
}
TEST_F(MapImplTest, InsertSingle) {
int32 key = 0;
int32 value1 = 100;
......@@ -433,11 +484,23 @@ TEST_F(MapImplTest, Assigner) {
map_.insert(map.begin(), map.end());
Map<int32, int32> other;
int32 key_other = 123;
int32 value_other = 321;
other[key_other] = value_other;
EXPECT_EQ(1, other.size());
other = map_;
EXPECT_EQ(2, other.size());
EXPECT_EQ(value1, other.at(key1));
EXPECT_EQ(value2, other.at(key2));
EXPECT_TRUE(other.find(key_other) == other.end());
// Self assign
other = other;
EXPECT_EQ(2, other.size());
EXPECT_EQ(value1, other.at(key1));
EXPECT_EQ(value2, other.at(key2));
}
TEST_F(MapImplTest, Rehash) {
......@@ -457,6 +520,30 @@ TEST_F(MapImplTest, Rehash) {
EXPECT_TRUE(map_.empty());
}
TEST_F(MapImplTest, EqualRange) {
int key = 100, key_missing = 101;
map_[key] = 100;
std::pair<google::protobuf::Map<int32, int32>::iterator,
google::protobuf::Map<int32, int32>::iterator> range = map_.equal_range(key);
EXPECT_TRUE(map_.find(key) == range.first);
EXPECT_TRUE(++map_.find(key) == range.second);
range = map_.equal_range(key_missing);
EXPECT_TRUE(map_.end() == range.first);
EXPECT_TRUE(map_.end() == range.second);
std::pair<google::protobuf::Map<int32, int32>::const_iterator,
google::protobuf::Map<int32, int32>::const_iterator> const_range =
const_map_.equal_range(key);
EXPECT_TRUE(const_map_.find(key) == const_range.first);
EXPECT_TRUE(++const_map_.find(key) == const_range.second);
const_range = const_map_.equal_range(key_missing);
EXPECT_TRUE(const_map_.end() == const_range.first);
EXPECT_TRUE(const_map_.end() == const_range.second);
}
// Map Field Reflection Test ========================================
static int Func(int i, int j) {
......@@ -879,15 +966,14 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
entry_int32_double.get(), fd_map_int32_double->message_type()->field(1),
Func(key, -2));
entry_string_string->GetReflection()->SetString(
entry_string_string.get(), fd_map_string_string->message_type()->field(0),
StrFunc(key, 1));
entry_string_string.get(),
fd_map_string_string->message_type()->field(0), StrFunc(key, 1));
entry_string_string->GetReflection()->SetString(
entry_string_string.get(), fd_map_string_string->message_type()->field(1),
StrFunc(key, -5));
entry_string_string.get(),
fd_map_string_string->message_type()->field(1), StrFunc(key, -5));
entry_int32_foreign_message->GetReflection()->SetInt32(
entry_int32_foreign_message.get(),
fd_map_int32_foreign_message->message_type()->field(0),
key);
fd_map_int32_foreign_message->message_type()->field(0), key);
Message* value_message =
entry_int32_foreign_message->GetReflection()->MutableMessage(
entry_int32_foreign_message.get(),
......@@ -896,10 +982,10 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefForRegularFields) {
value_message, value_message->GetDescriptor()->FindFieldByName("c"),
Func(key, -6));
mmf_int32_int32.Set(i, *entry_int32_int32.get());
mmf_int32_double.Set(i, *entry_int32_double.get());
mmf_string_string.Set(i, *entry_string_string.get());
mmf_int32_foreign_message.Set(i, *entry_int32_foreign_message.get());
mmf_int32_int32.Set(i, *entry_int32_int32);
mmf_int32_double.Set(i, *entry_int32_double);
mmf_string_string.Set(i, *entry_string_string);
mmf_int32_foreign_message.Set(i, *entry_int32_foreign_message);
}
for (int i = 0; i < 10; i++) {
......
......@@ -57,10 +57,14 @@ const int WireFormatLite::kMessageSetMessageTag;
// IBM xlC requires prefixing constants with WireFormatLite::
const int WireFormatLite::kMessageSetItemTagsSize =
io::CodedOutputStream::StaticVarintSize32<WireFormatLite::kMessageSetItemStartTag>::value +
io::CodedOutputStream::StaticVarintSize32<WireFormatLite::kMessageSetItemEndTag>::value +
io::CodedOutputStream::StaticVarintSize32<WireFormatLite::kMessageSetTypeIdTag>::value +
io::CodedOutputStream::StaticVarintSize32<WireFormatLite::kMessageSetMessageTag>::value;
io::CodedOutputStream::StaticVarintSize32<
WireFormatLite::kMessageSetItemStartTag>::value +
io::CodedOutputStream::StaticVarintSize32<
WireFormatLite::kMessageSetItemEndTag>::value +
io::CodedOutputStream::StaticVarintSize32<
WireFormatLite::kMessageSetTypeIdTag>::value +
io::CodedOutputStream::StaticVarintSize32<
WireFormatLite::kMessageSetMessageTag>::value;
const WireFormatLite::CppType
WireFormatLite::kFieldTypeToCppTypeMap[MAX_FIELD_TYPE + 1] = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment