Commit 2178b93b authored by Jon Skeet's avatar Jon Skeet

Fix bug when reading many messages - size guard was triggered

parent 60fb63e3
...@@ -3,6 +3,7 @@ using System.Collections.Generic; ...@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using NUnit.Framework; using NUnit.Framework;
using NestedMessage = Google.ProtocolBuffers.TestProtos.TestAllTypes.Types.NestedMessage; using NestedMessage = Google.ProtocolBuffers.TestProtos.TestAllTypes.Types.NestedMessage;
using Google.ProtocolBuffers.TestProtos;
namespace Google.ProtocolBuffers { namespace Google.ProtocolBuffers {
[TestFixture] [TestFixture]
...@@ -19,5 +20,31 @@ namespace Google.ProtocolBuffers { ...@@ -19,5 +20,31 @@ namespace Google.ProtocolBuffers {
Assert.AreEqual(1500, messages[1].Bb); Assert.AreEqual(1500, messages[1].Bb);
Assert.IsFalse(messages[2].HasBb); Assert.IsFalse(messages[2].HasBb);
} }
[Test]
public void ManyMessagesShouldNotTriggerSizeAlert() {
int messageSize = TestUtil.GetAllSet().SerializedSize;
// Enough messages to trigger the alert unless we've reset the size
// Note that currently we need to make this big enough to copy two whole buffers,
// as otherwise when we refill the buffer the second type, the alert triggers instantly.
int correctCount = (CodedInputStream.BufferSize * 2) / messageSize + 1;
using (MemoryStream stream = new MemoryStream()) {
MessageStreamWriter<TestAllTypes> writer = new MessageStreamWriter<TestAllTypes>(stream);
for (int i = 0; i < correctCount; i++) {
writer.Write(TestUtil.GetAllSet());
}
writer.Flush();
stream.Position = 0;
int count = 0;
foreach (var message in MessageStreamIterator<TestAllTypes>.FromStreamProvider(() => stream)
.WithSizeLimit(CodedInputStream.BufferSize * 2)) {
count++;
TestUtil.AssertAllFieldsSet(message);
}
Assert.AreEqual(correctCount, count);
}
}
} }
} }
...@@ -61,9 +61,9 @@ namespace Google.ProtocolBuffers { ...@@ -61,9 +61,9 @@ namespace Google.ProtocolBuffers {
private readonly Stream input; private readonly Stream input;
private uint lastTag = 0; private uint lastTag = 0;
const int DefaultRecursionLimit = 64; internal const int DefaultRecursionLimit = 64;
const int DefaultSizeLimit = 64 << 20; // 64MB internal const int DefaultSizeLimit = 64 << 20; // 64MB
const int BufferSize = 4096; internal const int BufferSize = 4096;
/// <summary> /// <summary>
/// The total number of bytes read before the current buffer. The /// The total number of bytes read before the current buffer. The
...@@ -741,7 +741,7 @@ namespace Google.ProtocolBuffers { ...@@ -741,7 +741,7 @@ namespace Google.ProtocolBuffers {
/// Read one byte from the input. /// Read one byte from the input.
/// </summary> /// </summary>
/// <exception cref="InvalidProtocolBufferException"> /// <exception cref="InvalidProtocolBufferException">
/// he end of the stream or the current limit was reached /// the end of the stream or the current limit was reached
/// </exception> /// </exception>
public byte ReadRawByte() { public byte ReadRawByte() {
if (bufferPos == bufferSize) { if (bufferPos == bufferSize) {
......
...@@ -18,6 +18,7 @@ namespace Google.ProtocolBuffers { ...@@ -18,6 +18,7 @@ namespace Google.ProtocolBuffers {
private readonly StreamProvider streamProvider; private readonly StreamProvider streamProvider;
private readonly ExtensionRegistry extensionRegistry; private readonly ExtensionRegistry extensionRegistry;
private readonly int sizeLimit;
/// <summary> /// <summary>
/// Delegate created via reflection trickery (once per type) to create a builder /// Delegate created via reflection trickery (once per type) to create a builder
...@@ -103,17 +104,22 @@ namespace Google.ProtocolBuffers { ...@@ -103,17 +104,22 @@ namespace Google.ProtocolBuffers {
TBuilder builder = builderBuilder(); TBuilder builder = builderBuilder();
input.ReadMessage(builder, registry); input.ReadMessage(builder, registry);
return builder.Build(); return builder.Build();
} }
#pragma warning restore 0414 #pragma warning restore 0414
private static readonly uint ExpectedTag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited); private static readonly uint ExpectedTag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
private MessageStreamIterator(StreamProvider streamProvider, ExtensionRegistry extensionRegistry) { private MessageStreamIterator(StreamProvider streamProvider, ExtensionRegistry extensionRegistry, int sizeLimit) {
if (messageReader == null) { if (messageReader == null) {
throw typeInitializationException; throw typeInitializationException;
} }
this.streamProvider = streamProvider; this.streamProvider = streamProvider;
this.extensionRegistry = extensionRegistry; this.extensionRegistry = extensionRegistry;
this.sizeLimit = sizeLimit;
}
private MessageStreamIterator(StreamProvider streamProvider, ExtensionRegistry extensionRegistry)
: this (streamProvider, extensionRegistry, CodedInputStream.DefaultSizeLimit) {
} }
/// <summary> /// <summary>
...@@ -121,7 +127,16 @@ namespace Google.ProtocolBuffers { ...@@ -121,7 +127,16 @@ namespace Google.ProtocolBuffers {
/// but the specified extension registry. /// but the specified extension registry.
/// </summary> /// </summary>
public MessageStreamIterator<TMessage> WithExtensionRegistry(ExtensionRegistry newRegistry) { public MessageStreamIterator<TMessage> WithExtensionRegistry(ExtensionRegistry newRegistry) {
return new MessageStreamIterator<TMessage>(streamProvider, newRegistry); return new MessageStreamIterator<TMessage>(streamProvider, newRegistry, sizeLimit);
}
/// <summary>
/// Creates a new instance which uses the same stream provider and extension registry as this one,
/// but with the specified size limit. Note that this must be big enough for the largest message
/// and the tag and size preceding it.
/// </summary>
public MessageStreamIterator<TMessage> WithSizeLimit(int newSizeLimit) {
return new MessageStreamIterator<TMessage>(streamProvider, extensionRegistry, newSizeLimit);
} }
public static MessageStreamIterator<TMessage> FromFile(string file) { public static MessageStreamIterator<TMessage> FromFile(string file) {
...@@ -135,12 +150,14 @@ namespace Google.ProtocolBuffers { ...@@ -135,12 +150,14 @@ namespace Google.ProtocolBuffers {
public IEnumerator<TMessage> GetEnumerator() { public IEnumerator<TMessage> GetEnumerator() {
using (Stream stream = streamProvider()) { using (Stream stream = streamProvider()) {
CodedInputStream input = CodedInputStream.CreateInstance(stream); CodedInputStream input = CodedInputStream.CreateInstance(stream);
input.SetSizeLimit(sizeLimit);
uint tag; uint tag;
while ((tag = input.ReadTag()) != 0) { while ((tag = input.ReadTag()) != 0) {
if (tag != ExpectedTag) { if (tag != ExpectedTag) {
throw InvalidProtocolBufferException.InvalidMessageStreamTag(); throw InvalidProtocolBufferException.InvalidMessageStreamTag();
} }
yield return messageReader(input, extensionRegistry); yield return messageReader(input, extensionRegistry);
input.ResetSizeCounter();
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment