Commit f3c75580 authored by Joshua Humphries's avatar Joshua Humphries

throw IOException instead of InvalidProtocolBufferException when appropriate

parent 9a5d892e
...@@ -36,15 +36,13 @@ import com.google.protobuf.Descriptors.EnumValueDescriptor; ...@@ -36,15 +36,13 @@ import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.OneofDescriptor; import com.google.protobuf.Descriptors.OneofDescriptor;
import com.google.protobuf.GeneratedMessageLite.ExtendableMessage;
import com.google.protobuf.GeneratedMessageLite.GeneratedExtension;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectStreamException; import java.io.ObjectStreamException;
import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
...@@ -276,6 +274,60 @@ public abstract class GeneratedMessage extends AbstractMessage ...@@ -276,6 +274,60 @@ public abstract class GeneratedMessage extends AbstractMessage
return unknownFields.mergeFieldFrom(tag, input); return unknownFields.mergeFieldFrom(tag, input);
} }
protected static <M extends Message> M parseWithIOException(Parser<M> parser, InputStream input)
throws IOException {
try {
return parser.parseFrom(input);
} catch (InvalidProtocolBufferException e) {
throw e.unwrapIOException();
}
}
protected static <M extends Message> M parseWithIOException(Parser<M> parser, InputStream input,
ExtensionRegistryLite extensions) throws IOException {
try {
return parser.parseFrom(input, extensions);
} catch (InvalidProtocolBufferException e) {
throw e.unwrapIOException();
}
}
protected static <M extends Message> M parseWithIOException(Parser<M> parser,
CodedInputStream input) throws IOException {
try {
return parser.parseFrom(input);
} catch (InvalidProtocolBufferException e) {
throw e.unwrapIOException();
}
}
protected static <M extends Message> M parseWithIOException(Parser<M> parser,
CodedInputStream input, ExtensionRegistryLite extensions) throws IOException {
try {
return parser.parseFrom(input, extensions);
} catch (InvalidProtocolBufferException e) {
throw e.unwrapIOException();
}
}
protected static <M extends Message> M parseDelimitedWithIOException(Parser<M> parser,
InputStream input) throws IOException {
try {
return parser.parseDelimitedFrom(input);
} catch (InvalidProtocolBufferException e) {
throw e.unwrapIOException();
}
}
protected static <M extends Message> M parseDelimitedWithIOException(Parser<M> parser,
InputStream input, ExtensionRegistryLite extensions) throws IOException {
try {
return parser.parseDelimitedFrom(input, extensions);
} catch (InvalidProtocolBufferException e) {
throw e.unwrapIOException();
}
}
@Override @Override
public void writeTo(final CodedOutputStream output) throws IOException { public void writeTo(final CodedOutputStream output) throws IOException {
MessageReflection.writeMessageTo(this, getAllFieldsRaw(), output, false); MessageReflection.writeMessageTo(this, getAllFieldsRaw(), output, false);
...@@ -667,7 +719,7 @@ public abstract class GeneratedMessage extends AbstractMessage ...@@ -667,7 +719,7 @@ public abstract class GeneratedMessage extends AbstractMessage
"No map fields found in " + getClass().getName()); "No map fields found in " + getClass().getName());
} }
/** Like {@link internalGetMapField} but return a mutable version. */ /** Like {@link #internalGetMapField} but return a mutable version. */
@SuppressWarnings({"unused", "rawtypes"}) @SuppressWarnings({"unused", "rawtypes"})
protected MapField internalGetMutableMapField(int fieldNumber) { protected MapField internalGetMutableMapField(int fieldNumber) {
// Note that we can't use descriptor names here because this method will // Note that we can't use descriptor names here because this method will
......
...@@ -46,6 +46,10 @@ public class InvalidProtocolBufferException extends IOException { ...@@ -46,6 +46,10 @@ public class InvalidProtocolBufferException extends IOException {
super(description); super(description);
} }
public InvalidProtocolBufferException(IOException e) {
super(e.getMessage(), e);
}
/** /**
* Attaches an unfinished message to the exception to support best-effort * Attaches an unfinished message to the exception to support best-effort
* parsing in {@code Parser} interface. * parsing in {@code Parser} interface.
...@@ -66,6 +70,14 @@ public class InvalidProtocolBufferException extends IOException { ...@@ -66,6 +70,14 @@ public class InvalidProtocolBufferException extends IOException {
return unfinishedMessage; return unfinishedMessage;
} }
/**
* Unwraps the underlying {@link IOException} if this exception was caused by an I/O
* problem. Otherwise, returns {@code this}.
*/
public IOException unwrapIOException() {
return getCause() instanceof IOException ? (IOException) getCause() : this;
}
static InvalidProtocolBufferException truncatedMessage() { static InvalidProtocolBufferException truncatedMessage() {
return new InvalidProtocolBufferException( return new InvalidProtocolBufferException(
"While parsing a protocol message, the input ended unexpectedly " + "While parsing a protocol message, the input ended unexpectedly " +
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
package com.google.protobuf; package com.google.protobuf;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
/** /**
...@@ -37,9 +38,20 @@ import java.io.InputStream; ...@@ -37,9 +38,20 @@ import java.io.InputStream;
* *
* The implementation should be stateless and thread-safe. * The implementation should be stateless and thread-safe.
* *
* <p>All methods may throw {@link InvalidProtocolBufferException}. In the event of invalid data,
* like an encoding error, the cause of the thrown exception will be {@code null}. However, if an
* I/O problem occurs, an exception is thrown with an {@link IOException} cause.
*
* @author liujisi@google.com (Pherl Liu) * @author liujisi@google.com (Pherl Liu)
*/ */
public interface Parser<MessageType> { public interface Parser<MessageType> {
// NB(jh): Other parts of the protobuf API that parse messages distinguish between an I/O problem
// (like failure reading bytes from a socket) and invalid data (encoding error) via the type of
// thrown exception. But it would be source-incompatible to make the methods in this interface do
// so since they were originally spec'ed to only throw InvalidProtocolBufferException. So callers
// must inspect the cause of the exception to distinguish these two cases.
/** /**
* Parses a message of {@code MessageType} from the input. * Parses a message of {@code MessageType} from the input.
* *
......
package com.google.protobuf;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Tests the exceptions thrown when parsing from a stream. The methods on the {@link Parser}
* interface are specified to only throw {@link InvalidProtocolBufferException}. But we really want
* to distinguish between invalid protos vs. actual I/O errors (like failures reading from a
* socket, etc.). So, when we're not using the parser directly, an {@link IOException} should be
* thrown where appropriate, instead of always an {@link InvalidProtocolBufferException}.
*
* @author jh@squareup.com (Joshua Humphries)
*/
public class ParseExceptionsTest {
private interface ParseTester {
DescriptorProto parse(InputStream in) throws IOException;
}
private byte serializedProto[];
private void setup() {
serializedProto = DescriptorProto.getDescriptor().toProto().toByteArray();
}
private void setupDelimited() {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
DescriptorProto.getDescriptor().toProto().writeDelimitedTo(bos);
} catch (IOException e) {
fail("Exception not expected: " + e);
}
serializedProto = bos.toByteArray();
}
@Test public void message_parseFrom_InputStream() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.parseFrom(in);
}
});
}
@Test public void message_parseFrom_InputStreamAndExtensionRegistry() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.parseFrom(in, ExtensionRegistry.newInstance());
}
});
}
@Test public void message_parseFrom_CodedInputStream() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.parseFrom(CodedInputStream.newInstance(in));
}
});
}
@Test public void message_parseFrom_CodedInputStreamAndExtensionRegistry() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.parseFrom(CodedInputStream.newInstance(in),
ExtensionRegistry.newInstance());
}
});
}
@Test public void message_parseDelimitedFrom_InputStream() {
setupDelimited();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.parseDelimitedFrom(in);
}
});
}
@Test public void message_parseDelimitedFrom_InputStreamAndExtensionRegistry() {
setupDelimited();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.parseDelimitedFrom(in, ExtensionRegistry.newInstance());
}
});
}
@Test public void messageBuilder_mergeFrom_InputStream() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.newBuilder().mergeFrom(in).build();
}
});
}
@Test public void messageBuilder_mergeFrom_InputStreamAndExtensionRegistry() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.newBuilder().mergeFrom(in, ExtensionRegistry.newInstance()).build();
}
});
}
@Test public void messageBuilder_mergeFrom_CodedInputStream() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.newBuilder().mergeFrom(CodedInputStream.newInstance(in)).build();
}
});
}
@Test public void messageBuilder_mergeFrom_CodedInputStreamAndExtensionRegistry() {
setup();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
return DescriptorProto.newBuilder()
.mergeFrom(CodedInputStream.newInstance(in), ExtensionRegistry.newInstance()).build();
}
});
}
@Test public void messageBuilder_mergeDelimitedFrom_InputStream() {
setupDelimited();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
DescriptorProto.Builder builder = DescriptorProto.newBuilder();
builder.mergeDelimitedFrom(in);
return builder.build();
}
});
}
@Test public void messageBuilder_mergeDelimitedFrom_InputStreamAndExtensionRegistry() {
setupDelimited();
verifyExceptions(new ParseTester() {
public DescriptorProto parse(InputStream in) throws IOException {
DescriptorProto.Builder builder = DescriptorProto.newBuilder();
builder.mergeDelimitedFrom(in, ExtensionRegistry.newInstance());
return builder.build();
}
});
}
private void verifyExceptions(ParseTester parseTester) {
// No exception
try {
assertEquals(DescriptorProto.getDescriptor().toProto(),
parseTester.parse(new ByteArrayInputStream(serializedProto)));
} catch (IOException e) {
fail("No exception expected: " + e);
}
// IOException
try {
// using a "broken" stream that will throw part-way through reading the message
parseTester.parse(broken(new ByteArrayInputStream(serializedProto)));
fail("IOException expected but not thrown");
} catch (IOException e) {
assertFalse(e instanceof InvalidProtocolBufferException);
}
// InvalidProtocolBufferException
try {
// make the serialized proto invalid
for (int i = 0; i < 50; i++) {
serializedProto[i] = -1;
}
parseTester.parse(new ByteArrayInputStream(serializedProto));
fail("InvalidProtocolBufferException expected but not thrown");
} catch (IOException e) {
assertTrue(e instanceof InvalidProtocolBufferException);
}
}
private InputStream broken(InputStream i) {
return new FilterInputStream(i) {
int count = 0;
@Override public int read() throws IOException {
if (count++ >= 50) {
throw new IOException("I'm broken!");
}
return super.read();
}
@Override public int read(byte b[], int off, int len) throws IOException {
if ((count += len) >= 50) {
throw new IOException("I'm broken!");
}
return super.read(b, off, len);
}
};
}
}
...@@ -664,34 +664,34 @@ GenerateParseFromMethods(io::Printer* printer) { ...@@ -664,34 +664,34 @@ GenerateParseFromMethods(io::Printer* printer) {
"}\n" "}\n"
"public static $classname$ parseFrom(java.io.InputStream input)\n" "public static $classname$ parseFrom(java.io.InputStream input)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return PARSER.parseFrom(input);\n" " return parseWithIOException(PARSER, input);"
"}\n" "}\n"
"public static $classname$ parseFrom(\n" "public static $classname$ parseFrom(\n"
" java.io.InputStream input,\n" " java.io.InputStream input,\n"
" com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return PARSER.parseFrom(input, extensionRegistry);\n" " return parseWithIOException(PARSER, input, extensionRegistry);"
"}\n" "}\n"
"public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n" "public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return PARSER.parseDelimitedFrom(input);\n" " return parseDelimitedWithIOException(PARSER, input);"
"}\n" "}\n"
"public static $classname$ parseDelimitedFrom(\n" "public static $classname$ parseDelimitedFrom(\n"
" java.io.InputStream input,\n" " java.io.InputStream input,\n"
" com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return PARSER.parseDelimitedFrom(input, extensionRegistry);\n" " return parseDelimitedWithIOException(PARSER, input, extensionRegistry);"
"}\n" "}\n"
"public static $classname$ parseFrom(\n" "public static $classname$ parseFrom(\n"
" com.google.protobuf.CodedInputStream input)\n" " com.google.protobuf.CodedInputStream input)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return PARSER.parseFrom(input);\n" " return parseWithIOException(PARSER, input);"
"}\n" "}\n"
"public static $classname$ parseFrom(\n" "public static $classname$ parseFrom(\n"
" com.google.protobuf.CodedInputStream input,\n" " com.google.protobuf.CodedInputStream input,\n"
" com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return PARSER.parseFrom(input, extensionRegistry);\n" " return parseWithIOException(PARSER, input, extensionRegistry);"
"}\n" "}\n"
"\n", "\n",
"classname", name_resolver_->GetImmutableClassName(descriptor_)); "classname", name_resolver_->GetImmutableClassName(descriptor_));
...@@ -1217,9 +1217,8 @@ GenerateParsingConstructor(io::Printer* printer) { ...@@ -1217,9 +1217,8 @@ GenerateParsingConstructor(io::Printer* printer) {
"} catch (com.google.protobuf.InvalidProtocolBufferException e) {\n" "} catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
" throw new RuntimeException(e.setUnfinishedMessage(this));\n" " throw new RuntimeException(e.setUnfinishedMessage(this));\n"
"} catch (java.io.IOException e) {\n" "} catch (java.io.IOException e) {\n"
" throw new RuntimeException(\n" " throw new RuntimeException(new com.google.protobuf.InvalidProtocolBufferException(e)\n"
" new com.google.protobuf.InvalidProtocolBufferException(\n" " .setUnfinishedMessage(this));\n"
" e.getMessage()).setUnfinishedMessage(this));\n"
"} finally {\n"); "} finally {\n");
printer->Indent(); printer->Indent();
......
...@@ -538,7 +538,7 @@ GenerateBuilderParsingMethods(io::Printer* printer) { ...@@ -538,7 +538,7 @@ GenerateBuilderParsingMethods(io::Printer* printer) {
" parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);\n" " parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry);\n"
" } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n" " } catch (com.google.protobuf.InvalidProtocolBufferException e) {\n"
" parsedMessage = ($classname$) e.getUnfinishedMessage();\n" " parsedMessage = ($classname$) e.getUnfinishedMessage();\n"
" throw e;\n" " throw e.unwrapIOException();\n"
" } finally {\n" " } finally {\n"
" if (parsedMessage != null) {\n" " if (parsedMessage != null) {\n"
" mergeFrom(parsedMessage);\n" " mergeFrom(parsedMessage);\n"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment