Commit 2ac8946e authored by Chris Bacon's avatar Chris Bacon

Allow extra fields in wrapper messages, more tests.

parent e305e56c
...@@ -386,7 +386,7 @@ namespace Google.Protobuf.WellKnownTypes ...@@ -386,7 +386,7 @@ namespace Google.Protobuf.WellKnownTypes
} }
[Test] [Test]
public void UnknownFieldInWrapper() public void UnknownFieldInWrapperInt32FastPath()
{ {
var stream = new MemoryStream(); var stream = new MemoryStream();
var output = new CodedOutputStream(stream); var output = new CodedOutputStream(stream);
...@@ -395,19 +395,96 @@ namespace Google.Protobuf.WellKnownTypes ...@@ -395,19 +395,96 @@ namespace Google.Protobuf.WellKnownTypes
var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint); var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint);
output.WriteTag(wrapperTag); output.WriteTag(wrapperTag);
output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte // Wrapper message is just long enough - 6 bytes - to use the wrapper fast-path.
output.WriteLength(6); // unknownTag + value 5 + valueType, each 1 byte, + value 65536, 3 bytes
output.WriteTag(unknownTag); output.WriteTag(unknownTag);
output.WriteInt32((int) valueTag); // Sneakily "pretend" it's a tag when it's really a value output.WriteInt32((int) valueTag); // Sneakily "pretend" it's a tag when it's really a value
output.WriteTag(valueTag); output.WriteTag(valueTag);
output.WriteInt32(65536);
output.Flush();
Assert.AreEqual(8, stream.Length); // tag (1 byte) + length (1 byte) + message (6 bytes)
stream.Position = 0;
var message = TestWellKnownTypes.Parser.ParseFrom(stream);
Assert.AreEqual(65536, message.Int32Field);
}
[Test]
public void UnknownFieldInWrapperInt32SlowPath()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int32FieldFieldNumber, WireFormat.WireType.LengthDelimited);
var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint);
var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint);
output.WriteTag(wrapperTag);
// Wrapper message is too short to be used on the wrapper fast-path.
output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte
output.WriteTag(unknownTag);
output.WriteInt32((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value
output.WriteTag(valueTag);
output.WriteInt32(6); output.WriteInt32(6);
output.Flush(); output.Flush();
Assert.Less(stream.Length, 8); // tag (1 byte) + length (1 byte) + message
stream.Position = 0; stream.Position = 0;
var message = TestWellKnownTypes.Parser.ParseFrom(stream); var message = TestWellKnownTypes.Parser.ParseFrom(stream);
Assert.AreEqual(6, message.Int32Field); Assert.AreEqual(6, message.Int32Field);
} }
[Test]
public void UnknownFieldInWrapperInt64FastPath()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int64FieldFieldNumber, WireFormat.WireType.LengthDelimited);
var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint);
var valueTag = WireFormat.MakeTag(Int64Value.ValueFieldNumber, WireFormat.WireType.Varint);
output.WriteTag(wrapperTag);
// Wrapper message is just long enough - 10 bytes - to use the wrapper fast-path.
output.WriteLength(11); // unknownTag + value 5 + valueType, each 1 byte, + value 0xfffffffffffff, 8 bytes
output.WriteTag(unknownTag);
output.WriteInt64((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value
output.WriteTag(valueTag);
output.WriteInt64(0xfffffffffffffL);
output.Flush();
Assert.AreEqual(13, stream.Length); // tag (1 byte) + length (1 byte) + message (11 bytes)
stream.Position = 0;
var message = TestWellKnownTypes.Parser.ParseFrom(stream);
Assert.AreEqual(0xfffffffffffffL, message.Int64Field);
}
[Test]
public void UnknownFieldInWrapperInt64SlowPath()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int64FieldFieldNumber, WireFormat.WireType.LengthDelimited);
var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint);
var valueTag = WireFormat.MakeTag(Int64Value.ValueFieldNumber, WireFormat.WireType.Varint);
output.WriteTag(wrapperTag);
// Wrapper message is too short to be used on the wrapper fast-path.
output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte
output.WriteTag(unknownTag);
output.WriteInt64((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value
output.WriteTag(valueTag);
output.WriteInt64(6);
output.Flush();
Assert.Less(stream.Length, 12); // tag (1 byte) + length (1 byte) + message
stream.Position = 0;
var message = TestWellKnownTypes.Parser.ParseFrom(stream);
Assert.AreEqual(6L, message.Int64Field);
}
[Test] [Test]
public void ClearWithReflection() public void ClearWithReflection()
{ {
......
...@@ -737,29 +737,76 @@ namespace Google.Protobuf ...@@ -737,29 +737,76 @@ namespace Google.Protobuf
return false; return false;
} }
internal static float? ReadFloatWrapperLittleEndian(CodedInputStream input)
{
// length:1 + tag:1 + value:4 = 6 bytes
if (input.bufferPos + 6 <= input.bufferSize)
{
// The entire wrapper message is already contained in `buffer`.
int length = input.buffer[input.bufferPos];
if (length == 0)
{
input.bufferPos++;
return 0F;
}
// tag:1 + value:4 = length of 5 bytes
// field=1, type=32-bit = tag of 13
if (length != 5 || input.buffer[input.bufferPos + 1] != 13)
{
return ReadFloatWrapperSlow(input);
}
var result = BitConverter.ToSingle(input.buffer, input.bufferPos + 2);
input.bufferPos += 6;
return result;
}
else
{
return ReadFloatWrapperSlow(input);
}
}
internal static float? ReadFloatWrapperSlow(CodedInputStream input)
{
int length = input.ReadLength();
if (length == 0)
{
return 0F;
}
int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
float result = 0F;
do
{
// field=1, type=32-bit = tag of 13
if (input.ReadTag() == 13)
{
result = input.ReadFloat();
}
else
{
input.SkipLastField();
}
}
while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
return result;
}
internal static double? ReadDoubleWrapperLittleEndian(CodedInputStream input) internal static double? ReadDoubleWrapperLittleEndian(CodedInputStream input)
{ {
// tag:1 + value:8 = 9 bytes
const int expectedLength = 9;
// field=1, type=64-bit = tag of 9
const int expectedTag = 9;
// length:1 + tag:1 + value:8 = 10 bytes // length:1 + tag:1 + value:8 = 10 bytes
if (input.bufferPos + 10 <= input.bufferSize) if (input.bufferPos + 10 <= input.bufferSize)
{ {
// The entire wrapper message is already contained in `buffer`.
int length = input.buffer[input.bufferPos]; int length = input.buffer[input.bufferPos];
if (length == 0) if (length == 0)
{ {
input.bufferPos++; input.bufferPos++;
return 0D; return 0D;
} }
if (length != expectedLength) // tag:1 + value:8 = length of 9 bytes
{
throw InvalidProtocolBufferException.InvalidWrapperMessageLength();
}
// field=1, type=64-bit = tag of 9 // field=1, type=64-bit = tag of 9
if (input.buffer[input.bufferPos + 1] != expectedTag) if (length != 9 || input.buffer[input.bufferPos + 1] != 9)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageTag(); return ReadDoubleWrapperSlow(input);
} }
var result = BitConverter.ToDouble(input.buffer, input.bufferPos + 2); var result = BitConverter.ToDouble(input.buffer, input.bufferPos + 2);
input.bufferPos += 10; input.bufferPos += 10;
...@@ -767,50 +814,119 @@ namespace Google.Protobuf ...@@ -767,50 +814,119 @@ namespace Google.Protobuf
} }
else else
{ {
int length = input.ReadLength(); return ReadDoubleWrapperSlow(input);
}
}
internal static double? ReadDoubleWrapperSlow(CodedInputStream input)
{
int length = input.ReadLength();
if (length == 0)
{
return 0D;
}
int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
double result = 0D;
do
{
// field=1, type=64-bit = tag of 9
if (input.ReadTag() == 9)
{
result = input.ReadDouble();
}
else
{
input.SkipLastField();
}
}
while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
return result;
}
internal static bool? ReadBoolWrapper(CodedInputStream input)
{
return ReadUInt32Wrapper(input) != 0;
}
internal static uint? ReadUInt32Wrapper(CodedInputStream input)
{
// length:1 + tag:1 + value:5(varint32-max) = 7 bytes
if (input.bufferPos + 7 <= input.bufferSize)
{
// The entire wrapper message is already contained in `buffer`.
int pos0 = input.bufferPos;
int length = input.buffer[input.bufferPos++];
if (length == 0) if (length == 0)
{ {
return 0D; return 0;
}
// Length will always fit in a single byte.
if (length >= 128)
{
input.bufferPos = pos0;
return ReadUInt32WrapperSlow(input);
} }
if (length != expectedLength) int finalBufferPos = input.bufferPos + length;
// field=1, type=varint = tag of 8
if (input.buffer[input.bufferPos++] != 8)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageLength(); input.bufferPos = pos0;
return ReadUInt32WrapperSlow(input);
} }
if (input.ReadTag() != expectedTag) var result = input.ReadUInt32();
// Verify this message only contained a single field.
if (input.bufferPos != finalBufferPos)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageTag(); input.bufferPos = pos0;
return ReadUInt32WrapperSlow(input);
} }
return input.ReadDouble(); return result;
}
else
{
return ReadUInt32WrapperSlow(input);
} }
} }
internal static double? ReadDoubleWrapperBigEndian(CodedInputStream input) private static uint? ReadUInt32WrapperSlow(CodedInputStream input)
{ {
int length = input.ReadLength(); int length = input.ReadLength();
if (length == 0) if (length == 0)
{ {
return 0D; return 0;
}
// tag:1 + value:8 = 9 bytes
if (length != 9)
{
throw InvalidProtocolBufferException.InvalidWrapperMessageLength();
} }
// field=1, type=64-bit = tag of 9 int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
if (input.ReadTag() != 9) uint result = 0;
do
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageTag(); // field=1, type=varint = tag of 8
if (input.ReadTag() == 8)
{
result = input.ReadUInt32();
}
else
{
input.SkipLastField();
}
} }
return input.ReadDouble(); while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
return result;
} }
internal static long? ReadInt64Wrapper(CodedInputStream input) internal static int? ReadInt32Wrapper(CodedInputStream input)
{
return (int?)ReadUInt32Wrapper(input);
}
internal static ulong? ReadUInt64Wrapper(CodedInputStream input)
{ {
// field=1, type=varint = tag of 8 // field=1, type=varint = tag of 8
const int expectedTag = 8; const int expectedTag = 8;
// length:1 + tag:1 + value:10(varint64-max) = 12 bytes // length:1 + tag:1 + value:10(varint64-max) = 12 bytes
if (input.bufferPos + 12 <= input.bufferSize) if (input.bufferPos + 12 <= input.bufferSize)
{ {
// The entire wrapper message is already contained in `buffer`.
int pos0 = input.bufferPos;
int length = input.buffer[input.bufferPos++]; int length = input.buffer[input.bufferPos++];
if (length == 0) if (length == 0)
{ {
...@@ -819,43 +935,61 @@ namespace Google.Protobuf ...@@ -819,43 +935,61 @@ namespace Google.Protobuf
// Length will always fit in a single byte. // Length will always fit in a single byte.
if (length >= 128) if (length >= 128)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageLength(); input.bufferPos = pos0;
return ReadUInt64WrapperSlow(input);
} }
int finalBufferPos = input.bufferPos + length; int finalBufferPos = input.bufferPos + length;
if (input.buffer[input.bufferPos++] != expectedTag) if (input.buffer[input.bufferPos++] != expectedTag)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageTag(); input.bufferPos = pos0;
return ReadUInt64WrapperSlow(input);
} }
var result = input.ReadInt64(); var result = input.ReadUInt64();
// Verify this message only contained a single field. // Verify this message only contained a single field.
if (input.bufferPos != finalBufferPos) if (input.bufferPos != finalBufferPos)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageExtraFields(); input.bufferPos = pos0;
return ReadUInt64WrapperSlow(input);
} }
return result; return result;
} }
else else
{ {
int length = input.ReadLength(); return ReadUInt64WrapperSlow(input);
if (length == 0) }
{ }
return 0L;
} internal static ulong? ReadUInt64WrapperSlow(CodedInputStream input)
int finalBufferPos = input.totalBytesRetired + input.bufferPos + length; {
if (input.ReadTag() != expectedTag) // field=1, type=varint = tag of 8
const int expectedTag = 8;
int length = input.ReadLength();
if (length == 0)
{
return 0L;
}
int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
ulong result = 0L;
do
{
if (input.ReadTag() == expectedTag)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageTag(); result = input.ReadUInt64();
} }
var result = input.ReadInt64(); else
// Verify this message only contained a single field.
if (input.totalBytesRetired + input.bufferPos != finalBufferPos)
{ {
throw InvalidProtocolBufferException.InvalidWrapperMessageExtraFields(); input.SkipLastField();
} }
return result;
} }
while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
return result;
} }
internal static long? ReadInt64Wrapper(CodedInputStream input)
{
return (long?)ReadUInt64Wrapper(input);
}
#endregion #endregion
#region Underlying reading primitives #region Underlying reading primitives
......
...@@ -539,18 +539,21 @@ namespace Google.Protobuf ...@@ -539,18 +539,21 @@ namespace Google.Protobuf
{ typeof(ByteString), ForBytes(WireFormat.MakeTag(WrappersReflection.WrapperValueFieldNumber, WireFormat.WireType.LengthDelimited)) } { typeof(ByteString), ForBytes(WireFormat.MakeTag(WrappersReflection.WrapperValueFieldNumber, WireFormat.WireType.LengthDelimited)) }
}; };
private static readonly Dictionary<System.Type, Func<object>> Readers = new Dictionary<System.Type, Func<object>> private static readonly Dictionary<System.Type, object> Readers = new Dictionary<System.Type, object>
{ {
// TODO: Provide more optimized readers. // TODO: Provide more optimized readers.
{ typeof(bool), null }, { typeof(bool), (Func<CodedInputStream, bool?>)CodedInputStream.ReadBoolWrapper },
{ typeof(int), null }, { typeof(int), (Func<CodedInputStream, int?>)CodedInputStream.ReadInt32Wrapper },
{ typeof(long), () => (Func<CodedInputStream, long?>)CodedInputStream.ReadInt64Wrapper }, { typeof(long), (Func<CodedInputStream, long?>)CodedInputStream.ReadInt64Wrapper },
{ typeof(uint), null }, { typeof(uint), (Func<CodedInputStream, uint?>)CodedInputStream.ReadUInt32Wrapper },
{ typeof(ulong), null }, { typeof(ulong), (Func<CodedInputStream, ulong?>)CodedInputStream.ReadUInt64Wrapper },
{ typeof(float), null }, { typeof(float), BitConverter.IsLittleEndian ?
{ typeof(double), () => BitConverter.IsLittleEndian ? (Func<CodedInputStream, float?>)CodedInputStream.ReadFloatWrapperLittleEndian :
(Func<CodedInputStream, float?>)CodedInputStream.ReadFloatWrapperSlow },
{ typeof(double), BitConverter.IsLittleEndian ?
(Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperLittleEndian : (Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperLittleEndian :
(Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperBigEndian }, (Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperSlow },
// `string` and `ByteString` less performance-sensitive. Do not implement for now.
{ typeof(string), null }, { typeof(string), null },
{ typeof(ByteString), null }, { typeof(ByteString), null },
}; };
...@@ -571,7 +574,7 @@ namespace Google.Protobuf ...@@ -571,7 +574,7 @@ namespace Google.Protobuf
internal static Func<CodedInputStream, T?> GetReader<T>() where T : struct internal static Func<CodedInputStream, T?> GetReader<T>() where T : struct
{ {
Func<object> value; object value;
if (!Readers.TryGetValue(typeof(T), out value)) if (!Readers.TryGetValue(typeof(T), out value))
{ {
throw new InvalidOperationException("Invalid type argument requested for wrapper reader: " + typeof(T)); throw new InvalidOperationException("Invalid type argument requested for wrapper reader: " + typeof(T));
...@@ -583,7 +586,7 @@ namespace Google.Protobuf ...@@ -583,7 +586,7 @@ namespace Google.Protobuf
return input => Read<T>(input, nestedCoded); return input => Read<T>(input, nestedCoded);
} }
// Return optimized read for the wrapper type. // Return optimized read for the wrapper type.
return (Func<CodedInputStream, T?>)value(); return (Func<CodedInputStream, T?>)value;
} }
internal static T Read<T>(CodedInputStream input, FieldCodec<T> codec) internal static T Read<T>(CodedInputStream input, FieldCodec<T> codec)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment