Commit e7f88ff1 authored by Jon Skeet's avatar Jon Skeet

Skip groups properly.

Now the generated code doesn't need to check for end group tags, as it will skip whole groups at a time.
Currently it will ignore extraneous end group tags, which may or may not be a good thing.
Renamed ConsumeLastField to SkipLastField as it felt more natural.
Removed WireFormat.IsEndGroupTag as it's no longer useful.

This mostly fixes issue 688.

(Generated code changes coming in next commit.)
parent ad8a889d
......@@ -442,5 +442,92 @@ namespace Google.Protobuf
var input = new CodedInputStream(new byte[] { 0 });
Assert.Throws<InvalidProtocolBufferException>(() => input.ReadTag());
}
[Test]
public void SkipGroup()
{
// Create an output stream with a group in:
// Field 1: string "field 1"
// Field 2: group containing:
// Field 1: fixed int32 value 100
// Field 2: string "ignore me"
// Field 3: nested group containing
// Field 1: fixed int64 value 1000
// Field 3: string "field 3"
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
output.WriteTag(1, WireFormat.WireType.LengthDelimited);
output.WriteString("field 1");
// The outer group...
output.WriteTag(2, WireFormat.WireType.StartGroup);
output.WriteTag(1, WireFormat.WireType.Fixed32);
output.WriteFixed32(100);
output.WriteTag(2, WireFormat.WireType.LengthDelimited);
output.WriteString("ignore me");
// The nested group...
output.WriteTag(3, WireFormat.WireType.StartGroup);
output.WriteTag(1, WireFormat.WireType.Fixed64);
output.WriteFixed64(1000);
// Note: Not sure the field number is relevant for end group...
output.WriteTag(3, WireFormat.WireType.EndGroup);
// End the outer group
output.WriteTag(2, WireFormat.WireType.EndGroup);
output.WriteTag(3, WireFormat.WireType.LengthDelimited);
output.WriteString("field 3");
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
Assert.AreEqual("field 1", input.ReadString());
Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
input.SkipLastField(); // Should consume the whole group, including the nested one.
Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag());
Assert.AreEqual("field 3", input.ReadString());
}
[Test]
public void EndOfStreamReachedWhileSkippingGroup()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
output.WriteTag(1, WireFormat.WireType.StartGroup);
output.WriteTag(2, WireFormat.WireType.StartGroup);
output.WriteTag(2, WireFormat.WireType.EndGroup);
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
input.ReadTag();
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
}
[Test]
public void RecursionLimitAppliedWhileSkippingGroup()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
{
output.WriteTag(1, WireFormat.WireType.StartGroup);
}
for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
{
output.WriteTag(1, WireFormat.WireType.EndGroup);
}
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
}
}
}
\ No newline at end of file
......@@ -236,17 +236,16 @@ namespace Google.Protobuf
#region Validation
/// <summary>
/// Verifies that the last call to ReadTag() returned the given tag value.
/// This is used to verify that a nested group ended with the correct
/// end tag.
/// Verifies that the last call to ReadTag() returned tag 0 - in other words,
/// we've reached the end of the stream when we expected to.
/// </summary>
/// <exception cref="InvalidProtocolBufferException">The last
/// <exception cref="InvalidProtocolBufferException">The
/// tag read was not the one specified</exception>
internal void CheckLastTagWas(uint value)
internal void CheckReadEndOfStreamTag()
{
if (lastTag != value)
if (lastTag != 0)
{
throw InvalidProtocolBufferException.InvalidEndTag();
throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
#endregion
......@@ -275,6 +274,11 @@ namespace Google.Protobuf
/// <summary>
/// Reads a field tag, returning the tag of 0 for "end of stream".
/// </summary>
/// <remarks>
/// If this method returns 0, it doesn't necessarily mean the end of all
/// the data in this CodedInputStream; it may be the end of the logical stream
/// for an embedded message, for example.
/// </remarks>
/// <returns>The next field tag, or 0 for end of stream. (0 is never a valid tag.)</returns>
public uint ReadTag()
{
......@@ -329,22 +333,24 @@ namespace Google.Protobuf
}
/// <summary>
/// Consumes the data for the field with the tag we've just read.
/// Skips the data for the field with the tag we've just read.
/// This should be called directly after <see cref="ReadTag"/>, when
/// the caller wishes to skip an unknown field.
/// </summary>
public void ConsumeLastField()
public void SkipLastField()
{
if (lastTag == 0)
{
throw new InvalidOperationException("ConsumeLastField cannot be called at the end of a stream");
throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream");
}
switch (WireFormat.GetTagWireType(lastTag))
{
case WireFormat.WireType.StartGroup:
ConsumeGroup();
break;
case WireFormat.WireType.EndGroup:
// TODO: Work out how to skip them instead? See issue 688.
throw new InvalidProtocolBufferException("Group tags not supported by proto3 C# implementation");
// Just ignore; there's no data following the tag.
break;
case WireFormat.WireType.Fixed32:
ReadFixed32();
break;
......@@ -361,6 +367,29 @@ namespace Google.Protobuf
}
}
private void ConsumeGroup()
{
// Note: Currently we expect this to be the way that groups are read. We could put the recursion
// depth changes into the ReadTag method instead, potentially...
recursionDepth++;
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
uint tag;
do
{
tag = ReadTag();
if (tag == 0)
{
throw InvalidProtocolBufferException.TruncatedMessage();
}
// This recursion will allow us to handle nested groups.
SkipLastField();
} while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);
recursionDepth--;
}
/// <summary>
/// Reads a double field from the stream.
/// </summary>
......@@ -475,7 +504,7 @@ namespace Google.Protobuf
int oldLimit = PushLimit(length);
++recursionDepth;
builder.MergeFrom(this);
CheckLastTagWas(0);
CheckReadEndOfStreamTag();
// Check that we've read exactly as much data as expected.
if (!ReachedLimit)
{
......
......@@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections
{
Value = codec.valueCodec.Read(input);
}
else if (WireFormat.IsEndGroupTag(tag))
else
{
// TODO(jonskeet): Do we need this? (Given that we don't support groups...)
return;
input.SkipLastField();
}
}
}
......
......@@ -304,12 +304,13 @@ namespace Google.Protobuf
{
value = codec.Read(input);
}
if (WireFormat.IsEndGroupTag(tag))
else
{
break;
input.SkipLastField();
}
}
input.CheckLastTagWas(0);
input.CheckReadEndOfStreamTag();
input.PopLimit(oldLimit);
return value;
......
......@@ -50,7 +50,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data");
CodedInputStream input = new CodedInputStream(data);
message.MergeFrom(input);
input.CheckLastTagWas(0);
input.CheckReadEndOfStreamTag();
}
/// <summary>
......@@ -64,7 +64,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data");
CodedInputStream input = data.CreateCodedInput();
message.MergeFrom(input);
input.CheckLastTagWas(0);
input.CheckReadEndOfStreamTag();
}
/// <summary>
......@@ -78,7 +78,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(input, "input");
CodedInputStream codedInput = new CodedInputStream(input);
message.MergeFrom(codedInput);
codedInput.CheckLastTagWas(0);
codedInput.CheckReadEndOfStreamTag();
}
/// <summary>
......
......@@ -98,16 +98,6 @@ namespace Google.Protobuf
return (WireType) (tag & TagTypeMask);
}
/// <summary>
/// Determines whether the given tag is an end group tag.
/// </summary>
/// <param name="tag">The tag to check.</param>
/// <returns><c>true</c> if the given tag is an end group tag; <c>false</c> otherwise.</returns>
public static bool IsEndGroupTag(uint tag)
{
return (WireType) (tag & TagTypeMask) == WireType.EndGroup;
}
/// <summary>
/// Given a tag value, determines the field number (the upper 29 bits).
/// </summary>
......
......@@ -423,10 +423,7 @@ void MessageGenerator::GenerateMergingMethods(io::Printer* printer) {
printer->Indent();
printer->Print(
"default:\n"
" if (pb::WireFormat.IsEndGroupTag(tag)) {\n"
" return;\n"
" }\n"
" input.ConsumeLastField();\n" // We're not storing the data, but we still need to consume it.
" input.SkipLastField();\n" // We're not storing the data, but we still need to consume it.
" break;\n");
for (int i = 0; i < fields_by_number().size(); i++) {
const FieldDescriptor* field = fields_by_number()[i];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment