Commit 79b311c1 authored by Max Cai's avatar Max Cai

Correctness: floating point equality using bits instead of ==.

Special values for float and double make it inaccurate to test the equality with ==.
The main Java library uses the standard Object.equals() implementation for all fields,
which for floating point fields means Float.equals() or Double.equals(). They define
equality as bitwise equality, with all NaN representations normalized to the same bit
sequence (and therefore equal to each other). This test checks that the nano
implementation complies with Object.equals(), so NaN == NaN and +0.0 != -0.0.

Change-Id: I97bb4a3687223d8a212c70cd736436b9dd80c1d7
parent 1b1735ce
......@@ -2886,13 +2886,6 @@ public class NanoTest extends TestCase {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
// We set the _nan fields to something other than nan, because equality
// is defined for nan such that Float.NaN != Float.NaN, which makes any
// instance of TestAllTypesNano unequal to any other instance unless
// these fields are set. This is also the behavior of the regular java
// generator when the value of a field is NaN.
message.defaultFloatNan = 1.0f;
message.defaultDoubleNan = 1.0;
return message;
}
......@@ -2915,7 +2908,6 @@ public class NanoTest extends TestCase {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
message.defaultFloatNan = 1.0f;
return message;
}
......@@ -2924,8 +2916,7 @@ public class NanoTest extends TestCase {
.setOptionalInt32(5)
.setOptionalString("Hello")
.setOptionalBytes(new byte[] {1, 2, 3})
.setOptionalNestedEnum(TestNanoAccessors.BAR)
.setDefaultFloatNan(1.0f);
.setOptionalNestedEnum(TestNanoAccessors.BAR);
message.optionalNestedMessage = new TestNanoAccessors.NestedMessage().setBb(27);
message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
message.repeatedString = new String[] { "One", "Two" };
......@@ -2973,6 +2964,126 @@ public class NanoTest extends TestCase {
return message;
}
public void testEqualsWithSpecialFloatingPointValues() throws Exception {
// Checks that the nano implementation complies with Object.equals() when treating
// floating point numbers, i.e. NaN == NaN and +0.0 != -0.0.
// This test assumes that the generated equals() implementations are symmetric, so
// there will only be one direction for each equality check.
TestAllTypesNano m1 = new TestAllTypesNano();
m1.optionalFloat = Float.NaN;
m1.optionalDouble = Double.NaN;
TestAllTypesNano m2 = new TestAllTypesNano();
m2.optionalFloat = Float.NaN;
m2.optionalDouble = Double.NaN;
assertTrue(m1.equals(m2));
assertTrue(m1.equals(
MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1))));
m1.optionalFloat = +0f;
m2.optionalFloat = -0f;
assertFalse(m1.equals(m2));
m1.optionalFloat = -0f;
m1.optionalDouble = +0d;
m2.optionalDouble = -0d;
assertFalse(m1.equals(m2));
m1.optionalDouble = -0d;
assertTrue(m1.equals(m2));
assertFalse(m1.equals(new TestAllTypesNano())); // -0 does not equals() the default +0
assertTrue(m1.equals(
MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1))));
// -------
TestAllTypesNanoHas m3 = new TestAllTypesNanoHas();
m3.optionalFloat = Float.NaN;
m3.hasOptionalFloat = true;
m3.optionalDouble = Double.NaN;
m3.hasOptionalDouble = true;
TestAllTypesNanoHas m4 = new TestAllTypesNanoHas();
m4.optionalFloat = Float.NaN;
m4.hasOptionalFloat = true;
m4.optionalDouble = Double.NaN;
m4.hasOptionalDouble = true;
assertTrue(m3.equals(m4));
assertTrue(m3.equals(
MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3))));
m3.optionalFloat = +0f;
m4.optionalFloat = -0f;
assertFalse(m3.equals(m4));
m3.optionalFloat = -0f;
m3.optionalDouble = +0d;
m4.optionalDouble = -0d;
assertFalse(m3.equals(m4));
m3.optionalDouble = -0d;
m3.hasOptionalFloat = false; // -0 does not equals() the default +0,
m3.hasOptionalDouble = false; // so these incorrect 'has' flags should be disregarded.
assertTrue(m3.equals(m4)); // note: m4 has the 'has' flags set.
assertFalse(m3.equals(new TestAllTypesNanoHas())); // note: the new message has +0 defaults
assertTrue(m3.equals(
MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3))));
// note: the deserialized message has the 'has' flags set.
// -------
TestNanoAccessors m5 = new TestNanoAccessors();
m5.setOptionalFloat(Float.NaN);
m5.setOptionalDouble(Double.NaN);
TestNanoAccessors m6 = new TestNanoAccessors();
m6.setOptionalFloat(Float.NaN);
m6.setOptionalDouble(Double.NaN);
assertTrue(m5.equals(m6));
assertTrue(m5.equals(
MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6))));
m5.setOptionalFloat(+0f);
m6.setOptionalFloat(-0f);
assertFalse(m5.equals(m6));
m5.setOptionalFloat(-0f);
m5.setOptionalDouble(+0d);
m6.setOptionalDouble(-0d);
assertFalse(m5.equals(m6));
m5.setOptionalDouble(-0d);
assertTrue(m5.equals(m6));
assertFalse(m5.equals(new TestNanoAccessors()));
assertTrue(m5.equals(
MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6))));
// -------
NanoReferenceTypes.TestAllTypesNano m7 = new NanoReferenceTypes.TestAllTypesNano();
m7.optionalFloat = Float.NaN;
m7.optionalDouble = Double.NaN;
NanoReferenceTypes.TestAllTypesNano m8 = new NanoReferenceTypes.TestAllTypesNano();
m8.optionalFloat = Float.NaN;
m8.optionalDouble = Double.NaN;
assertTrue(m7.equals(m8));
assertTrue(m7.equals(MessageNano.mergeFrom(
new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7))));
m7.optionalFloat = +0f;
m8.optionalFloat = -0f;
assertFalse(m7.equals(m8));
m7.optionalFloat = -0f;
m7.optionalDouble = +0d;
m8.optionalDouble = -0d;
assertFalse(m7.equals(m8));
m7.optionalDouble = -0d;
assertTrue(m7.equals(m8));
assertFalse(m7.equals(new NanoReferenceTypes.TestAllTypesNano()));
assertTrue(m7.equals(MessageNano.mergeFrom(
new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7))));
}
public void testNullRepeatedFields() throws Exception {
// Check that serialization after explicitly setting a repeated field
// to null doesn't NPE.
......
......@@ -175,38 +175,6 @@ int FixedSize(FieldDescriptor::Type type) {
return -1;
}
// Returns true if the field has a default value equal to NaN.
bool IsDefaultNaN(const FieldDescriptor* field) {
switch (field->type()) {
case FieldDescriptor::TYPE_INT32 : return false;
case FieldDescriptor::TYPE_UINT32 : return false;
case FieldDescriptor::TYPE_SINT32 : return false;
case FieldDescriptor::TYPE_FIXED32 : return false;
case FieldDescriptor::TYPE_SFIXED32: return false;
case FieldDescriptor::TYPE_INT64 : return false;
case FieldDescriptor::TYPE_UINT64 : return false;
case FieldDescriptor::TYPE_SINT64 : return false;
case FieldDescriptor::TYPE_FIXED64 : return false;
case FieldDescriptor::TYPE_SFIXED64: return false;
case FieldDescriptor::TYPE_FLOAT :
return isnan(field->default_value_float());
case FieldDescriptor::TYPE_DOUBLE :
return isnan(field->default_value_double());
case FieldDescriptor::TYPE_BOOL : return false;
case FieldDescriptor::TYPE_STRING : return false;
case FieldDescriptor::TYPE_BYTES : return false;
case FieldDescriptor::TYPE_ENUM : return false;
case FieldDescriptor::TYPE_GROUP : return false;
case FieldDescriptor::TYPE_MESSAGE : return false;
// No default because we want the compiler to complain if any new
// types are added.
}
GOOGLE_LOG(FATAL) << "Can't get here.";
return false;
}
// Return true if the type is a that has variable length
// for instance String's.
bool IsVariableLenType(JavaType type) {
......@@ -384,15 +352,21 @@ GenerateSerializationConditional(io::Printer* printer) const {
printer->Print(variables_,
"if (");
}
if (IsArrayType(GetJavaType(descriptor_))) {
JavaType java_type = GetJavaType(descriptor_);
if (IsArrayType(java_type)) {
printer->Print(variables_,
"!java.util.Arrays.equals(this.$name$, $default$)) {\n");
} else if (IsReferenceType(GetJavaType(descriptor_))) {
} else if (IsReferenceType(java_type)) {
printer->Print(variables_,
"!this.$name$.equals($default$)) {\n");
} else if (IsDefaultNaN(descriptor_)) {
} else if (java_type == JAVATYPE_FLOAT) {
printer->Print(variables_,
"java.lang.Float.floatToIntBits(this.$name$)\n"
" != java.lang.Float.floatToIntBits($default$)) {\n");
} else if (java_type == JAVATYPE_DOUBLE) {
printer->Print(variables_,
"!$capitalized_type$.isNaN(this.$name$)) {\n");
"java.lang.Double.doubleToLongBits(this.$name$)\n"
" != java.lang.Double.doubleToLongBits($default$)) {\n");
} else {
printer->Print(variables_,
"this.$name$ != $default$) {\n");
......@@ -464,6 +438,36 @@ GenerateEqualsCode(io::Printer* printer) const {
printer->Print(") {\n"
" return false;\n"
"}\n");
} else if (java_type == JAVATYPE_FLOAT) {
printer->Print(variables_,
"{\n"
" int bits = java.lang.Float.floatToIntBits(this.$name$);\n"
" if (bits != java.lang.Float.floatToIntBits(other.$name$)");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (bits == java.lang.Float.floatToIntBits($default$)\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
" }\n"
"}\n");
} else if (java_type == JAVATYPE_DOUBLE) {
printer->Print(variables_,
"{\n"
" long bits = java.lang.Double.doubleToLongBits(this.$name$);\n"
" if (bits != java.lang.Double.doubleToLongBits(other.$name$)");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (bits == java.lang.Double.doubleToLongBits($default$)\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
" }\n"
"}\n");
} else {
printer->Print(variables_,
"if (this.$name$ != other.$name$");
......@@ -623,12 +627,26 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
void AccessorPrimitiveFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
switch (GetJavaType(descriptor_)) {
// For all Java primitive types below, the hash codes match the
// results of BoxedType.valueOf(primitiveValue).hashCode().
case JAVATYPE_INT:
case JAVATYPE_LONG:
// For all Java primitive types below, the equality checks match the
// results of BoxedType.valueOf(primitiveValue).equals(otherValue).
case JAVATYPE_FLOAT:
printer->Print(variables_,
"if ($different_has$\n"
" || java.lang.Float.floatToIntBits($name$_)\n"
" != java.lang.Float.floatToIntBits(other.$name$_)) {\n"
" return false;\n"
"}\n");
break;
case JAVATYPE_DOUBLE:
printer->Print(variables_,
"if ($different_has$\n"
" || java.lang.Double.doubleToLongBits($name$_)\n"
" != java.lang.Double.doubleToLongBits(other.$name$_)) {\n"
" return false;\n"
"}\n");
break;
case JAVATYPE_INT:
case JAVATYPE_LONG:
case JAVATYPE_BOOLEAN:
printer->Print(variables_,
"if ($different_has$\n"
......
......@@ -49,6 +49,8 @@ message TestNanoAccessors {
// Singular
optional int32 optional_int32 = 1;
optional float optional_float = 11;
optional double optional_double = 12;
optional string optional_string = 14;
optional bytes optional_bytes = 15;
......
......@@ -49,6 +49,8 @@ message TestAllTypesNanoHas {
// Singular
optional int32 optional_int32 = 1;
optional float optional_float = 11;
optional double optional_double = 12;
optional string optional_string = 14;
optional bytes optional_bytes = 15;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment