Commit df44ae44 authored by Jon Skeet's avatar Jon Skeet

More map tests, and various production code improvements.

Generated code in next commit.
parent e36e601a
This diff is collapsed.
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using Google.Protobuf.Collections;
using Google.Protobuf.TestProtos; using Google.Protobuf.TestProtos;
using NUnit.Framework; using NUnit.Framework;
namespace Google.Protobuf namespace Google.Protobuf.Collections
{ {
public class RepeatedFieldTest public class RepeatedFieldTest
{ {
......
#region Copyright notice and license
// Protocol Buffers - Google's data interchange format
// Copyright 2015 Google Inc. All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using NUnit.Framework;
namespace Google.Protobuf
{
/// <summary>
/// Helper methods when testing equality. NUnit's Assert.AreEqual and
/// Assert.AreNotEqual methods try to be clever with collections, which can
/// be annoying...
/// </summary>
internal static class EqualityTester
{
public static void AssertEquality<T>(T first, T second) where T : IEquatable<T>
{
Assert.IsTrue(first.Equals(second));
Assert.AreEqual(first.GetHashCode(), second.GetHashCode());
}
public static void AssertInequality<T>(T first, T second) where T : IEquatable<T>
{
Assert.IsFalse(first.Equals(second));
// While this isn't a requirement, the chances of this test failing due to
// coincidence rather than a bug are very small.
Assert.AreNotEqual(first.GetHashCode(), second.GetHashCode());
}
}
}
...@@ -11,6 +11,15 @@ namespace Google.Protobuf ...@@ -11,6 +11,15 @@ namespace Google.Protobuf
/// </summary> /// </summary>
public class GeneratedMessageTest public class GeneratedMessageTest
{ {
[Test]
public void EmptyMessageFieldDistinctFromMissingMessageField()
{
// This demonstrates what we're really interested in...
var message1 = new TestAllTypes { SingleForeignMessage = new ForeignMessage() };
var message2 = new TestAllTypes(); // SingleForeignMessage is null
EqualityTester.AssertInequality(message1, message2);
}
[Test] [Test]
public void DefaultValues() public void DefaultValues()
{ {
......
...@@ -74,8 +74,10 @@ ...@@ -74,8 +74,10 @@
<Compile Include="ByteStringTest.cs" /> <Compile Include="ByteStringTest.cs" />
<Compile Include="CodedInputStreamTest.cs" /> <Compile Include="CodedInputStreamTest.cs" />
<Compile Include="CodedOutputStreamTest.cs" /> <Compile Include="CodedOutputStreamTest.cs" />
<Compile Include="EqualityTester.cs" />
<Compile Include="GeneratedMessageTest.cs" /> <Compile Include="GeneratedMessageTest.cs" />
<Compile Include="RepeatedFieldTest.cs" /> <Compile Include="Collections\MapFieldTest.cs" />
<Compile Include="Collections\RepeatedFieldTest.cs" />
<Compile Include="TestProtos\UnittestImportProto3.cs" /> <Compile Include="TestProtos\UnittestImportProto3.cs" />
<Compile Include="TestProtos\UnittestImportPublicProto3.cs" /> <Compile Include="TestProtos\UnittestImportPublicProto3.cs" />
<Compile Include="TestProtos\UnittestIssues.cs" /> <Compile Include="TestProtos\UnittestIssues.cs" />
...@@ -99,9 +101,7 @@ ...@@ -99,9 +101,7 @@
<ItemGroup> <ItemGroup>
<Service Include="{82A7F48D-3B50-4B1E-B82E-3ADA8210C358}" /> <Service Include="{82A7F48D-3B50-4B1E-B82E-3ADA8210C358}" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup />
<Folder Include="Collections\" />
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" /> <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<!-- To modify your build process, add your task inside one of the targets below and uncomment it. <!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets. Other similar extension points exist, see Microsoft.Common.targets.
......
...@@ -456,14 +456,16 @@ namespace Google.Protobuf ...@@ -456,14 +456,16 @@ namespace Google.Protobuf
} }
/// <summary> /// <summary>
/// Returns true if the next tag is also part of the same unpacked array. /// Peeks at the next tag in the stream. If it matches <paramref name="tag"/>,
/// the tag is consumed and the method returns <c>true</c>; otherwise, the
/// stream is left in the original position and the method returns <c>false</c>.
/// </summary> /// </summary>
private bool ContinueArray(uint currentTag) public bool MaybeConsumeTag(uint tag)
{ {
uint next; uint next;
if (PeekNextTag(out next)) if (PeekNextTag(out next))
{ {
if (next == currentTag) if (next == tag)
{ {
hasNextTag = false; hasNextTag = false;
return true; return true;
...@@ -486,17 +488,7 @@ namespace Google.Protobuf ...@@ -486,17 +488,7 @@ namespace Google.Protobuf
} }
return true; return true;
} }
return MaybeConsumeTag(currentTag);
uint next;
if (PeekNextTag(out next))
{
if (next == currentTag)
{
hasNextTag = false;
return true;
}
}
return false;
} }
/// <summary> /// <summary>
...@@ -512,7 +504,7 @@ namespace Google.Protobuf ...@@ -512,7 +504,7 @@ namespace Google.Protobuf
do do
{ {
list.Add(ReadString()); list.Add(ReadString());
} while (ContinueArray(fieldTag)); } while (MaybeConsumeTag(fieldTag));
} }
public void ReadBytesArray(ICollection<ByteString> list) public void ReadBytesArray(ICollection<ByteString> list)
...@@ -521,7 +513,7 @@ namespace Google.Protobuf ...@@ -521,7 +513,7 @@ namespace Google.Protobuf
do do
{ {
list.Add(ReadBytes()); list.Add(ReadBytes());
} while (ContinueArray(fieldTag)); } while (MaybeConsumeTag(fieldTag));
} }
public void ReadBoolArray(ICollection<bool> list) public void ReadBoolArray(ICollection<bool> list)
...@@ -729,7 +721,7 @@ namespace Google.Protobuf ...@@ -729,7 +721,7 @@ namespace Google.Protobuf
do do
{ {
list.Add((T)(object) ReadEnum()); list.Add((T)(object) ReadEnum());
} while (ContinueArray(fieldTag)); } while (MaybeConsumeTag(fieldTag));
} }
} }
...@@ -742,7 +734,7 @@ namespace Google.Protobuf ...@@ -742,7 +734,7 @@ namespace Google.Protobuf
T message = messageParser.CreateTemplate(); T message = messageParser.CreateTemplate();
ReadMessage(message); ReadMessage(message);
list.Add(message); list.Add(message);
} while (ContinueArray(fieldTag)); } while (MaybeConsumeTag(fieldTag));
} }
#endregion #endregion
......
...@@ -77,8 +77,7 @@ namespace Google.Protobuf.Collections ...@@ -77,8 +77,7 @@ namespace Google.Protobuf.Collections
public void Add(TKey key, TValue value) public void Add(TKey key, TValue value)
{ {
ThrowHelper.ThrowIfNull(key, "key"); // Validation of arguments happens in ContainsKey and the indexer
this.CheckMutable();
if (ContainsKey(key)) if (ContainsKey(key))
{ {
throw new ArgumentException("Key already exists in map", "key"); throw new ArgumentException("Key already exists in map", "key");
...@@ -88,12 +87,14 @@ namespace Google.Protobuf.Collections ...@@ -88,12 +87,14 @@ namespace Google.Protobuf.Collections
public bool ContainsKey(TKey key) public bool ContainsKey(TKey key)
{ {
ThrowHelper.ThrowIfNull(key, "key");
return map.ContainsKey(key); return map.ContainsKey(key);
} }
public bool Remove(TKey key) public bool Remove(TKey key)
{ {
this.CheckMutable(); this.CheckMutable();
ThrowHelper.ThrowIfNull(key, "key");
LinkedListNode<KeyValuePair<TKey, TValue>> node; LinkedListNode<KeyValuePair<TKey, TValue>> node;
if (map.TryGetValue(key, out node)) if (map.TryGetValue(key, out node))
{ {
...@@ -126,6 +127,7 @@ namespace Google.Protobuf.Collections ...@@ -126,6 +127,7 @@ namespace Google.Protobuf.Collections
{ {
get get
{ {
ThrowHelper.ThrowIfNull(key, "key");
TValue value; TValue value;
if (TryGetValue(key, out value)) if (TryGetValue(key, out value))
{ {
...@@ -135,6 +137,11 @@ namespace Google.Protobuf.Collections ...@@ -135,6 +137,11 @@ namespace Google.Protobuf.Collections
} }
set set
{ {
ThrowHelper.ThrowIfNull(key, "key");
if (value == null && (typeof(TValue) == typeof(ByteString) || typeof(TValue) == typeof(string)))
{
ThrowHelper.ThrowIfNull(value, "value");
}
this.CheckMutable(); this.CheckMutable();
LinkedListNode<KeyValuePair<TKey, TValue>> node; LinkedListNode<KeyValuePair<TKey, TValue>> node;
var pair = new KeyValuePair<TKey, TValue>(key, value); var pair = new KeyValuePair<TKey, TValue>(key, value);
...@@ -156,9 +163,10 @@ namespace Google.Protobuf.Collections ...@@ -156,9 +163,10 @@ namespace Google.Protobuf.Collections
public void Add(IDictionary<TKey, TValue> entries) public void Add(IDictionary<TKey, TValue> entries)
{ {
ThrowHelper.ThrowIfNull(entries, "entries");
foreach (var pair in entries) foreach (var pair in entries)
{ {
Add(pair); Add(pair.Key, pair.Value);
} }
} }
...@@ -172,9 +180,8 @@ namespace Google.Protobuf.Collections ...@@ -172,9 +180,8 @@ namespace Google.Protobuf.Collections
return GetEnumerator(); return GetEnumerator();
} }
public void Add(KeyValuePair<TKey, TValue> item) void ICollection<KeyValuePair<TKey, TValue>>.Add(KeyValuePair<TKey, TValue> item)
{ {
this.CheckMutable();
Add(item.Key, item.Value); Add(item.Key, item.Value);
} }
...@@ -185,22 +192,37 @@ namespace Google.Protobuf.Collections ...@@ -185,22 +192,37 @@ namespace Google.Protobuf.Collections
map.Clear(); map.Clear();
} }
public bool Contains(KeyValuePair<TKey, TValue> item) bool ICollection<KeyValuePair<TKey, TValue>>.Contains(KeyValuePair<TKey, TValue> item)
{ {
TValue value; TValue value;
return TryGetValue(item.Key, out value) return TryGetValue(item.Key, out value)
&& EqualityComparer<TValue>.Default.Equals(item.Value, value); && EqualityComparer<TValue>.Default.Equals(item.Value, value);
} }
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) void ICollection<KeyValuePair<TKey, TValue>>.CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{ {
list.CopyTo(array, arrayIndex); list.CopyTo(array, arrayIndex);
} }
public bool Remove(KeyValuePair<TKey, TValue> item) bool ICollection<KeyValuePair<TKey, TValue>>.Remove(KeyValuePair<TKey, TValue> item)
{ {
this.CheckMutable(); this.CheckMutable();
return Remove(item.Key); if (item.Key == null)
{
throw new ArgumentException("Key is null", "item");
}
LinkedListNode<KeyValuePair<TKey, TValue>> node;
if (map.TryGetValue(item.Key, out node) &&
EqualityComparer<TValue>.Default.Equals(item.Value, node.Value.Value))
{
map.Remove(item.Key);
node.List.Remove(node);
return true;
}
else
{
return false;
}
} }
public int Count { get { return list.Count; } } public int Count { get { return list.Count; } }
...@@ -239,7 +261,7 @@ namespace Google.Protobuf.Collections ...@@ -239,7 +261,7 @@ namespace Google.Protobuf.Collections
public override int GetHashCode() public override int GetHashCode()
{ {
var valueComparer = EqualityComparer<TValue>.Default; var valueComparer = EqualityComparer<TValue>.Default;
int hash = 0; int hash = 19;
foreach (var pair in list) foreach (var pair in list)
{ {
hash ^= pair.Key.GetHashCode() * 31 + valueComparer.GetHashCode(pair.Value); hash ^= pair.Key.GetHashCode() * 31 + valueComparer.GetHashCode(pair.Value);
...@@ -277,14 +299,26 @@ namespace Google.Protobuf.Collections ...@@ -277,14 +299,26 @@ namespace Google.Protobuf.Collections
return true; return true;
} }
/// <summary>
/// Adds entries to the map from the given stream.
/// </summary>
/// <remarks>
/// It is assumed that the stream is initially positioned after the tag specified by the codec.
/// This method will continue reading entries from the stream until the end is reached, or
/// a different tag is encountered.
/// </remarks>
/// <param name="input">Stream to read from</param>
/// <param name="codec">Codec describing how the key/value pairs are encoded</param>
public void AddEntriesFrom(CodedInputStream input, Codec codec) public void AddEntriesFrom(CodedInputStream input, Codec codec)
{ {
// TODO: Peek at the next tag and see if it's the same. If it is, we can reuse the entry object...
var adapter = new Codec.MessageAdapter(codec); var adapter = new Codec.MessageAdapter(codec);
adapter.Reset(); do
input.ReadMessage(adapter); {
this[adapter.Key] = adapter.Value; adapter.Reset();
} input.ReadMessage(adapter);
this[adapter.Key] = adapter.Value;
} while (input.MaybeConsumeTag(codec.MapTag));
}
public void WriteTo(CodedOutputStream output, Codec codec) public void WriteTo(CodedOutputStream output, Codec codec)
{ {
......
...@@ -405,9 +405,10 @@ void MessageGenerator::GenerateFrameworkMethods(io::Printer* printer) { ...@@ -405,9 +405,10 @@ void MessageGenerator::GenerateFrameworkMethods(io::Printer* printer) {
"}\n\n"); "}\n\n");
// GetHashCode // GetHashCode
// Start with a non-zero value to easily distinguish between null and "empty" messages.
printer->Print( printer->Print(
"public override int GetHashCode() {\n" "public override int GetHashCode() {\n"
" int hash = 0;\n"); " int hash = 17;\n");
printer->Indent(); printer->Indent();
for (int i = 0; i < descriptor_->field_count(); i++) { for (int i = 0; i < descriptor_->field_count(); i++) {
scoped_ptr<FieldGeneratorBase> generator( scoped_ptr<FieldGeneratorBase> generator(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment