Commit 6fa17e75 authored by Jon Skeet's avatar Jon Skeet

Reimplement JSON recursion by detecting the depth in the tokenizer.

Added a TODO around a possible change to the tokenizer API, changing PushBack(token) into just Rewind() or something similar.
parent 3d257a9d
...@@ -81,6 +81,63 @@ namespace Google.Protobuf ...@@ -81,6 +81,63 @@ namespace Google.Protobuf
AssertTokens("'\ud800\\udc00'", JsonToken.Value(expected)); AssertTokens("'\ud800\\udc00'", JsonToken.Value(expected));
} }
[Test]
public void ObjectDepth()
{
string json = "{ \"foo\": { \"x\": 1, \"y\": [ 0 ] } }";
var tokenizer = new JsonTokenizer(new StringReader(json));
// If we had more tests like this, I'd introduce a helper method... but for one test, it's not worth it.
Assert.AreEqual(0, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.StartObject, tokenizer.Next());
Assert.AreEqual(1, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.Name("foo"), tokenizer.Next());
Assert.AreEqual(1, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.StartObject, tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.Name("x"), tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.Value(1), tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.Name("y"), tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.StartArray, tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth); // Depth hasn't changed in array
Assert.AreEqual(JsonToken.Value(0), tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.EndArray, tokenizer.Next());
Assert.AreEqual(2, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.EndObject, tokenizer.Next());
Assert.AreEqual(1, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.EndObject, tokenizer.Next());
Assert.AreEqual(0, tokenizer.ObjectDepth);
Assert.AreEqual(JsonToken.EndDocument, tokenizer.Next());
Assert.AreEqual(0, tokenizer.ObjectDepth);
}
[Test]
public void ObjectDepth_WithPushBack()
{
string json = "{}";
var tokenizer = new JsonTokenizer(new StringReader(json));
Assert.AreEqual(0, tokenizer.ObjectDepth);
var token = tokenizer.Next();
Assert.AreEqual(1, tokenizer.ObjectDepth);
// When we push back a "start object", we should effectively be back to the previous depth.
tokenizer.PushBack(token);
Assert.AreEqual(0, tokenizer.ObjectDepth);
// Read the same token again, and get back to depth 1
token = tokenizer.Next();
Assert.AreEqual(1, tokenizer.ObjectDepth);
// Now the same in reverse, with EndObject
token = tokenizer.Next();
Assert.AreEqual(0, tokenizer.ObjectDepth);
tokenizer.PushBack(token);
Assert.AreEqual(1, tokenizer.ObjectDepth);
tokenizer.Next();
Assert.AreEqual(0, tokenizer.ObjectDepth);
}
[Test] [Test]
[TestCase("embedded tab\t")] [TestCase("embedded tab\t")]
[TestCase("embedded CR\r")] [TestCase("embedded CR\r")]
......
...@@ -95,6 +95,13 @@ namespace Google.Protobuf ...@@ -95,6 +95,13 @@ namespace Google.Protobuf
"Use CodedInputStream.SetRecursionLimit() to increase the depth limit."); "Use CodedInputStream.SetRecursionLimit() to increase the depth limit.");
} }
internal static InvalidProtocolBufferException JsonRecursionLimitExceeded()
{
return new InvalidProtocolBufferException(
"Protocol message had too many levels of nesting. May be malicious. " +
"Use JsonParser.Settings to increase the depth limit.");
}
internal static InvalidProtocolBufferException SizeLimitExceeded() internal static InvalidProtocolBufferException SizeLimitExceeded()
{ {
return new InvalidProtocolBufferException( return new InvalidProtocolBufferException(
......
...@@ -69,16 +69,16 @@ namespace Google.Protobuf ...@@ -69,16 +69,16 @@ namespace Google.Protobuf
// TODO: Consider introducing a class containing parse state of the parser, tokenizer and depth. That would simplify these handlers // TODO: Consider introducing a class containing parse state of the parser, tokenizer and depth. That would simplify these handlers
// and the signatures of various methods. // and the signatures of various methods.
private static readonly Dictionary<string, Action<JsonParser, IMessage, JsonTokenizer, int>> private static readonly Dictionary<string, Action<JsonParser, IMessage, JsonTokenizer>>
WellKnownTypeHandlers = new Dictionary<string, Action<JsonParser, IMessage, JsonTokenizer, int>> WellKnownTypeHandlers = new Dictionary<string, Action<JsonParser, IMessage, JsonTokenizer>>
{ {
{ Timestamp.Descriptor.FullName, (parser, message, tokenizer, depth) => MergeTimestamp(message, tokenizer.Next()) }, { Timestamp.Descriptor.FullName, (parser, message, tokenizer) => MergeTimestamp(message, tokenizer.Next()) },
{ Duration.Descriptor.FullName, (parser, message, tokenizer, depth) => MergeDuration(message, tokenizer.Next()) }, { Duration.Descriptor.FullName, (parser, message, tokenizer) => MergeDuration(message, tokenizer.Next()) },
{ Value.Descriptor.FullName, (parser, message, tokenizer, depth) => parser.MergeStructValue(message, tokenizer, depth) }, { Value.Descriptor.FullName, (parser, message, tokenizer) => parser.MergeStructValue(message, tokenizer) },
{ ListValue.Descriptor.FullName, (parser, message, tokenizer, depth) => { ListValue.Descriptor.FullName, (parser, message, tokenizer) =>
parser.MergeRepeatedField(message, message.Descriptor.Fields[ListValue.ValuesFieldNumber], tokenizer, depth) }, parser.MergeRepeatedField(message, message.Descriptor.Fields[ListValue.ValuesFieldNumber], tokenizer) },
{ Struct.Descriptor.FullName, (parser, message, tokenizer, depth) => parser.MergeStruct(message, tokenizer, depth) }, { Struct.Descriptor.FullName, (parser, message, tokenizer) => parser.MergeStruct(message, tokenizer) },
{ FieldMask.Descriptor.FullName, (parser, message, tokenizer, depth) => MergeFieldMask(message, tokenizer.Next()) }, { FieldMask.Descriptor.FullName, (parser, message, tokenizer) => MergeFieldMask(message, tokenizer.Next()) },
{ Int32Value.Descriptor.FullName, MergeWrapperField }, { Int32Value.Descriptor.FullName, MergeWrapperField },
{ Int64Value.Descriptor.FullName, MergeWrapperField }, { Int64Value.Descriptor.FullName, MergeWrapperField },
{ UInt32Value.Descriptor.FullName, MergeWrapperField }, { UInt32Value.Descriptor.FullName, MergeWrapperField },
...@@ -91,9 +91,9 @@ namespace Google.Protobuf ...@@ -91,9 +91,9 @@ namespace Google.Protobuf
// Convenience method to avoid having to repeat the same code multiple times in the above // Convenience method to avoid having to repeat the same code multiple times in the above
// dictionary initialization. // dictionary initialization.
private static void MergeWrapperField(JsonParser parser, IMessage message, JsonTokenizer tokenizer, int depth) private static void MergeWrapperField(JsonParser parser, IMessage message, JsonTokenizer tokenizer)
{ {
parser.MergeField(message, message.Descriptor.Fields[Wrappers.WrapperValueFieldNumber], tokenizer, depth); parser.MergeField(message, message.Descriptor.Fields[Wrappers.WrapperValueFieldNumber], tokenizer);
} }
/// <summary> /// <summary>
...@@ -130,7 +130,7 @@ namespace Google.Protobuf ...@@ -130,7 +130,7 @@ namespace Google.Protobuf
internal void Merge(IMessage message, TextReader jsonReader) internal void Merge(IMessage message, TextReader jsonReader)
{ {
var tokenizer = new JsonTokenizer(jsonReader); var tokenizer = new JsonTokenizer(jsonReader);
Merge(message, tokenizer, 0); Merge(message, tokenizer);
var lastToken = tokenizer.Next(); var lastToken = tokenizer.Next();
if (lastToken != JsonToken.EndDocument) if (lastToken != JsonToken.EndDocument)
{ {
...@@ -145,19 +145,18 @@ namespace Google.Protobuf ...@@ -145,19 +145,18 @@ namespace Google.Protobuf
/// of tokens provided by the tokenizer. This token stream is assumed to be valid JSON, with the /// of tokens provided by the tokenizer. This token stream is assumed to be valid JSON, with the
/// tokenizer performing that validation - but not every token stream is valid "protobuf JSON". /// tokenizer performing that validation - but not every token stream is valid "protobuf JSON".
/// </summary> /// </summary>
private void Merge(IMessage message, JsonTokenizer tokenizer, int depth) private void Merge(IMessage message, JsonTokenizer tokenizer)
{ {
if (depth > settings.RecursionLimit) if (tokenizer.ObjectDepth > settings.RecursionLimit)
{ {
throw InvalidProtocolBufferException.RecursionLimitExceeded(); throw InvalidProtocolBufferException.JsonRecursionLimitExceeded();
} }
depth++;
if (message.Descriptor.IsWellKnownType) if (message.Descriptor.IsWellKnownType)
{ {
Action<JsonParser, IMessage, JsonTokenizer, int> handler; Action<JsonParser, IMessage, JsonTokenizer> handler;
if (WellKnownTypeHandlers.TryGetValue(message.Descriptor.FullName, out handler)) if (WellKnownTypeHandlers.TryGetValue(message.Descriptor.FullName, out handler))
{ {
handler(this, message, tokenizer, depth); handler(this, message, tokenizer);
return; return;
} }
// Well-known types with no special handling continue in the normal way. // Well-known types with no special handling continue in the normal way.
...@@ -188,7 +187,7 @@ namespace Google.Protobuf ...@@ -188,7 +187,7 @@ namespace Google.Protobuf
FieldDescriptor field; FieldDescriptor field;
if (jsonFieldMap.TryGetValue(name, out field)) if (jsonFieldMap.TryGetValue(name, out field))
{ {
MergeField(message, field, tokenizer, depth); MergeField(message, field, tokenizer);
} }
else else
{ {
...@@ -200,7 +199,7 @@ namespace Google.Protobuf ...@@ -200,7 +199,7 @@ namespace Google.Protobuf
} }
} }
private void MergeField(IMessage message, FieldDescriptor field, JsonTokenizer tokenizer, int depth) private void MergeField(IMessage message, FieldDescriptor field, JsonTokenizer tokenizer)
{ {
var token = tokenizer.Next(); var token = tokenizer.Next();
if (token.Type == JsonToken.TokenType.Null) if (token.Type == JsonToken.TokenType.Null)
...@@ -214,20 +213,20 @@ namespace Google.Protobuf ...@@ -214,20 +213,20 @@ namespace Google.Protobuf
if (field.IsMap) if (field.IsMap)
{ {
MergeMapField(message, field, tokenizer, depth); MergeMapField(message, field, tokenizer);
} }
else if (field.IsRepeated) else if (field.IsRepeated)
{ {
MergeRepeatedField(message, field, tokenizer, depth); MergeRepeatedField(message, field, tokenizer);
} }
else else
{ {
var value = ParseSingleValue(field, tokenizer, depth); var value = ParseSingleValue(field, tokenizer);
field.Accessor.SetValue(message, value); field.Accessor.SetValue(message, value);
} }
} }
private void MergeRepeatedField(IMessage message, FieldDescriptor field, JsonTokenizer tokenizer, int depth) private void MergeRepeatedField(IMessage message, FieldDescriptor field, JsonTokenizer tokenizer)
{ {
var token = tokenizer.Next(); var token = tokenizer.Next();
if (token.Type != JsonToken.TokenType.StartArray) if (token.Type != JsonToken.TokenType.StartArray)
...@@ -244,11 +243,11 @@ namespace Google.Protobuf ...@@ -244,11 +243,11 @@ namespace Google.Protobuf
return; return;
} }
tokenizer.PushBack(token); tokenizer.PushBack(token);
list.Add(ParseSingleValue(field, tokenizer, depth)); list.Add(ParseSingleValue(field, tokenizer));
} }
} }
private void MergeMapField(IMessage message, FieldDescriptor field, JsonTokenizer tokenizer, int depth) private void MergeMapField(IMessage message, FieldDescriptor field, JsonTokenizer tokenizer)
{ {
// Map fields are always objects, even if the values are well-known types: ParseSingleValue handles those. // Map fields are always objects, even if the values are well-known types: ParseSingleValue handles those.
var token = tokenizer.Next(); var token = tokenizer.Next();
...@@ -274,13 +273,13 @@ namespace Google.Protobuf ...@@ -274,13 +273,13 @@ namespace Google.Protobuf
return; return;
} }
object key = ParseMapKey(keyField, token.StringValue); object key = ParseMapKey(keyField, token.StringValue);
object value = ParseSingleValue(valueField, tokenizer, depth); object value = ParseSingleValue(valueField, tokenizer);
// TODO: Null handling // TODO: Null handling
dictionary[key] = value; dictionary[key] = value;
} }
} }
private object ParseSingleValue(FieldDescriptor field, JsonTokenizer tokenizer, int depth) private object ParseSingleValue(FieldDescriptor field, JsonTokenizer tokenizer)
{ {
var token = tokenizer.Next(); var token = tokenizer.Next();
if (token.Type == JsonToken.TokenType.Null) if (token.Type == JsonToken.TokenType.Null)
...@@ -308,7 +307,7 @@ namespace Google.Protobuf ...@@ -308,7 +307,7 @@ namespace Google.Protobuf
// TODO: Merge the current value in message? (Public API currently doesn't make this relevant as we don't expose merging.) // TODO: Merge the current value in message? (Public API currently doesn't make this relevant as we don't expose merging.)
tokenizer.PushBack(token); tokenizer.PushBack(token);
IMessage subMessage = NewMessageForField(field); IMessage subMessage = NewMessageForField(field);
Merge(subMessage, tokenizer, depth); Merge(subMessage, tokenizer);
return subMessage; return subMessage;
} }
} }
...@@ -358,7 +357,7 @@ namespace Google.Protobuf ...@@ -358,7 +357,7 @@ namespace Google.Protobuf
return message; return message;
} }
private void MergeStructValue(IMessage message, JsonTokenizer tokenizer, int depth) private void MergeStructValue(IMessage message, JsonTokenizer tokenizer)
{ {
var firstToken = tokenizer.Next(); var firstToken = tokenizer.Next();
var fields = message.Descriptor.Fields; var fields = message.Descriptor.Fields;
...@@ -382,7 +381,7 @@ namespace Google.Protobuf ...@@ -382,7 +381,7 @@ namespace Google.Protobuf
var field = fields[Value.StructValueFieldNumber]; var field = fields[Value.StructValueFieldNumber];
var structMessage = NewMessageForField(field); var structMessage = NewMessageForField(field);
tokenizer.PushBack(firstToken); tokenizer.PushBack(firstToken);
Merge(structMessage, tokenizer, depth); Merge(structMessage, tokenizer);
field.Accessor.SetValue(message, structMessage); field.Accessor.SetValue(message, structMessage);
return; return;
} }
...@@ -391,7 +390,7 @@ namespace Google.Protobuf ...@@ -391,7 +390,7 @@ namespace Google.Protobuf
var field = fields[Value.ListValueFieldNumber]; var field = fields[Value.ListValueFieldNumber];
var list = NewMessageForField(field); var list = NewMessageForField(field);
tokenizer.PushBack(firstToken); tokenizer.PushBack(firstToken);
Merge(list, tokenizer, depth); Merge(list, tokenizer);
field.Accessor.SetValue(message, list); field.Accessor.SetValue(message, list);
return; return;
} }
...@@ -400,7 +399,7 @@ namespace Google.Protobuf ...@@ -400,7 +399,7 @@ namespace Google.Protobuf
} }
} }
private void MergeStruct(IMessage message, JsonTokenizer tokenizer, int depth) private void MergeStruct(IMessage message, JsonTokenizer tokenizer)
{ {
var token = tokenizer.Next(); var token = tokenizer.Next();
if (token.Type != JsonToken.TokenType.StartObject) if (token.Type != JsonToken.TokenType.StartObject)
...@@ -410,7 +409,7 @@ namespace Google.Protobuf ...@@ -410,7 +409,7 @@ namespace Google.Protobuf
tokenizer.PushBack(token); tokenizer.PushBack(token);
var field = message.Descriptor.Fields[Struct.FieldsFieldNumber]; var field = message.Descriptor.Fields[Struct.FieldsFieldNumber];
MergeMapField(message, field, tokenizer, depth); MergeMapField(message, field, tokenizer);
} }
#region Utility methods which don't depend on the state (or settings) of the parser. #region Utility methods which don't depend on the state (or settings) of the parser.
......
...@@ -58,6 +58,13 @@ namespace Google.Protobuf ...@@ -58,6 +58,13 @@ namespace Google.Protobuf
private readonly PushBackReader reader; private readonly PushBackReader reader;
private JsonToken bufferedToken; private JsonToken bufferedToken;
private State state; private State state;
private int objectDepth = 0;
/// <summary>
/// Returns the depth of the stack, purely in objects (not collections).
/// Informally, this is the number of remaining unclosed '{' characters we have.
/// </summary>
internal int ObjectDepth { get { return objectDepth; } }
internal JsonTokenizer(TextReader reader) internal JsonTokenizer(TextReader reader)
{ {
...@@ -66,6 +73,8 @@ namespace Google.Protobuf ...@@ -66,6 +73,8 @@ namespace Google.Protobuf
containerStack.Push(ContainerType.Document); containerStack.Push(ContainerType.Document);
} }
// TODO: Why do we allow a different token to be pushed back? It might be better to always remember the previous
// token returned, and allow a parameterless Rewind() method (which could only be called once, just like the current PushBack).
internal void PushBack(JsonToken token) internal void PushBack(JsonToken token)
{ {
if (bufferedToken != null) if (bufferedToken != null)
...@@ -73,6 +82,14 @@ namespace Google.Protobuf ...@@ -73,6 +82,14 @@ namespace Google.Protobuf
throw new InvalidOperationException("Can't push back twice"); throw new InvalidOperationException("Can't push back twice");
} }
bufferedToken = token; bufferedToken = token;
if (token.Type == JsonToken.TokenType.StartObject)
{
objectDepth--;
}
else if (token.Type == JsonToken.TokenType.EndObject)
{
objectDepth++;
}
} }
/// <summary> /// <summary>
...@@ -94,6 +111,14 @@ namespace Google.Protobuf ...@@ -94,6 +111,14 @@ namespace Google.Protobuf
{ {
var ret = bufferedToken; var ret = bufferedToken;
bufferedToken = null; bufferedToken = null;
if (ret.Type == JsonToken.TokenType.StartObject)
{
objectDepth++;
}
else if (ret.Type == JsonToken.TokenType.EndObject)
{
objectDepth--;
}
return ret; return ret;
} }
if (state == State.ReaderExhausted) if (state == State.ReaderExhausted)
...@@ -141,10 +166,12 @@ namespace Google.Protobuf ...@@ -141,10 +166,12 @@ namespace Google.Protobuf
ValidateState(ValueStates, "Invalid state to read an open brace: "); ValidateState(ValueStates, "Invalid state to read an open brace: ");
state = State.ObjectStart; state = State.ObjectStart;
containerStack.Push(ContainerType.Object); containerStack.Push(ContainerType.Object);
objectDepth++;
return JsonToken.StartObject; return JsonToken.StartObject;
case '}': case '}':
ValidateState(State.ObjectAfterProperty | State.ObjectStart, "Invalid state to read a close brace: "); ValidateState(State.ObjectAfterProperty | State.ObjectStart, "Invalid state to read a close brace: ");
PopContainer(); PopContainer();
objectDepth--;
return JsonToken.EndObject; return JsonToken.EndObject;
case '[': case '[':
ValidateState(ValueStates, "Invalid state to read an open square bracket: "); ValidateState(ValueStates, "Invalid state to read an open square bracket: ");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment