Commit 57772592 authored by Feng Xiao's avatar Feng Xiao

Cherry-pick cl/152450543

parent cad0258d
......@@ -67,6 +67,13 @@ size_t MapFieldBase::SpaceUsedExcludingSelfNoLock() const {
}
}
bool MapFieldBase::IsMapValid() const {
// "Acquire" insures the operation after SyncRepeatedFieldWithMap won't get
// executed before state_ is checked.
Atomic32 state = google::protobuf::internal::Acquire_Load(&state_);
return state != STATE_MODIFIED_REPEATED;
}
void MapFieldBase::SetMapDirty() { state_ = STATE_MODIFIED_MAP; }
void MapFieldBase::SetRepeatedDirty() { state_ = STATE_MODIFIED_REPEATED; }
......@@ -359,6 +366,13 @@ void DynamicMapField::SyncMapWithRepeatedFieldNoLock() const {
GOOGLE_LOG(FATAL) << "Can't get here.";
break;
}
// Remove existing map value with same key.
Map<MapKey, MapValueRef>::iterator iter = map->find(map_key);
if (iter != map->end()) {
iter->second.DeleteData();
}
MapValueRef& map_val = (*map)[map_key];
map_val.SetType(val_des->cpp_type());
switch (val_des->cpp_type()) {
......
......@@ -86,6 +86,8 @@ class LIBPROTOBUF_EXPORT MapFieldBase {
virtual bool ContainsMapKey(const MapKey& map_key) const = 0;
virtual bool InsertOrLookupMapValue(
const MapKey& map_key, MapValueRef* val) = 0;
// Insures operations after won't get executed before calling this.
bool IsMapValid() const;
virtual bool DeleteMapValue(const MapKey& map_key) = 0;
virtual bool EqualIterator(const MapIterator& a,
const MapIterator& b) const = 0;
......
......@@ -975,6 +975,11 @@ static int Int(const string& value) {
class MapFieldReflectionTest : public testing::Test {
protected:
typedef FieldDescriptor FD;
int MapSize(const Reflection* reflection, const FieldDescriptor* field,
const Message& message) {
return reflection->MapSize(message, field);
}
};
TEST_F(MapFieldReflectionTest, RegularFields) {
......@@ -1782,6 +1787,50 @@ TEST_F(MapFieldReflectionTest, RepeatedFieldRefMergeFromAndSwap) {
// TODO(teboring): add test for duplicated key
}
TEST_F(MapFieldReflectionTest, MapSizeWithDuplicatedKey) {
// Dynamic Message
{
DynamicMessageFactory factory;
google::protobuf::scoped_ptr<Message> message(
factory.GetPrototype(unittest::TestMap::descriptor())->New());
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* field =
unittest::TestMap::descriptor()->FindFieldByName("map_int32_int32");
Message* entry1 = reflection->AddMessage(message.get(), field);
Message* entry2 = reflection->AddMessage(message.get(), field);
const Reflection* entry_reflection = entry1->GetReflection();
const FieldDescriptor* key_field =
entry1->GetDescriptor()->FindFieldByName("key");
entry_reflection->SetInt32(entry1, key_field, 1);
entry_reflection->SetInt32(entry2, key_field, 1);
EXPECT_EQ(2, reflection->FieldSize(*message, field));
EXPECT_EQ(1, MapSize(reflection, field, *message));
}
// Generated Message
{
unittest::TestMap message;
const Reflection* reflection = message.GetReflection();
const FieldDescriptor* field =
message.GetDescriptor()->FindFieldByName("map_int32_int32");
Message* entry1 = reflection->AddMessage(&message, field);
Message* entry2 = reflection->AddMessage(&message, field);
const Reflection* entry_reflection = entry1->GetReflection();
const FieldDescriptor* key_field =
entry1->GetDescriptor()->FindFieldByName("key");
entry_reflection->SetInt32(entry1, key_field, 1);
entry_reflection->SetInt32(entry2, key_field, 1);
EXPECT_EQ(2, reflection->FieldSize(message, field));
EXPECT_EQ(1, MapSize(reflection, field, message));
}
}
// Generated Message Test ===========================================
TEST(GeneratedMapFieldTest, Accessors) {
......@@ -2689,6 +2738,69 @@ TEST_F(MapFieldInDynamicMessageTest, RecursiveMap) {
ASSERT_TRUE(to->ParseFromString(data));
}
TEST_F(MapFieldInDynamicMessageTest, MapValueReferernceValidAfterSerialize) {
google::protobuf::scoped_ptr<Message> message(map_prototype_->New());
MapReflectionTester reflection_tester(map_descriptor_);
reflection_tester.SetMapFieldsViaMapReflection(message.get());
// Get value reference before serialization, so that we know the value is from
// map.
MapKey map_key;
MapValueRef map_val;
map_key.SetInt32Value(0);
reflection_tester.GetMapValueViaMapReflection(
message.get(), "map_int32_foreign_message", map_key, &map_val);
Message* submsg = map_val.MutableMessageValue();
// In previous implementation, calling SerializeToString will cause syncing
// from map to repeated field, which will invalidate the submsg we previously
// got.
string data;
message->SerializeToString(&data);
const Reflection* submsg_reflection = submsg->GetReflection();
const Descriptor* submsg_desc = submsg->GetDescriptor();
const FieldDescriptor* submsg_field = submsg_desc->FindFieldByName("c");
submsg_reflection->SetInt32(submsg, submsg_field, 128);
message->SerializeToString(&data);
TestMap to;
to.ParseFromString(data);
EXPECT_EQ(128, to.map_int32_foreign_message().at(0).c());
}
TEST_F(MapFieldInDynamicMessageTest, MapEntryReferernceValidAfterSerialize) {
google::protobuf::scoped_ptr<Message> message(map_prototype_->New());
MapReflectionTester reflection_tester(map_descriptor_);
reflection_tester.SetMapFieldsViaReflection(message.get());
// Get map entry before serialization, so that we know the it is from
// repeated field.
Message* map_entry = reflection_tester.GetMapEntryViaReflection(
message.get(), "map_int32_foreign_message", 0);
const Reflection* map_entry_reflection = map_entry->GetReflection();
const Descriptor* map_entry_desc = map_entry->GetDescriptor();
const FieldDescriptor* value_field = map_entry_desc->FindFieldByName("value");
Message* submsg =
map_entry_reflection->MutableMessage(map_entry, value_field);
// In previous implementation, calling SerializeToString will cause syncing
// from repeated field to map, which will invalidate the map_entry we
// previously got.
string data;
message->SerializeToString(&data);
const Reflection* submsg_reflection = submsg->GetReflection();
const Descriptor* submsg_desc = submsg->GetDescriptor();
const FieldDescriptor* submsg_field = submsg_desc->FindFieldByName("c");
submsg_reflection->SetInt32(submsg, submsg_field, 128);
message->SerializeToString(&data);
TestMap to;
to.ParseFromString(data);
EXPECT_EQ(128, to.map_int32_foreign_message().at(0).c());
}
// ReflectionOps Test ===============================================
TEST(ReflectionOpsForMapFieldTest, MapSanityCheck) {
......@@ -2751,6 +2863,20 @@ TEST(ReflectionOpsForMapFieldTest, MapDiscardUnknownFields) {
GetUnknownFields(message).field_count());
}
TEST(ReflectionOpsForMapFieldTest, IsInitialized) {
unittest::TestRequiredMessageMap map_message;
// Add an uninitialized message.
(*map_message.mutable_map_field())[0];
EXPECT_FALSE(ReflectionOps::IsInitialized(map_message));
// Initialize uninitialized message
(*map_message.mutable_map_field())[0].set_a(0);
(*map_message.mutable_map_field())[0].set_b(0);
(*map_message.mutable_map_field())[0].set_c(0);
EXPECT_TRUE(ReflectionOps::IsInitialized(map_message));
}
// Wire Format Test =================================================
TEST(WireFormatForMapFieldTest, ParseMap) {
......@@ -3089,7 +3215,7 @@ TEST(ArenaTest, ParsingAndSerializingNoHeapAllocation) {
}
// Use text format parsing and serializing to test reflection api.
TEST(ArenaTest, RelfectionInTextFormat) {
TEST(ArenaTest, ReflectionInTextFormat) {
Arena arena;
string data;
......
......@@ -744,6 +744,22 @@ void MapReflectionTester::SetMapFieldsViaMapReflection(
sub_foreign_message, foreign_c_, 1);
}
void MapReflectionTester::GetMapValueViaMapReflection(Message* message,
const string& field_name,
const MapKey& map_key,
MapValueRef* map_val) {
const Reflection* reflection = message->GetReflection();
EXPECT_FALSE(reflection->InsertOrLookupMapValue(message, F(field_name),
map_key, map_val));
}
Message* MapReflectionTester::GetMapEntryViaReflection(Message* message,
const string& field_name,
int index) {
const Reflection* reflection = message->GetReflection();
return reflection->MutableRepeatedMessage(message, F(field_name), index);
}
void MapReflectionTester::ClearMapFieldsViaReflection(
Message* message) {
const Reflection* reflection = message->GetReflection();
......
......@@ -106,6 +106,11 @@ class MapReflectionTester {
void ExpectClearViaReflection(const Message& message);
void ExpectClearViaReflectionIterator(Message* message);
void ExpectMapEntryClearViaReflection(Message* message);
void GetMapValueViaMapReflection(Message* message,
const string& field_name,
const MapKey& map_key, MapValueRef* map_val);
Message* GetMapEntryViaReflection(Message* message, const string& field_name,
int index);
private:
const FieldDescriptor* F(const string& name);
......
......@@ -154,6 +154,13 @@ class MapReflectionFriend; // scalar_map_container.h
}
namespace internal {
class ReflectionOps; // reflection_ops.h
class MapKeySorter; // wire_format.cc
class WireFormat; // wire_format.h
class MapFieldReflectionTest; // map_test.cc
}
template<typename T>
class RepeatedField; // repeated_field.h
......@@ -936,6 +943,10 @@ class LIBPROTOBUF_EXPORT Reflection {
template<typename T, typename Enable>
friend class MutableRepeatedFieldRef;
friend class ::google::protobuf::python::MapReflectionFriend;
friend class internal::MapFieldReflectionTest;
friend class internal::MapKeySorter;
friend class internal::WireFormat;
friend class internal::ReflectionOps;
// Special version for specialized implementations of string. We can't call
// MutableRawRepeatedField directly here because we don't have access to
......
......@@ -38,6 +38,7 @@
#include <google/protobuf/reflection_ops.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/map_field.h>
#include <google/protobuf/unknown_field_set.h>
#include <google/protobuf/stubs/strutil.h>
......@@ -158,6 +159,27 @@ bool ReflectionOps::IsInitialized(const Message& message) {
const FieldDescriptor* field = fields[i];
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
if (field->is_map()) {
const FieldDescriptor* value_field = field->message_type()->field(1);
if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
MapFieldBase* map_field =
reflection->MapData(const_cast<Message*>(&message), field);
if (map_field->IsMapValid()) {
MapIterator iter(const_cast<Message*>(&message), field);
MapIterator end(const_cast<Message*>(&message), field);
for (map_field->MapBegin(&iter), map_field->MapEnd(&end);
iter != end; ++iter) {
if (!iter.GetValueRef().GetMessageValue().IsInitialized()) {
return false;
}
}
continue;
}
} else {
continue;
}
}
if (field->is_repeated()) {
int size = reflection->FieldSize(message, field);
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment