Unverified Commit 6de51cae authored by Feng Xiao's avatar Feng Xiao Committed by GitHub

Merge pull request #3824 from anuraaga/dev_rag

[Java] Add a UTF-8 decoder that uses Unsafe to directly decode a byte buffer.
parents da89eb25 3e944aec
...@@ -286,6 +286,7 @@ java_EXTRA_DIST= ...@@ -286,6 +286,7 @@ java_EXTRA_DIST=
java/core/src/test/java/com/google/protobuf/CheckUtf8Test.java \ java/core/src/test/java/com/google/protobuf/CheckUtf8Test.java \
java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java \ java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java \
java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java \ java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java \
java/core/src/test/java/com/google/protobuf/DecodeUtf8Test.java \
java/core/src/test/java/com/google/protobuf/DeprecatedFieldTest.java \ java/core/src/test/java/com/google/protobuf/DeprecatedFieldTest.java \
java/core/src/test/java/com/google/protobuf/DescriptorsTest.java \ java/core/src/test/java/com/google/protobuf/DescriptorsTest.java \
java/core/src/test/java/com/google/protobuf/DiscardUnknownFieldsTest.java \ java/core/src/test/java/com/google/protobuf/DiscardUnknownFieldsTest.java \
......
...@@ -64,6 +64,14 @@ public abstract class CodedInputStream { ...@@ -64,6 +64,14 @@ public abstract class CodedInputStream {
// Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h // Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h
private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE; private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE;
/**
* Whether to enable our custom UTF-8 decode codepath which does not use {@link StringCoding}.
* Enabled by default, disable by setting
* {@code -Dcom.google.protobuf.enableCustomutf8Decode=false} in JVM args.
*/
private static final boolean ENABLE_CUSTOM_UTF8_DECODE
= !"false".equals(System.getProperty("com.google.protobuf.enableCustomUtf8Decode"));
/** Visible for subclasses. See setRecursionLimit() */ /** Visible for subclasses. See setRecursionLimit() */
int recursionDepth; int recursionDepth;
...@@ -825,13 +833,19 @@ public abstract class CodedInputStream { ...@@ -825,13 +833,19 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException { public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32(); final int size = readRawVarint32();
if (size > 0 && size <= (limit - pos)) { if (size > 0 && size <= (limit - pos)) {
// TODO(martinrb): We could save a pass by validating while decoding. if (ENABLE_CUSTOM_UTF8_DECODE) {
if (!Utf8.isValidUtf8(buffer, pos, pos + size)) { String result = Utf8.decodeUtf8(buffer, pos, size);
throw InvalidProtocolBufferException.invalidUtf8(); pos += size;
return result;
} else {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(buffer, pos, pos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
final int tempPos = pos;
pos += size;
return new String(buffer, tempPos, size, UTF_8);
} }
final int tempPos = pos;
pos += size;
return new String(buffer, tempPos, size, UTF_8);
} }
if (size == 0) { if (size == 0) {
...@@ -1524,6 +1538,8 @@ public abstract class CodedInputStream { ...@@ -1524,6 +1538,8 @@ public abstract class CodedInputStream {
final int size = readRawVarint32(); final int size = readRawVarint32();
if (size > 0 && size <= remaining()) { if (size > 0 && size <= remaining()) {
// TODO(nathanmittler): Is there a way to avoid this copy? // TODO(nathanmittler): Is there a way to avoid this copy?
// TODO(anuraaga): It might be possible to share the optimized loop with
// readStringRequireUtf8 by implementing Java replacement logic there.
// The same as readBytes' logic // The same as readBytes' logic
byte[] bytes = new byte[size]; byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(pos, bytes, 0, size); UnsafeUtil.copyMemory(pos, bytes, 0, size);
...@@ -1544,19 +1560,26 @@ public abstract class CodedInputStream { ...@@ -1544,19 +1560,26 @@ public abstract class CodedInputStream {
@Override @Override
public String readStringRequireUtf8() throws IOException { public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32(); final int size = readRawVarint32();
if (size >= 0 && size <= remaining()) { if (size > 0 && size <= remaining()) {
// TODO(nathanmittler): Is there a way to avoid this copy? if (ENABLE_CUSTOM_UTF8_DECODE) {
// The same as readBytes' logic final int bufferPos = bufferPos(pos);
byte[] bytes = new byte[size]; String result = Utf8.decodeUtf8(buffer, bufferPos, size);
UnsafeUtil.copyMemory(pos, bytes, 0, size); pos += size;
// TODO(martinrb): We could save a pass by validating while decoding. return result;
if (!Utf8.isValidUtf8(bytes)) { } else {
throw InvalidProtocolBufferException.invalidUtf8(); // TODO(nathanmittler): Is there a way to avoid this copy?
} // The same as readBytes' logic
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(pos, bytes, 0, size);
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8); String result = new String(bytes, UTF_8);
pos += size; pos += size;
return result; return result;
}
} }
if (size == 0) { if (size == 0) {
...@@ -2324,11 +2347,15 @@ public abstract class CodedInputStream { ...@@ -2324,11 +2347,15 @@ public abstract class CodedInputStream {
bytes = readRawBytesSlowPath(size); bytes = readRawBytesSlowPath(size);
tempPos = 0; tempPos = 0;
} }
// TODO(martinrb): We could save a pass by validating while decoding. if (ENABLE_CUSTOM_UTF8_DECODE) {
if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) { return Utf8.decodeUtf8(bytes, tempPos, size);
throw InvalidProtocolBufferException.invalidUtf8(); } else {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
return new String(bytes, tempPos, size, UTF_8);
} }
return new String(bytes, tempPos, size, UTF_8);
} }
@Override @Override
...@@ -3348,23 +3375,34 @@ public abstract class CodedInputStream { ...@@ -3348,23 +3375,34 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException { public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32(); final int size = readRawVarint32();
if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) { if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) {
byte[] bytes = new byte[size]; if (ENABLE_CUSTOM_UTF8_DECODE) {
UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size); final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos);
if (!Utf8.isValidUtf8(bytes)) { String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size);
throw InvalidProtocolBufferException.invalidUtf8(); currentByteBufferPos += size;
return result;
} else {
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size);
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
currentByteBufferPos += size;
return result;
} }
String result = new String(bytes, UTF_8);
currentByteBufferPos += size;
return result;
} }
if (size >= 0 && size <= remaining()) { if (size >= 0 && size <= remaining()) {
byte[] bytes = new byte[size]; byte[] bytes = new byte[size];
readRawBytesTo(bytes, 0, size); readRawBytesTo(bytes, 0, size);
if (!Utf8.isValidUtf8(bytes)) { if (ENABLE_CUSTOM_UTF8_DECODE) {
throw InvalidProtocolBufferException.invalidUtf8(); return Utf8.decodeUtf8(bytes, 0, size);
} else {
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
return result;
} }
String result = new String(bytes, UTF_8);
return result;
} }
if (size == 0) { if (size == 0) {
......
...@@ -33,7 +33,6 @@ package com.google.protobuf; ...@@ -33,7 +33,6 @@ package com.google.protobuf;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.nio.Buffer; import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.AccessController; import java.security.AccessController;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.logging.Level; import java.util.logging.Level;
...@@ -72,6 +71,8 @@ final class UnsafeUtil { ...@@ -72,6 +71,8 @@ final class UnsafeUtil {
private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(bufferAddressField()); private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(bufferAddressField());
private static final long STRING_VALUE_OFFSET = fieldOffset(stringValueField());
private UnsafeUtil() {} private UnsafeUtil() {}
static boolean hasUnsafeArrayOperations() { static boolean hasUnsafeArrayOperations() {
...@@ -259,6 +260,26 @@ final class UnsafeUtil { ...@@ -259,6 +260,26 @@ final class UnsafeUtil {
return MEMORY_ACCESSOR.getLong(buffer, BUFFER_ADDRESS_OFFSET); return MEMORY_ACCESSOR.getLong(buffer, BUFFER_ADDRESS_OFFSET);
} }
/**
* Returns a new {@link String} backed by the given {@code chars}. The char array should not
* be mutated any more after calling this function.
*/
static String moveToString(char[] chars) {
if (STRING_VALUE_OFFSET == -1) {
// In the off-chance that this JDK does not implement String as we'd expect, just do a copy.
return new String(chars);
}
final String str;
try {
str = (String) UNSAFE.allocateInstance(String.class);
} catch (InstantiationException e) {
// This should never happen, but return a copy as a fallback just in case.
return new String(chars);
}
putObject(str, STRING_VALUE_OFFSET, chars);
return str;
}
static Object getStaticObject(Field field) { static Object getStaticObject(Field field) {
return MEMORY_ACCESSOR.getStaticObject(field); return MEMORY_ACCESSOR.getStaticObject(field);
} }
...@@ -375,7 +396,12 @@ final class UnsafeUtil { ...@@ -375,7 +396,12 @@ final class UnsafeUtil {
/** Finds the address field within a direct {@link Buffer}. */ /** Finds the address field within a direct {@link Buffer}. */
private static Field bufferAddressField() { private static Field bufferAddressField() {
return field(Buffer.class, "address"); return field(Buffer.class, "address", long.class);
}
/** Finds the value field within a {@link String}. */
private static Field stringValueField() {
return field(String.class, "value", char[].class);
} }
/** /**
...@@ -390,11 +416,14 @@ final class UnsafeUtil { ...@@ -390,11 +416,14 @@ final class UnsafeUtil {
* Gets the field with the given name within the class, or {@code null} if not found. If found, * Gets the field with the given name within the class, or {@code null} if not found. If found,
* the field is made accessible. * the field is made accessible.
*/ */
private static Field field(Class<?> clazz, String fieldName) { private static Field field(Class<?> clazz, String fieldName, Class<?> expectedType) {
Field field; Field field;
try { try {
field = clazz.getDeclaredField(fieldName); field = clazz.getDeclaredField(fieldName);
field.setAccessible(true); field.setAccessible(true);
if (!field.getType().equals(expectedType)) {
return null;
}
} catch (Throwable t) { } catch (Throwable t) {
// Failed to access the fields. // Failed to access the fields.
field = null; field = null;
......
...@@ -273,6 +273,15 @@ final class IsValidUtf8TestUtil { ...@@ -273,6 +273,15 @@ final class IsValidUtf8TestUtil {
assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes)); assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes));
assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes, 0, numBytes)); assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes, 0, numBytes));
try {
assertEquals(s, Utf8.decodeUtf8(bytes, 0, numBytes));
} catch (InvalidProtocolBufferException e) {
if (isRoundTrippable) {
System.out.println("Could not decode utf-8");
outputFailure(byteChar, bytes, bytesReencoded);
}
}
// Test partial sequences. // Test partial sequences.
// Partition numBytes into three segments (not necessarily non-empty). // Partition numBytes into three segments (not necessarily non-empty).
int i = rnd.nextInt(numBytes); int i = rnd.nextInt(numBytes);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment