Commit e7f88ff1 authored by Jon Skeet's avatar Jon Skeet

Skip groups properly.

Now the generated code doesn't need to check for end group tags, as it will skip whole groups at a time.
Currently it will ignore extraneous end group tags, which may or may not be a good thing.
Renamed ConsumeLastField to SkipLastField as it felt more natural.
Removed WireFormat.IsEndGroupTag as it's no longer useful.

This mostly fixes issue 688.

(Generated code changes coming in next commit.)
parent ad8a889d
...@@ -442,5 +442,92 @@ namespace Google.Protobuf ...@@ -442,5 +442,92 @@ namespace Google.Protobuf
var input = new CodedInputStream(new byte[] { 0 }); var input = new CodedInputStream(new byte[] { 0 });
Assert.Throws<InvalidProtocolBufferException>(() => input.ReadTag()); Assert.Throws<InvalidProtocolBufferException>(() => input.ReadTag());
} }
[Test]
public void SkipGroup()
{
// Create an output stream with a group in:
// Field 1: string "field 1"
// Field 2: group containing:
// Field 1: fixed int32 value 100
// Field 2: string "ignore me"
// Field 3: nested group containing
// Field 1: fixed int64 value 1000
// Field 3: string "field 3"
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
output.WriteTag(1, WireFormat.WireType.LengthDelimited);
output.WriteString("field 1");
// The outer group...
output.WriteTag(2, WireFormat.WireType.StartGroup);
output.WriteTag(1, WireFormat.WireType.Fixed32);
output.WriteFixed32(100);
output.WriteTag(2, WireFormat.WireType.LengthDelimited);
output.WriteString("ignore me");
// The nested group...
output.WriteTag(3, WireFormat.WireType.StartGroup);
output.WriteTag(1, WireFormat.WireType.Fixed64);
output.WriteFixed64(1000);
// Note: Not sure the field number is relevant for end group...
output.WriteTag(3, WireFormat.WireType.EndGroup);
// End the outer group
output.WriteTag(2, WireFormat.WireType.EndGroup);
output.WriteTag(3, WireFormat.WireType.LengthDelimited);
output.WriteString("field 3");
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
Assert.AreEqual("field 1", input.ReadString());
Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
input.SkipLastField(); // Should consume the whole group, including the nested one.
Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag());
Assert.AreEqual("field 3", input.ReadString());
}
[Test]
public void EndOfStreamReachedWhileSkippingGroup()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
output.WriteTag(1, WireFormat.WireType.StartGroup);
output.WriteTag(2, WireFormat.WireType.StartGroup);
output.WriteTag(2, WireFormat.WireType.EndGroup);
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
input.ReadTag();
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
}
[Test]
public void RecursionLimitAppliedWhileSkippingGroup()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
{
output.WriteTag(1, WireFormat.WireType.StartGroup);
}
for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
{
output.WriteTag(1, WireFormat.WireType.EndGroup);
}
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField());
}
} }
} }
\ No newline at end of file
...@@ -236,17 +236,16 @@ namespace Google.Protobuf ...@@ -236,17 +236,16 @@ namespace Google.Protobuf
#region Validation #region Validation
/// <summary> /// <summary>
/// Verifies that the last call to ReadTag() returned the given tag value. /// Verifies that the last call to ReadTag() returned tag 0 - in other words,
/// This is used to verify that a nested group ended with the correct /// we've reached the end of the stream when we expected to.
/// end tag.
/// </summary> /// </summary>
/// <exception cref="InvalidProtocolBufferException">The last /// <exception cref="InvalidProtocolBufferException">The
/// tag read was not the one specified</exception> /// tag read was not the one specified</exception>
internal void CheckLastTagWas(uint value) internal void CheckReadEndOfStreamTag()
{ {
if (lastTag != value) if (lastTag != 0)
{ {
throw InvalidProtocolBufferException.InvalidEndTag(); throw InvalidProtocolBufferException.MoreDataAvailable();
} }
} }
#endregion #endregion
...@@ -275,6 +274,11 @@ namespace Google.Protobuf ...@@ -275,6 +274,11 @@ namespace Google.Protobuf
/// <summary> /// <summary>
/// Reads a field tag, returning the tag of 0 for "end of stream". /// Reads a field tag, returning the tag of 0 for "end of stream".
/// </summary> /// </summary>
/// <remarks>
/// If this method returns 0, it doesn't necessarily mean the end of all
/// the data in this CodedInputStream; it may be the end of the logical stream
/// for an embedded message, for example.
/// </remarks>
/// <returns>The next field tag, or 0 for end of stream. (0 is never a valid tag.)</returns> /// <returns>The next field tag, or 0 for end of stream. (0 is never a valid tag.)</returns>
public uint ReadTag() public uint ReadTag()
{ {
...@@ -329,22 +333,24 @@ namespace Google.Protobuf ...@@ -329,22 +333,24 @@ namespace Google.Protobuf
} }
/// <summary> /// <summary>
/// Consumes the data for the field with the tag we've just read. /// Skips the data for the field with the tag we've just read.
/// This should be called directly after <see cref="ReadTag"/>, when /// This should be called directly after <see cref="ReadTag"/>, when
/// the caller wishes to skip an unknown field. /// the caller wishes to skip an unknown field.
/// </summary> /// </summary>
public void ConsumeLastField() public void SkipLastField()
{ {
if (lastTag == 0) if (lastTag == 0)
{ {
throw new InvalidOperationException("ConsumeLastField cannot be called at the end of a stream"); throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream");
} }
switch (WireFormat.GetTagWireType(lastTag)) switch (WireFormat.GetTagWireType(lastTag))
{ {
case WireFormat.WireType.StartGroup: case WireFormat.WireType.StartGroup:
ConsumeGroup();
break;
case WireFormat.WireType.EndGroup: case WireFormat.WireType.EndGroup:
// TODO: Work out how to skip them instead? See issue 688. // Just ignore; there's no data following the tag.
throw new InvalidProtocolBufferException("Group tags not supported by proto3 C# implementation"); break;
case WireFormat.WireType.Fixed32: case WireFormat.WireType.Fixed32:
ReadFixed32(); ReadFixed32();
break; break;
...@@ -361,6 +367,29 @@ namespace Google.Protobuf ...@@ -361,6 +367,29 @@ namespace Google.Protobuf
} }
} }
private void ConsumeGroup()
{
// Note: Currently we expect this to be the way that groups are read. We could put the recursion
// depth changes into the ReadTag method instead, potentially...
recursionDepth++;
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
uint tag;
do
{
tag = ReadTag();
if (tag == 0)
{
throw InvalidProtocolBufferException.TruncatedMessage();
}
// This recursion will allow us to handle nested groups.
SkipLastField();
} while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup);
recursionDepth--;
}
/// <summary> /// <summary>
/// Reads a double field from the stream. /// Reads a double field from the stream.
/// </summary> /// </summary>
...@@ -475,7 +504,7 @@ namespace Google.Protobuf ...@@ -475,7 +504,7 @@ namespace Google.Protobuf
int oldLimit = PushLimit(length); int oldLimit = PushLimit(length);
++recursionDepth; ++recursionDepth;
builder.MergeFrom(this); builder.MergeFrom(this);
CheckLastTagWas(0); CheckReadEndOfStreamTag();
// Check that we've read exactly as much data as expected. // Check that we've read exactly as much data as expected.
if (!ReachedLimit) if (!ReachedLimit)
{ {
......
...@@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections ...@@ -637,10 +637,9 @@ namespace Google.Protobuf.Collections
{ {
Value = codec.valueCodec.Read(input); Value = codec.valueCodec.Read(input);
} }
else if (WireFormat.IsEndGroupTag(tag)) else
{ {
// TODO(jonskeet): Do we need this? (Given that we don't support groups...) input.SkipLastField();
return;
} }
} }
} }
......
...@@ -304,12 +304,13 @@ namespace Google.Protobuf ...@@ -304,12 +304,13 @@ namespace Google.Protobuf
{ {
value = codec.Read(input); value = codec.Read(input);
} }
if (WireFormat.IsEndGroupTag(tag)) else
{ {
break; input.SkipLastField();
} }
} }
input.CheckLastTagWas(0); input.CheckReadEndOfStreamTag();
input.PopLimit(oldLimit); input.PopLimit(oldLimit);
return value; return value;
......
...@@ -50,7 +50,7 @@ namespace Google.Protobuf ...@@ -50,7 +50,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data"); Preconditions.CheckNotNull(data, "data");
CodedInputStream input = new CodedInputStream(data); CodedInputStream input = new CodedInputStream(data);
message.MergeFrom(input); message.MergeFrom(input);
input.CheckLastTagWas(0); input.CheckReadEndOfStreamTag();
} }
/// <summary> /// <summary>
...@@ -64,7 +64,7 @@ namespace Google.Protobuf ...@@ -64,7 +64,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(data, "data"); Preconditions.CheckNotNull(data, "data");
CodedInputStream input = data.CreateCodedInput(); CodedInputStream input = data.CreateCodedInput();
message.MergeFrom(input); message.MergeFrom(input);
input.CheckLastTagWas(0); input.CheckReadEndOfStreamTag();
} }
/// <summary> /// <summary>
...@@ -78,7 +78,7 @@ namespace Google.Protobuf ...@@ -78,7 +78,7 @@ namespace Google.Protobuf
Preconditions.CheckNotNull(input, "input"); Preconditions.CheckNotNull(input, "input");
CodedInputStream codedInput = new CodedInputStream(input); CodedInputStream codedInput = new CodedInputStream(input);
message.MergeFrom(codedInput); message.MergeFrom(codedInput);
codedInput.CheckLastTagWas(0); codedInput.CheckReadEndOfStreamTag();
} }
/// <summary> /// <summary>
......
...@@ -98,16 +98,6 @@ namespace Google.Protobuf ...@@ -98,16 +98,6 @@ namespace Google.Protobuf
return (WireType) (tag & TagTypeMask); return (WireType) (tag & TagTypeMask);
} }
/// <summary>
/// Determines whether the given tag is an end group tag.
/// </summary>
/// <param name="tag">The tag to check.</param>
/// <returns><c>true</c> if the given tag is an end group tag; <c>false</c> otherwise.</returns>
public static bool IsEndGroupTag(uint tag)
{
return (WireType) (tag & TagTypeMask) == WireType.EndGroup;
}
/// <summary> /// <summary>
/// Given a tag value, determines the field number (the upper 29 bits). /// Given a tag value, determines the field number (the upper 29 bits).
/// </summary> /// </summary>
......
...@@ -423,10 +423,7 @@ void MessageGenerator::GenerateMergingMethods(io::Printer* printer) { ...@@ -423,10 +423,7 @@ void MessageGenerator::GenerateMergingMethods(io::Printer* printer) {
printer->Indent(); printer->Indent();
printer->Print( printer->Print(
"default:\n" "default:\n"
" if (pb::WireFormat.IsEndGroupTag(tag)) {\n" " input.SkipLastField();\n" // We're not storing the data, but we still need to consume it.
" return;\n"
" }\n"
" input.ConsumeLastField();\n" // We're not storing the data, but we still need to consume it.
" break;\n"); " break;\n");
for (int i = 0; i < fields_by_number().size(); i++) { for (int i = 0; i < fields_by_number().size(); i++) {
const FieldDescriptor* field = fields_by_number()[i]; const FieldDescriptor* field = fields_by_number()[i];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment