Unverified Commit 0e8f69e5 authored by Jan Tattermusch's avatar Jan Tattermusch Committed by GitHub

enforce recursion depth checking for unknown fields (#7210)

parent d3141015
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
using System; using System;
using System.IO; using System.IO;
using Google.Protobuf.TestProtos; using Google.Protobuf.TestProtos;
using Proto2 = Google.Protobuf.TestProtos.Proto2;
using NUnit.Framework; using NUnit.Framework;
namespace Google.Protobuf namespace Google.Protobuf
...@@ -337,6 +338,66 @@ namespace Google.Protobuf ...@@ -337,6 +338,66 @@ namespace Google.Protobuf
CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1); CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input)); Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
} }
private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
{
// generate recursively nested groups that will be parsed as unknown fields
int unknownFieldNumber = 14; // an unused field number
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
}
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
}
output.Flush();
return ms.ToArray();
}
[Test]
public void MaliciousRecursion_UnknownFields()
{
byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);
Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
}
[Test]
public void ReadGroup_WrongEndGroupTag()
{
int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;
// write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
// end group with different field number
output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();
Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
}
[Test]
public void ReadGroup_UnknownFields_WrongEndGroupTag()
{
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
// end group with different field number
output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
}
[Test] [Test]
public void SizeLimit() public void SizeLimit()
...@@ -735,4 +796,4 @@ namespace Google.Protobuf ...@@ -735,4 +796,4 @@ namespace Google.Protobuf
} }
} }
} }
} }
\ No newline at end of file
...@@ -307,10 +307,17 @@ namespace Google.Protobuf ...@@ -307,10 +307,17 @@ namespace Google.Protobuf
throw InvalidProtocolBufferException.MoreDataAvailable(); throw InvalidProtocolBufferException.MoreDataAvailable();
} }
} }
#endregion
internal void CheckLastTagWas(uint expectedTag)
{
if (lastTag != expectedTag) {
throw InvalidProtocolBufferException.InvalidEndTag();
}
}
#endregion
#region Reading of tags etc #region Reading of tags etc
/// <summary> /// <summary>
/// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the /// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
/// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the /// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
...@@ -636,7 +643,27 @@ namespace Google.Protobuf ...@@ -636,7 +643,27 @@ namespace Google.Protobuf
throw InvalidProtocolBufferException.RecursionLimitExceeded(); throw InvalidProtocolBufferException.RecursionLimitExceeded();
} }
++recursionDepth; ++recursionDepth;
uint tag = lastTag;
int fieldNumber = WireFormat.GetTagFieldNumber(tag);
builder.MergeFrom(this); builder.MergeFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}
/// <summary>
/// Reads an embedded group unknown field from the stream.
/// </summary>
internal void ReadGroup(int fieldNumber, UnknownFieldSet set)
{
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;
set.MergeGroupFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth; --recursionDepth;
} }
......
...@@ -215,12 +215,8 @@ namespace Google.Protobuf ...@@ -215,12 +215,8 @@ namespace Google.Protobuf
} }
case WireFormat.WireType.StartGroup: case WireFormat.WireType.StartGroup:
{ {
uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup);
UnknownFieldSet set = new UnknownFieldSet(); UnknownFieldSet set = new UnknownFieldSet();
while (input.ReadTag() != endTag) input.ReadGroup(number, set);
{
set.MergeFieldFrom(input);
}
GetOrAddField(number).AddGroup(set); GetOrAddField(number).AddGroup(set);
return true; return true;
} }
...@@ -233,6 +229,22 @@ namespace Google.Protobuf ...@@ -233,6 +229,22 @@ namespace Google.Protobuf
} }
} }
internal void MergeGroupFrom(CodedInputStream input)
{
while (true)
{
uint tag = input.ReadTag();
if (tag == 0)
{
break;
}
if (!MergeFieldFrom(input))
{
break;
}
}
}
/// <summary> /// <summary>
/// Create a new UnknownFieldSet if unknownFields is null. /// Create a new UnknownFieldSet if unknownFields is null.
/// Parse a single field from <paramref name="input"/> and merge it /// Parse a single field from <paramref name="input"/> and merge it
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment