Commit 9bdc8488 authored by Jon Skeet's avatar Jon Skeet

Validate that end-group tags match their corresponding start-group tags

This detects:
- An end-group tag with the wrong field number (doesn't match the start-group field)
- An end-group tag with no preceding start-group tag

Fixes issue #688.
parent e35e2480
...@@ -469,6 +469,52 @@ namespace Google.Protobuf ...@@ -469,6 +469,52 @@ namespace Google.Protobuf
Assert.AreEqual("field 3", input.ReadString()); Assert.AreEqual("field 3", input.ReadString());
} }
[Test]
public void SkipGroup_WrongEndGroupTag()
{
// Create an output stream with:
// Field 1: string "field 1"
// Start group 2
// Field 3: fixed int32
// End group 4 (should give an error)
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
output.WriteTag(1, WireFormat.WireType.LengthDelimited);
output.WriteString("field 1");
// The outer group...
output.WriteTag(2, WireFormat.WireType.StartGroup);
output.WriteTag(3, WireFormat.WireType.Fixed32);
output.WriteFixed32(100);
output.WriteTag(4, WireFormat.WireType.EndGroup);
output.Flush();
stream.Position = 0;
// Now act like a generated client
var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
Assert.AreEqual("field 1", input.ReadString());
Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
}
[Test]
public void RogueEndGroupTag()
{
// If we have an end-group tag without a leading start-group tag, generated
// code will just call SkipLastField... so that should fail.
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
output.WriteTag(1, WireFormat.WireType.EndGroup);
output.Flush();
stream.Position = 0;
var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.EndGroup), input.ReadTag());
Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
}
[Test] [Test]
public void EndOfStreamReachedWhileSkippingGroup() public void EndOfStreamReachedWhileSkippingGroup()
{ {
...@@ -484,7 +530,7 @@ namespace Google.Protobuf ...@@ -484,7 +530,7 @@ namespace Google.Protobuf
// Now act like a generated client // Now act like a generated client
var input = new CodedInputStream(stream); var input = new CodedInputStream(stream);
input.ReadTag(); input.ReadTag();
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField()); Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
} }
[Test] [Test]
...@@ -506,7 +552,7 @@ namespace Google.Protobuf ...@@ -506,7 +552,7 @@ namespace Google.Protobuf
// Now act like a generated client // Now act like a generated client
var input = new CodedInputStream(stream); var input = new CodedInputStream(stream);
Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag()); Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
Assert.Throws<InvalidProtocolBufferException>(() => input.SkipLastField()); Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
} }
[Test] [Test]
......
...@@ -679,21 +679,20 @@ namespace Google.Protobuf ...@@ -679,21 +679,20 @@ namespace Google.Protobuf
/// for details; we may want to change this. /// for details; we may want to change this.
/// </summary> /// </summary>
[Test] [Test]
public void ExtraEndGroupSkipped() public void ExtraEndGroupThrows()
{ {
var message = SampleMessages.CreateFullTestAllTypes(); var message = SampleMessages.CreateFullTestAllTypes();
var stream = new MemoryStream(); var stream = new MemoryStream();
var output = new CodedOutputStream(stream); var output = new CodedOutputStream(stream);
output.WriteTag(100, WireFormat.WireType.EndGroup);
output.WriteTag(TestAllTypes.SingleFixed32FieldNumber, WireFormat.WireType.Fixed32); output.WriteTag(TestAllTypes.SingleFixed32FieldNumber, WireFormat.WireType.Fixed32);
output.WriteFixed32(123); output.WriteFixed32(123);
output.WriteTag(100, WireFormat.WireType.EndGroup);
output.Flush(); output.Flush();
stream.Position = 0; stream.Position = 0;
var parsed = TestAllTypes.Parser.ParseFrom(stream); Assert.Throws<InvalidProtocolBufferException>(() => TestAllTypes.Parser.ParseFrom(stream));
Assert.AreEqual(new TestAllTypes { SingleFixed32 = 123 }, parsed);
} }
[Test] [Test]
......
...@@ -349,6 +349,14 @@ namespace Google.Protobuf ...@@ -349,6 +349,14 @@ namespace Google.Protobuf
/// This should be called directly after <see cref="ReadTag"/>, when /// This should be called directly after <see cref="ReadTag"/>, when
/// the caller wishes to skip an unknown field. /// the caller wishes to skip an unknown field.
/// </summary> /// </summary>
/// <remarks>
/// This method throws <see cref="InvalidProtocolBufferException"/> if the last-read tag was an end-group tag.
/// If a caller wishes to skip a group, they should skip the whole group, by calling this method after reading the
/// start-group tag. This behavior allows callers to call this method on any field they don't understand, correctly
/// resulting in an error if an end-group tag has not been paired with an earlier start-group tag.
/// </remarks>
/// <exception cref="InvalidProtocolBufferException">The last tag was an end-group tag</exception>
/// <exception cref="InvalidOperationException">The last read operation read to the end of the logical stream</exception>
public void SkipLastField() public void SkipLastField()
{ {
if (lastTag == 0) if (lastTag == 0)
...@@ -358,11 +366,11 @@ namespace Google.Protobuf ...@@ -358,11 +366,11 @@ namespace Google.Protobuf
switch (WireFormat.GetTagWireType(lastTag)) switch (WireFormat.GetTagWireType(lastTag))
{ {
case WireFormat.WireType.StartGroup: case WireFormat.WireType.StartGroup:
SkipGroup(); SkipGroup(lastTag);
break; break;
case WireFormat.WireType.EndGroup: case WireFormat.WireType.EndGroup:
// Just ignore; there's no data following the tag. throw new InvalidProtocolBufferException(
break; "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing");
case WireFormat.WireType.Fixed32: case WireFormat.WireType.Fixed32:
ReadFixed32(); ReadFixed32();
break; break;
...@@ -379,7 +387,7 @@ namespace Google.Protobuf ...@@ -379,7 +387,7 @@ namespace Google.Protobuf
} }
} }
private void SkipGroup() private void SkipGroup(uint startGroupTag)
{ {
// Note: Currently we expect this to be the way that groups are read. We could put the recursion // Note: Currently we expect this to be the way that groups are read. We could put the recursion
// depth changes into the ReadTag method instead, potentially... // depth changes into the ReadTag method instead, potentially...
...@@ -389,16 +397,28 @@ namespace Google.Protobuf ...@@ -389,16 +397,28 @@ namespace Google.Protobuf
throw InvalidProtocolBufferException.RecursionLimitExceeded(); throw InvalidProtocolBufferException.RecursionLimitExceeded();
} }
uint tag; uint tag;
do while (true)
{ {
tag = ReadTag(); tag = ReadTag();
if (tag == 0) if (tag == 0)
{ {
throw InvalidProtocolBufferException.TruncatedMessage(); throw InvalidProtocolBufferException.TruncatedMessage();
} }
// Can't call SkipLastField for this case- that would throw.
if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup)
{
break;
}
// This recursion will allow us to handle nested groups. // This recursion will allow us to handle nested groups.
SkipLastField(); SkipLastField();
} while (WireFormat.GetTagWireType(tag) != WireFormat.WireType.EndGroup); }
int startField = WireFormat.GetTagFieldNumber(startGroupTag);
int endField = WireFormat.GetTagFieldNumber(tag);
if (startField != endField)
{
throw new InvalidProtocolBufferException(
$"Mismatched end-group tag. Started with field {startField}; ended with field {endField}");
}
recursionDepth--; recursionDepth--;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment