Commit 3e944aec authored by Anuraag Agrawal's avatar Anuraag Agrawal Committed by Anuraag Agrawal

Add a UTF-8 decoder that uses Unsafe to directly decode a byte buffer.

parent 3c6fd3f7
......@@ -286,6 +286,7 @@ java_EXTRA_DIST=
java/core/src/test/java/com/google/protobuf/CheckUtf8Test.java \
java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java \
java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java \
java/core/src/test/java/com/google/protobuf/DecodeUtf8Test.java \
java/core/src/test/java/com/google/protobuf/DeprecatedFieldTest.java \
java/core/src/test/java/com/google/protobuf/DescriptorsTest.java \
java/core/src/test/java/com/google/protobuf/DiscardUnknownFieldsTest.java \
......
......@@ -64,6 +64,14 @@ public abstract class CodedInputStream {
// Integer.MAX_VALUE == 0x7FFFFFF == INT_MAX from limits.h
private static final int DEFAULT_SIZE_LIMIT = Integer.MAX_VALUE;
/**
* Whether to enable our custom UTF-8 decode codepath which does not use {@link StringCoding}.
* Enabled by default, disable by setting
* {@code -Dcom.google.protobuf.enableCustomutf8Decode=false} in JVM args.
*/
private static final boolean ENABLE_CUSTOM_UTF8_DECODE
= !"false".equals(System.getProperty("com.google.protobuf.enableCustomUtf8Decode"));
/** Visible for subclasses. See setRecursionLimit() */
int recursionDepth;
......@@ -825,13 +833,19 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32();
if (size > 0 && size <= (limit - pos)) {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(buffer, pos, pos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
if (ENABLE_CUSTOM_UTF8_DECODE) {
String result = Utf8.decodeUtf8(buffer, pos, size);
pos += size;
return result;
} else {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(buffer, pos, pos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
final int tempPos = pos;
pos += size;
return new String(buffer, tempPos, size, UTF_8);
}
final int tempPos = pos;
pos += size;
return new String(buffer, tempPos, size, UTF_8);
}
if (size == 0) {
......@@ -1524,6 +1538,8 @@ public abstract class CodedInputStream {
final int size = readRawVarint32();
if (size > 0 && size <= remaining()) {
// TODO(nathanmittler): Is there a way to avoid this copy?
// TODO(anuraaga): It might be possible to share the optimized loop with
// readStringRequireUtf8 by implementing Java replacement logic there.
// The same as readBytes' logic
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(pos, bytes, 0, size);
......@@ -1544,19 +1560,26 @@ public abstract class CodedInputStream {
@Override
public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32();
if (size >= 0 && size <= remaining()) {
// TODO(nathanmittler): Is there a way to avoid this copy?
// The same as readBytes' logic
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(pos, bytes, 0, size);
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
if (size > 0 && size <= remaining()) {
if (ENABLE_CUSTOM_UTF8_DECODE) {
final int bufferPos = bufferPos(pos);
String result = Utf8.decodeUtf8(buffer, bufferPos, size);
pos += size;
return result;
} else {
// TODO(nathanmittler): Is there a way to avoid this copy?
// The same as readBytes' logic
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(pos, bytes, 0, size);
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
pos += size;
return result;
String result = new String(bytes, UTF_8);
pos += size;
return result;
}
}
if (size == 0) {
......@@ -2324,11 +2347,15 @@ public abstract class CodedInputStream {
bytes = readRawBytesSlowPath(size);
tempPos = 0;
}
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
if (ENABLE_CUSTOM_UTF8_DECODE) {
return Utf8.decodeUtf8(bytes, tempPos, size);
} else {
// TODO(martinrb): We could save a pass by validating while decoding.
if (!Utf8.isValidUtf8(bytes, tempPos, tempPos + size)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
return new String(bytes, tempPos, size, UTF_8);
}
return new String(bytes, tempPos, size, UTF_8);
}
@Override
......@@ -3348,23 +3375,34 @@ public abstract class CodedInputStream {
public String readStringRequireUtf8() throws IOException {
final int size = readRawVarint32();
if (size > 0 && size <= currentByteBufferLimit - currentByteBufferPos) {
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size);
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
if (ENABLE_CUSTOM_UTF8_DECODE) {
final int bufferPos = (int) (currentByteBufferPos - currentByteBufferStartPos);
String result = Utf8.decodeUtf8(currentByteBuffer, bufferPos, size);
currentByteBufferPos += size;
return result;
} else {
byte[] bytes = new byte[size];
UnsafeUtil.copyMemory(currentByteBufferPos, bytes, 0, size);
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
currentByteBufferPos += size;
return result;
}
String result = new String(bytes, UTF_8);
currentByteBufferPos += size;
return result;
}
if (size >= 0 && size <= remaining()) {
byte[] bytes = new byte[size];
readRawBytesTo(bytes, 0, size);
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
if (ENABLE_CUSTOM_UTF8_DECODE) {
return Utf8.decodeUtf8(bytes, 0, size);
} else {
if (!Utf8.isValidUtf8(bytes)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
String result = new String(bytes, UTF_8);
return result;
}
String result = new String(bytes, UTF_8);
return result;
}
if (size == 0) {
......
......@@ -33,7 +33,6 @@ package com.google.protobuf;
import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.logging.Level;
......@@ -72,6 +71,8 @@ final class UnsafeUtil {
private static final long BUFFER_ADDRESS_OFFSET = fieldOffset(bufferAddressField());
private static final long STRING_VALUE_OFFSET = fieldOffset(stringValueField());
private UnsafeUtil() {}
static boolean hasUnsafeArrayOperations() {
......@@ -259,6 +260,26 @@ final class UnsafeUtil {
return MEMORY_ACCESSOR.getLong(buffer, BUFFER_ADDRESS_OFFSET);
}
/**
* Returns a new {@link String} backed by the given {@code chars}. The char array should not
* be mutated any more after calling this function.
*/
static String moveToString(char[] chars) {
if (STRING_VALUE_OFFSET == -1) {
// In the off-chance that this JDK does not implement String as we'd expect, just do a copy.
return new String(chars);
}
final String str;
try {
str = (String) UNSAFE.allocateInstance(String.class);
} catch (InstantiationException e) {
// This should never happen, but return a copy as a fallback just in case.
return new String(chars);
}
putObject(str, STRING_VALUE_OFFSET, chars);
return str;
}
static Object getStaticObject(Field field) {
return MEMORY_ACCESSOR.getStaticObject(field);
}
......@@ -375,7 +396,12 @@ final class UnsafeUtil {
/** Finds the address field within a direct {@link Buffer}. */
private static Field bufferAddressField() {
return field(Buffer.class, "address");
return field(Buffer.class, "address", long.class);
}
/** Finds the value field within a {@link String}. */
private static Field stringValueField() {
return field(String.class, "value", char[].class);
}
/**
......@@ -390,11 +416,14 @@ final class UnsafeUtil {
* Gets the field with the given name within the class, or {@code null} if not found. If found,
* the field is made accessible.
*/
private static Field field(Class<?> clazz, String fieldName) {
private static Field field(Class<?> clazz, String fieldName, Class<?> expectedType) {
Field field;
try {
field = clazz.getDeclaredField(fieldName);
field.setAccessible(true);
if (!field.getType().equals(expectedType)) {
return null;
}
} catch (Throwable t) {
// Failed to access the fields.
field = null;
......
......@@ -34,11 +34,15 @@ import static com.google.protobuf.UnsafeUtil.addressOffset;
import static com.google.protobuf.UnsafeUtil.hasUnsafeArrayOperations;
import static com.google.protobuf.UnsafeUtil.hasUnsafeByteBufferOperations;
import static java.lang.Character.MAX_SURROGATE;
import static java.lang.Character.MIN_HIGH_SURROGATE;
import static java.lang.Character.MIN_LOW_SURROGATE;
import static java.lang.Character.MIN_SUPPLEMENTARY_CODE_POINT;
import static java.lang.Character.MIN_SURROGATE;
import static java.lang.Character.isSurrogatePair;
import static java.lang.Character.toCodePoint;
import java.nio.ByteBuffer;
import java.util.Arrays;
/**
* A set of low-level, high-performance static utility methods related
......@@ -289,7 +293,7 @@ final class Utf8 {
if (Character.MIN_SURROGATE <= c && c <= Character.MAX_SURROGATE) {
// Check that we have a well-formed surrogate pair.
int cp = Character.codePointAt(sequence, i);
if (cp < Character.MIN_SUPPLEMENTARY_CODE_POINT) {
if (cp < MIN_SUPPLEMENTARY_CODE_POINT) {
throw new UnpairedSurrogateException(i, utf16Length);
}
i++;
......@@ -330,6 +334,26 @@ final class Utf8 {
return processor.partialIsValidUtf8(state, buffer, index, limit);
}
/**
* Decodes the given UTF-8 portion of the {@link ByteBuffer} into a {@link String}.
*
* @throws InvalidProtocolBufferException if the input is not valid UTF-8.
*/
static String decodeUtf8(ByteBuffer buffer, int index, int size)
throws InvalidProtocolBufferException {
return processor.decodeUtf8(buffer, index, size);
}
/**
* Decodes the given UTF-8 encoded byte array slice into a {@link String}.
*
* @throws InvalidProtocolBufferException if the input is not valid UTF-8.
*/
static String decodeUtf8(byte[] bytes, int index, int size)
throws InvalidProtocolBufferException {
return processor.decodeUtf8(bytes, index, size);
}
/**
* Encodes the given characters to the target {@link ByteBuffer} using UTF-8 encoding.
*
......@@ -609,6 +633,116 @@ final class Utf8 {
}
}
/**
* Decodes the given byte array slice into a {@link String}.
*
* @throws InvalidProtocolBufferException if the byte array slice is not valid UTF-8.
*/
abstract String decodeUtf8(byte[] bytes, int index, int size)
throws InvalidProtocolBufferException;
/**
* Decodes the given portion of the {@link ByteBuffer} into a {@link String}.
*
* @throws InvalidProtocolBufferException if the portion of the buffer is not valid UTF-8.
*/
final String decodeUtf8(ByteBuffer buffer, int index, int size)
throws InvalidProtocolBufferException {
if (buffer.hasArray()) {
final int offset = buffer.arrayOffset();
return decodeUtf8(buffer.array(), offset + index, size);
} else if (buffer.isDirect()) {
return decodeUtf8Direct(buffer, index, size);
}
return decodeUtf8Default(buffer, index, size);
}
/**
* Decodes direct {@link ByteBuffer} instances into {@link String}.
*/
abstract String decodeUtf8Direct(ByteBuffer buffer, int index, int size)
throws InvalidProtocolBufferException;
/**
* Decodes {@link ByteBuffer} instances using the {@link ByteBuffer} API rather than
* potentially faster approaches.
*/
final String decodeUtf8Default(ByteBuffer buffer, int index, int size)
throws InvalidProtocolBufferException {
// Bitwise OR combines the sign bits so any negative value fails the check.
if ((index | size | buffer.limit() - index - size) < 0) {
throw new ArrayIndexOutOfBoundsException(
String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, size));
}
int offset = index;
final int limit = offset + size;
// The longest possible resulting String is the same as the number of input bytes, when it is
// all ASCII. For other cases, this over-allocates and we will truncate in the end.
char[] resultArr = new char[size];
int resultPos = 0;
// Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this).
// This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII).
while (offset < limit) {
byte b = buffer.get(offset);
if (!DecodeUtil.isOneByte(b)) {
break;
}
offset++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
while (offset < limit) {
byte byte1 = buffer.get(offset++);
if (DecodeUtil.isOneByte(byte1)) {
DecodeUtil.handleOneByte(byte1, resultArr, resultPos++);
// It's common for there to be multiple ASCII characters in a run mixed in, so add an
// extra optimized loop to take care of these runs.
while (offset < limit) {
byte b = buffer.get(offset);
if (!DecodeUtil.isOneByte(b)) {
break;
}
offset++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
} else if (DecodeUtil.isTwoBytes(byte1)) {
if (offset >= limit) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleTwoBytes(
byte1, /* byte2 */ buffer.get(offset++), resultArr, resultPos++);
} else if (DecodeUtil.isThreeBytes(byte1)) {
if (offset >= limit - 1) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleThreeBytes(
byte1,
/* byte2 */ buffer.get(offset++),
/* byte3 */ buffer.get(offset++),
resultArr,
resultPos++);
} else {
if (offset >= limit - 2) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleFourBytes(
byte1,
/* byte2 */ buffer.get(offset++),
/* byte3 */ buffer.get(offset++),
/* byte4 */ buffer.get(offset++),
resultArr,
resultPos++);
// 4-byte case requires two chars.
resultPos++;
}
}
return new String(resultArr, 0, resultPos);
}
/**
* Encodes an input character sequence ({@code in}) to UTF-8 in the target array ({@code out}).
* For a string, this method is similar to
......@@ -850,6 +984,88 @@ final class Utf8 {
return partialIsValidUtf8Default(state, buffer, index, limit);
}
@Override
String decodeUtf8(byte[] bytes, int index, int size) throws InvalidProtocolBufferException {
// Bitwise OR combines the sign bits so any negative value fails the check.
if ((index | size | bytes.length - index - size) < 0) {
throw new ArrayIndexOutOfBoundsException(
String.format("buffer length=%d, index=%d, size=%d", bytes.length, index, size));
}
int offset = index;
final int limit = offset + size;
// The longest possible resulting String is the same as the number of input bytes, when it is
// all ASCII. For other cases, this over-allocates and we will truncate in the end.
char[] resultArr = new char[size];
int resultPos = 0;
// Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this).
// This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII).
while (offset < limit) {
byte b = bytes[offset];
if (!DecodeUtil.isOneByte(b)) {
break;
}
offset++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
while (offset < limit) {
byte byte1 = bytes[offset++];
if (DecodeUtil.isOneByte(byte1)) {
DecodeUtil.handleOneByte(byte1, resultArr, resultPos++);
// It's common for there to be multiple ASCII characters in a run mixed in, so add an
// extra optimized loop to take care of these runs.
while (offset < limit) {
byte b = bytes[offset];
if (!DecodeUtil.isOneByte(b)) {
break;
}
offset++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
} else if (DecodeUtil.isTwoBytes(byte1)) {
if (offset >= limit) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleTwoBytes(byte1, /* byte2 */ bytes[offset++], resultArr, resultPos++);
} else if (DecodeUtil.isThreeBytes(byte1)) {
if (offset >= limit - 1) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleThreeBytes(
byte1,
/* byte2 */ bytes[offset++],
/* byte3 */ bytes[offset++],
resultArr,
resultPos++);
} else {
if (offset >= limit - 2) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleFourBytes(
byte1,
/* byte2 */ bytes[offset++],
/* byte3 */ bytes[offset++],
/* byte4 */ bytes[offset++],
resultArr,
resultPos++);
// 4-byte case requires two chars.
resultPos++;
}
}
return new String(resultArr, 0, resultPos);
}
@Override
String decodeUtf8Direct(ByteBuffer buffer, int index, int size)
throws InvalidProtocolBufferException {
// For safe processing, we have to use the ByteBufferAPI.
return decodeUtf8Default(buffer, index, size);
}
@Override
int encodeUtf8(CharSequence in, byte[] out, int offset, int length) {
int utf16Length = in.length();
......@@ -996,6 +1212,7 @@ final class Utf8 {
@Override
int partialIsValidUtf8(int state, byte[] bytes, final int index, final int limit) {
// Bitwise OR combines the sign bits so any negative value fails the check.
if ((index | limit | bytes.length - limit) < 0) {
throw new ArrayIndexOutOfBoundsException(
String.format("Array length=%d, index=%d, limit=%d", bytes.length, index, limit));
......@@ -1091,6 +1308,7 @@ final class Utf8 {
@Override
int partialIsValidUtf8Direct(
final int state, ByteBuffer buffer, final int index, final int limit) {
// Bitwise OR combines the sign bits so any negative value fails the check.
if ((index | limit | buffer.limit() - limit) < 0) {
throw new ArrayIndexOutOfBoundsException(
String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, limit));
......@@ -1184,6 +1402,163 @@ final class Utf8 {
return partialIsValidUtf8(address, (int) (addressLimit - address));
}
@Override
String decodeUtf8(byte[] bytes, int index, int size) throws InvalidProtocolBufferException {
if ((index | size | bytes.length - index - size) < 0) {
throw new ArrayIndexOutOfBoundsException(
String.format("buffer length=%d, index=%d, size=%d", bytes.length, index, size));
}
int offset = index;
final int limit = offset + size;
// The longest possible resulting String is the same as the number of input bytes, when it is
// all ASCII. For other cases, this over-allocates and we will truncate in the end.
char[] resultArr = new char[size];
int resultPos = 0;
// Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this).
// This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII).
while (offset < limit) {
byte b = UnsafeUtil.getByte(bytes, offset);
if (!DecodeUtil.isOneByte(b)) {
break;
}
offset++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
while (offset < limit) {
byte byte1 = UnsafeUtil.getByte(bytes, offset++);
if (DecodeUtil.isOneByte(byte1)) {
DecodeUtil.handleOneByte(byte1, resultArr, resultPos++);
// It's common for there to be multiple ASCII characters in a run mixed in, so add an
// extra optimized loop to take care of these runs.
while (offset < limit) {
byte b = UnsafeUtil.getByte(bytes, offset);
if (!DecodeUtil.isOneByte(b)) {
break;
}
offset++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
} else if (DecodeUtil.isTwoBytes(byte1)) {
if (offset >= limit) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleTwoBytes(
byte1, /* byte2 */ UnsafeUtil.getByte(bytes, offset++), resultArr, resultPos++);
} else if (DecodeUtil.isThreeBytes(byte1)) {
if (offset >= limit - 1) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleThreeBytes(
byte1,
/* byte2 */ UnsafeUtil.getByte(bytes, offset++),
/* byte3 */ UnsafeUtil.getByte(bytes, offset++),
resultArr,
resultPos++);
} else {
if (offset >= limit - 2) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleFourBytes(
byte1,
/* byte2 */ UnsafeUtil.getByte(bytes, offset++),
/* byte3 */ UnsafeUtil.getByte(bytes, offset++),
/* byte4 */ UnsafeUtil.getByte(bytes, offset++),
resultArr,
resultPos++);
// 4-byte case requires two chars.
resultPos++;
}
}
if (resultPos < resultArr.length) {
resultArr = Arrays.copyOf(resultArr, resultPos);
}
return UnsafeUtil.moveToString(resultArr);
}
@Override
String decodeUtf8Direct(ByteBuffer buffer, int index, int size)
throws InvalidProtocolBufferException {
// Bitwise OR combines the sign bits so any negative value fails the check.
if ((index | size | buffer.limit() - index - size) < 0) {
throw new ArrayIndexOutOfBoundsException(
String.format("buffer limit=%d, index=%d, limit=%d", buffer.limit(), index, size));
}
long address = UnsafeUtil.addressOffset(buffer) + index;
final long addressLimit = address + size;
// The longest possible resulting String is the same as the number of input bytes, when it is
// all ASCII. For other cases, this over-allocates and we will truncate in the end.
char[] resultArr = new char[size];
int resultPos = 0;
// Optimize for 100% ASCII (Hotspot loves small simple top-level loops like this).
// This simple loop stops when we encounter a byte >= 0x80 (i.e. non-ASCII).
while (address < addressLimit) {
byte b = UnsafeUtil.getByte(address);
if (!DecodeUtil.isOneByte(b)) {
break;
}
address++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
while (address < addressLimit) {
byte byte1 = UnsafeUtil.getByte(address++);
if (DecodeUtil.isOneByte(byte1)) {
DecodeUtil.handleOneByte(byte1, resultArr, resultPos++);
// It's common for there to be multiple ASCII characters in a run mixed in, so add an
// extra optimized loop to take care of these runs.
while (address < addressLimit) {
byte b = UnsafeUtil.getByte(address);
if (!DecodeUtil.isOneByte(b)) {
break;
}
address++;
DecodeUtil.handleOneByte(b, resultArr, resultPos++);
}
} else if (DecodeUtil.isTwoBytes(byte1)) {
if (address >= addressLimit) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleTwoBytes(
byte1, /* byte2 */ UnsafeUtil.getByte(address++), resultArr, resultPos++);
} else if (DecodeUtil.isThreeBytes(byte1)) {
if (address >= addressLimit - 1) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleThreeBytes(
byte1,
/* byte2 */ UnsafeUtil.getByte(address++),
/* byte3 */ UnsafeUtil.getByte(address++),
resultArr,
resultPos++);
} else {
if (address >= addressLimit - 2) {
throw InvalidProtocolBufferException.invalidUtf8();
}
DecodeUtil.handleFourBytes(
byte1,
/* byte2 */ UnsafeUtil.getByte(address++),
/* byte3 */ UnsafeUtil.getByte(address++),
/* byte4 */ UnsafeUtil.getByte(address++),
resultArr,
resultPos++);
// 4-byte case requires two chars.
resultPos++;
}
}
if (resultPos < resultArr.length) {
resultArr = Arrays.copyOf(resultArr, resultPos);
}
return UnsafeUtil.moveToString(resultArr);
}
@Override
int encodeUtf8(final CharSequence in, final byte[] out, final int offset, final int length) {
long outIx = offset;
......@@ -1554,5 +1929,112 @@ final class Utf8 {
}
}
/**
* Utility methods for decoding bytes into {@link String}. Callers are responsible for extracting
* bytes (possibly using Unsafe methods), and checking remaining bytes. All other UTF-8 validity
* checks and codepoint conversion happen in this class.
*/
private static class DecodeUtil {
/**
* Returns whether this is a single-byte codepoint (i.e., ASCII) with the form '0XXXXXXX'.
*/
private static boolean isOneByte(byte b) {
return b >= 0;
}
/**
* Returns whether this is a two-byte codepoint with the form '10XXXXXX'.
*/
private static boolean isTwoBytes(byte b) {
return b < (byte) 0xE0;
}
/**
* Returns whether this is a three-byte codepoint with the form '110XXXXX'.
*/
private static boolean isThreeBytes(byte b) {
return b < (byte) 0xF0;
}
private static void handleOneByte(byte byte1, char[] resultArr, int resultPos) {
resultArr[resultPos] = (char) byte1;
}
private static void handleTwoBytes(
byte byte1, byte byte2, char[] resultArr, int resultPos)
throws InvalidProtocolBufferException {
// Simultaneously checks for illegal trailing-byte in leading position (<= '11000000') and
// overlong 2-byte, '11000001'.
if (byte1 < (byte) 0xC2
|| isNotTrailingByte(byte2)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
resultArr[resultPos] = (char) (((byte1 & 0x1F) << 6) | trailingByteValue(byte2));
}
private static void handleThreeBytes(
byte byte1, byte byte2, byte byte3, char[] resultArr, int resultPos)
throws InvalidProtocolBufferException {
if (isNotTrailingByte(byte2)
// overlong? 5 most significant bits must not all be zero
|| (byte1 == (byte) 0xE0 && byte2 < (byte) 0xA0)
// check for illegal surrogate codepoints
|| (byte1 == (byte) 0xED && byte2 >= (byte) 0xA0)
|| isNotTrailingByte(byte3)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
resultArr[resultPos] = (char)
(((byte1 & 0x0F) << 12) | (trailingByteValue(byte2) << 6) | trailingByteValue(byte3));
}
private static void handleFourBytes(
byte byte1, byte byte2, byte byte3, byte byte4, char[] resultArr, int resultPos)
throws InvalidProtocolBufferException{
if (isNotTrailingByte(byte2)
// Check that 1 <= plane <= 16. Tricky optimized form of:
// valid 4-byte leading byte?
// if (byte1 > (byte) 0xF4 ||
// overlong? 4 most significant bits must not all be zero
// byte1 == (byte) 0xF0 && byte2 < (byte) 0x90 ||
// codepoint larger than the highest code point (U+10FFFF)?
// byte1 == (byte) 0xF4 && byte2 > (byte) 0x8F)
|| (((byte1 << 28) + (byte2 - (byte) 0x90)) >> 30) != 0
|| isNotTrailingByte(byte3)
|| isNotTrailingByte(byte4)) {
throw InvalidProtocolBufferException.invalidUtf8();
}
int codepoint = ((byte1 & 0x07) << 18)
| (trailingByteValue(byte2) << 12)
| (trailingByteValue(byte3) << 6)
| trailingByteValue(byte4);
resultArr[resultPos] = DecodeUtil.highSurrogate(codepoint);
resultArr[resultPos + 1] = DecodeUtil.lowSurrogate(codepoint);
}
/**
* Returns whether the byte is not a valid continuation of the form '10XXXXXX'.
*/
private static boolean isNotTrailingByte(byte b) {
return b > (byte) 0xBF;
}
/**
* Returns the actual value of the trailing byte (removes the prefix '10') for composition.
*/
private static int trailingByteValue(byte b) {
return b & 0x3F;
}
private static char highSurrogate(int codePoint) {
return (char) ((MIN_HIGH_SURROGATE - (MIN_SUPPLEMENTARY_CODE_POINT >>> 10))
+ (codePoint >>> 10));
}
private static char lowSurrogate(int codePoint) {
return (char) (MIN_LOW_SURROGATE + (codePoint & 0x3ff));
}
}
private Utf8() {}
}
package com.google.protobuf;
import com.google.protobuf.Utf8.Processor;
import com.google.protobuf.Utf8.SafeProcessor;
import com.google.protobuf.Utf8.UnsafeProcessor;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
import junit.framework.TestCase;
public class DecodeUtf8Test extends TestCase {
private static Logger logger = Logger.getLogger(DecodeUtf8Test.class.getName());
private static final Processor SAFE_PROCESSOR = new SafeProcessor();
private static final Processor UNSAFE_PROCESSOR = new UnsafeProcessor();
public void testRoundTripAllValidChars() throws Exception {
for (int i = Character.MIN_CODE_POINT; i < Character.MAX_CODE_POINT; i++) {
if (i < Character.MIN_SURROGATE || i > Character.MAX_SURROGATE) {
String str = new String(Character.toChars(i));
assertRoundTrips(str);
}
}
}
// Test all 1, 2, 3 invalid byte combinations. Valid ones would have been covered above.
public void testOneByte() throws Exception {
int valid = 0;
for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) {
ByteString bs = ByteString.copyFrom(new byte[] { (byte) i });
if (!bs.isValidUtf8()) {
assertInvalid(bs.toByteArray());
} else {
valid++;
}
}
assertEquals(IsValidUtf8TestUtil.EXPECTED_ONE_BYTE_ROUNDTRIPPABLE_COUNT, valid);
}
public void testTwoBytes() throws Exception {
int valid = 0;
for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) {
for (int j = Byte.MIN_VALUE; j <= Byte.MAX_VALUE; j++) {
ByteString bs = ByteString.copyFrom(new byte[]{(byte) i, (byte) j});
if (!bs.isValidUtf8()) {
assertInvalid(bs.toByteArray());
} else {
valid++;
}
}
}
assertEquals(IsValidUtf8TestUtil.EXPECTED_TWO_BYTE_ROUNDTRIPPABLE_COUNT, valid);
}
public void testThreeBytes() throws Exception {
// Travis' OOM killer doesn't like this test
if (System.getenv("TRAVIS") == null) {
int count = 0;
int valid = 0;
for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) {
for (int j = Byte.MIN_VALUE; j <= Byte.MAX_VALUE; j++) {
for (int k = Byte.MIN_VALUE; k <= Byte.MAX_VALUE; k++) {
byte[] bytes = new byte[]{(byte) i, (byte) j, (byte) k};
ByteString bs = ByteString.copyFrom(bytes);
if (!bs.isValidUtf8()) {
assertInvalid(bytes);
} else {
valid++;
}
count++;
if (count % 1000000L == 0) {
logger.info("Processed " + (count / 1000000L) + " million characters");
}
}
}
}
assertEquals(IsValidUtf8TestUtil.EXPECTED_THREE_BYTE_ROUNDTRIPPABLE_COUNT, valid);
}
}
/**
* Tests that round tripping of a sample of four byte permutations work.
*/
public void testInvalid_4BytesSamples() throws Exception {
// Bad trailing bytes
assertInvalid(0xF0, 0xA4, 0xAD, 0x7F);
assertInvalid(0xF0, 0xA4, 0xAD, 0xC0);
// Special cases for byte2
assertInvalid(0xF0, 0x8F, 0xAD, 0xA2);
assertInvalid(0xF4, 0x90, 0xAD, 0xA2);
}
public void testRealStrings() throws Exception {
// English
assertRoundTrips("The quick brown fox jumps over the lazy dog");
// German
assertRoundTrips("Quizdeltagerne spiste jordb\u00e6r med fl\u00f8de, mens cirkusklovnen");
// Japanese
assertRoundTrips(
"\u3044\u308d\u306f\u306b\u307b\u3078\u3068\u3061\u308a\u306c\u308b\u3092");
// Hebrew
assertRoundTrips(
"\u05d3\u05d2 \u05e1\u05e7\u05e8\u05df \u05e9\u05d8 \u05d1\u05d9\u05dd "
+ "\u05de\u05d0\u05d5\u05db\u05d6\u05d1 \u05d5\u05dc\u05e4\u05ea\u05e2"
+ " \u05de\u05e6\u05d0 \u05dc\u05d5 \u05d7\u05d1\u05e8\u05d4 "
+ "\u05d0\u05d9\u05da \u05d4\u05e7\u05dc\u05d9\u05d8\u05d4");
// Thai
assertRoundTrips(
" \u0e08\u0e07\u0e1d\u0e48\u0e32\u0e1f\u0e31\u0e19\u0e1e\u0e31\u0e12"
+ "\u0e19\u0e32\u0e27\u0e34\u0e0a\u0e32\u0e01\u0e32\u0e23");
// Chinese
assertRoundTrips(
"\u8fd4\u56de\u94fe\u4e2d\u7684\u4e0b\u4e00\u4e2a\u4ee3\u7406\u9879\u9009\u62e9\u5668");
// Chinese with 4-byte chars
assertRoundTrips("\uD841\uDF0E\uD841\uDF31\uD841\uDF79\uD843\uDC53\uD843\uDC78"
+ "\uD843\uDC96\uD843\uDCCF\uD843\uDCD5\uD843\uDD15\uD843\uDD7C\uD843\uDD7F"
+ "\uD843\uDE0E\uD843\uDE0F\uD843\uDE77\uD843\uDE9D\uD843\uDEA2");
// Mixed
assertRoundTrips(
"The quick brown \u3044\u308d\u306f\u306b\u307b\u3078\u8fd4\u56de\u94fe"
+ "\u4e2d\u7684\u4e0b\u4e00");
}
public void testOverlong() throws Exception {
assertInvalid(0xc0, 0xaf);
assertInvalid(0xe0, 0x80, 0xaf);
assertInvalid(0xf0, 0x80, 0x80, 0xaf);
// Max overlong
assertInvalid(0xc1, 0xbf);
assertInvalid(0xe0, 0x9f, 0xbf);
assertInvalid(0xf0 ,0x8f, 0xbf, 0xbf);
// null overlong
assertInvalid(0xc0, 0x80);
assertInvalid(0xe0, 0x80, 0x80);
assertInvalid(0xf0, 0x80, 0x80, 0x80);
}
public void testIllegalCodepoints() throws Exception {
// Single surrogate
assertInvalid(0xed, 0xa0, 0x80);
assertInvalid(0xed, 0xad, 0xbf);
assertInvalid(0xed, 0xae, 0x80);
assertInvalid(0xed, 0xaf, 0xbf);
assertInvalid(0xed, 0xb0, 0x80);
assertInvalid(0xed, 0xbe, 0x80);
assertInvalid(0xed, 0xbf, 0xbf);
// Paired surrogates
assertInvalid(0xed, 0xa0, 0x80, 0xed, 0xb0, 0x80);
assertInvalid(0xed, 0xa0, 0x80, 0xed, 0xbf, 0xbf);
assertInvalid(0xed, 0xad, 0xbf, 0xed, 0xb0, 0x80);
assertInvalid(0xed, 0xad, 0xbf, 0xed, 0xbf, 0xbf);
assertInvalid(0xed, 0xae, 0x80, 0xed, 0xb0, 0x80);
assertInvalid(0xed, 0xae, 0x80, 0xed, 0xbf, 0xbf);
assertInvalid(0xed, 0xaf, 0xbf, 0xed, 0xb0, 0x80);
assertInvalid(0xed, 0xaf, 0xbf, 0xed, 0xbf, 0xbf);
}
public void testBufferSlice() throws Exception {
String str = "The quick brown fox jumps over the lazy dog";
assertRoundTrips(str, 10, 4);
assertRoundTrips(str, str.length(), 0);
}
public void testInvalidBufferSlice() throws Exception {
byte[] bytes = "The quick brown fox jumps over the lazy dog".getBytes(Internal.UTF_8);
assertInvalidSlice(bytes, bytes.length - 3, 4);
assertInvalidSlice(bytes, bytes.length, 1);
assertInvalidSlice(bytes, bytes.length + 1, 0);
assertInvalidSlice(bytes, 0, bytes.length + 1);
}
private void assertInvalid(int... bytesAsInt) throws Exception {
byte[] bytes = new byte[bytesAsInt.length];
for (int i = 0; i < bytesAsInt.length; i++) {
bytes[i] = (byte) bytesAsInt[i];
}
assertInvalid(bytes);
}
private void assertInvalid(byte[] bytes) throws Exception {
try {
UNSAFE_PROCESSOR.decodeUtf8(bytes, 0, bytes.length);
fail();
} catch (InvalidProtocolBufferException e) {
// Expected.
}
try {
SAFE_PROCESSOR.decodeUtf8(bytes, 0, bytes.length);
fail();
} catch (InvalidProtocolBufferException e) {
// Expected.
}
ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length);
direct.put(bytes);
direct.flip();
try {
UNSAFE_PROCESSOR.decodeUtf8(direct, 0, bytes.length);
fail();
} catch (InvalidProtocolBufferException e) {
// Expected.
}
try {
SAFE_PROCESSOR.decodeUtf8(direct, 0, bytes.length);
fail();
} catch (InvalidProtocolBufferException e) {
// Expected.
}
ByteBuffer heap = ByteBuffer.allocate(bytes.length);
heap.put(bytes);
heap.flip();
try {
UNSAFE_PROCESSOR.decodeUtf8(heap, 0, bytes.length);
fail();
} catch (InvalidProtocolBufferException e) {
// Expected.
}
try {
SAFE_PROCESSOR.decodeUtf8(heap, 0, bytes.length);
fail();
} catch (InvalidProtocolBufferException e) {
// Expected.
}
}
private void assertInvalidSlice(byte[] bytes, int index, int size) throws Exception {
try {
UNSAFE_PROCESSOR.decodeUtf8(bytes, index, size);
fail();
} catch (ArrayIndexOutOfBoundsException e) {
// Expected.
}
try {
SAFE_PROCESSOR.decodeUtf8(bytes, index, size);
fail();
} catch (ArrayIndexOutOfBoundsException e) {
// Expected.
}
ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length);
direct.put(bytes);
direct.flip();
try {
UNSAFE_PROCESSOR.decodeUtf8(direct, index, size);
fail();
} catch (ArrayIndexOutOfBoundsException e) {
// Expected.
}
try {
SAFE_PROCESSOR.decodeUtf8(direct, index, size);
fail();
} catch (ArrayIndexOutOfBoundsException e) {
// Expected.
}
ByteBuffer heap = ByteBuffer.allocate(bytes.length);
heap.put(bytes);
heap.flip();
try {
UNSAFE_PROCESSOR.decodeUtf8(heap, index, size);
fail();
} catch (ArrayIndexOutOfBoundsException e) {
// Expected.
}
try {
SAFE_PROCESSOR.decodeUtf8(heap, index, size);
fail();
} catch (ArrayIndexOutOfBoundsException e) {
// Expected.
}
}
private void assertRoundTrips(String str) throws Exception {
assertRoundTrips(str, 0, -1);
}
private void assertRoundTrips(String str, int index, int size) throws Exception {
byte[] bytes = str.getBytes(Internal.UTF_8);
if (size == -1) {
size = bytes.length;
}
assertDecode(new String(bytes, index, size, Internal.UTF_8),
UNSAFE_PROCESSOR.decodeUtf8(bytes, index, size));
assertDecode(new String(bytes, index, size, Internal.UTF_8),
SAFE_PROCESSOR.decodeUtf8(bytes, index, size));
ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length);
direct.put(bytes);
direct.flip();
assertDecode(new String(bytes, index, size, Internal.UTF_8),
UNSAFE_PROCESSOR.decodeUtf8(direct, index, size));
assertDecode(new String(bytes, index, size, Internal.UTF_8),
SAFE_PROCESSOR.decodeUtf8(direct, index, size));
ByteBuffer heap = ByteBuffer.allocate(bytes.length);
heap.put(bytes);
heap.flip();
assertDecode(new String(bytes, index, size, Internal.UTF_8),
UNSAFE_PROCESSOR.decodeUtf8(heap, index, size));
assertDecode(new String(bytes, index, size, Internal.UTF_8),
SAFE_PROCESSOR.decodeUtf8(heap, index, size));
}
private void assertDecode(String expected, String actual) {
if (!expected.equals(actual)) {
fail("Failure: Expected (" + codepoints(expected) + ") Actual (" + codepoints(actual) + ")");
}
}
private List<String> codepoints(String str) {
List<String> codepoints = new ArrayList<String>();
for (int i = 0; i < str.length(); i++) {
codepoints.add(Long.toHexString(str.charAt(i)));
}
return codepoints;
}
}
......@@ -273,6 +273,15 @@ final class IsValidUtf8TestUtil {
assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes));
assertEquals(isRoundTrippable, Utf8.isValidUtf8(bytes, 0, numBytes));
try {
assertEquals(s, Utf8.decodeUtf8(bytes, 0, numBytes));
} catch (InvalidProtocolBufferException e) {
if (isRoundTrippable) {
System.out.println("Could not decode utf-8");
outputFailure(byteChar, bytes, bytesReencoded);
}
}
// Test partial sequences.
// Partition numBytes into three segments (not necessarily non-empty).
int i = rnd.nextInt(numBytes);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment