Commit 885f959b authored by Max Cai's avatar Max Cai Committed by Gerrit Code Review

Merge "Implement hashCode() and equals() behind a generator option."

parents 461d4ac8 56a37328
......@@ -31,6 +31,7 @@
package com.google.protobuf.nano;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
/**
* The classes contained within are used internally by the Protocol Buffer
......@@ -40,7 +41,10 @@ import java.io.UnsupportedEncodingException;
*
* @author kenton@google.com (Kenton Varda)
*/
public class InternalNano {
public final class InternalNano {
private InternalNano() {}
/**
* Helper called by generated code to construct default values for string
* fields.
......@@ -69,7 +73,7 @@ public class InternalNano {
* converts from the generated string to the string we actually want. The
* generated code calls this automatically.
*/
public static final String stringDefaultValue(String bytes) {
public static String stringDefaultValue(String bytes) {
try {
return new String(bytes.getBytes("ISO-8859-1"), "UTF-8");
} catch (UnsupportedEncodingException e) {
......@@ -88,7 +92,7 @@ public class InternalNano {
* In this case we only need the second of the two hacks -- allowing us to
* embed raw bytes as a string literal with ISO-8859-1 encoding.
*/
public static final byte[] bytesDefaultValue(String bytes) {
public static byte[] bytesDefaultValue(String bytes) {
try {
return bytes.getBytes("ISO-8859-1");
} catch (UnsupportedEncodingException e) {
......@@ -103,11 +107,215 @@ public class InternalNano {
* Helper function to convert a string into UTF-8 while turning the
* UnsupportedEncodingException to a RuntimeException.
*/
public static final byte[] copyFromUtf8(final String text) {
public static byte[] copyFromUtf8(final String text) {
try {
return text.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException("UTF-8 not supported?");
}
}
/**
* Checks repeated int field equality; null-value and 0-length fields are
* considered equal.
*/
public static boolean equals(int[] field1, int[] field2) {
if (field1 == null || field1.length == 0) {
return field2 == null || field2.length == 0;
} else {
return Arrays.equals(field1, field2);
}
}
/**
* Checks repeated long field equality; null-value and 0-length fields are
* considered equal.
*/
public static boolean equals(long[] field1, long[] field2) {
if (field1 == null || field1.length == 0) {
return field2 == null || field2.length == 0;
} else {
return Arrays.equals(field1, field2);
}
}
/**
* Checks repeated float field equality; null-value and 0-length fields are
* considered equal.
*/
public static boolean equals(float[] field1, float[] field2) {
if (field1 == null || field1.length == 0) {
return field2 == null || field2.length == 0;
} else {
return Arrays.equals(field1, field2);
}
}
/**
* Checks repeated double field equality; null-value and 0-length fields are
* considered equal.
*/
public static boolean equals(double[] field1, double[] field2) {
if (field1 == null || field1.length == 0) {
return field2 == null || field2.length == 0;
} else {
return Arrays.equals(field1, field2);
}
}
/**
* Checks repeated boolean field equality; null-value and 0-length fields are
* considered equal.
*/
public static boolean equals(boolean[] field1, boolean[] field2) {
if (field1 == null || field1.length == 0) {
return field2 == null || field2.length == 0;
} else {
return Arrays.equals(field1, field2);
}
}
/**
* Checks repeated bytes field equality. Only non-null elements are tested.
* Returns true if the two fields have the same sequence of non-null
* elements. Null-value fields and fields of any length with only null
* elements are considered equal.
*/
public static boolean equals(byte[][] field1, byte[][] field2) {
int index1 = 0;
int length1 = field1 == null ? 0 : field1.length;
int index2 = 0;
int length2 = field2 == null ? 0 : field2.length;
while (true) {
while (index1 < length1 && field1[index1] == null) {
index1++;
}
while (index2 < length2 && field2[index2] == null) {
index2++;
}
boolean atEndOf1 = index1 >= length1;
boolean atEndOf2 = index2 >= length2;
if (atEndOf1 && atEndOf2) {
// no more non-null elements to test in both arrays
return true;
} else if (atEndOf1 != atEndOf2) {
// one of the arrays have extra non-null elements
return false;
} else if (!Arrays.equals(field1[index1], field2[index2])) {
// element mismatch
return false;
}
index1++;
index2++;
}
}
/**
* Checks repeated string/message field equality. Only non-null elements are
* tested. Returns true if the two fields have the same sequence of non-null
* elements. Null-value fields and fields of any length with only null
* elements are considered equal.
*/
public static boolean equals(Object[] field1, Object[] field2) {
int index1 = 0;
int length1 = field1 == null ? 0 : field1.length;
int index2 = 0;
int length2 = field2 == null ? 0 : field2.length;
while (true) {
while (index1 < length1 && field1[index1] == null) {
index1++;
}
while (index2 < length2 && field2[index2] == null) {
index2++;
}
boolean atEndOf1 = index1 >= length1;
boolean atEndOf2 = index2 >= length2;
if (atEndOf1 && atEndOf2) {
// no more non-null elements to test in both arrays
return true;
} else if (atEndOf1 != atEndOf2) {
// one of the arrays have extra non-null elements
return false;
} else if (!field1[index1].equals(field2[index2])) {
// element mismatch
return false;
}
index1++;
index2++;
}
}
/**
* Computes the hash code of a repeated int field. Null-value and 0-length
* fields have the same hash code.
*/
public static int hashCode(int[] field) {
return field == null || field.length == 0 ? 0 : Arrays.hashCode(field);
}
/**
* Computes the hash code of a repeated long field. Null-value and 0-length
* fields have the same hash code.
*/
public static int hashCode(long[] field) {
return field == null || field.length == 0 ? 0 : Arrays.hashCode(field);
}
/**
* Computes the hash code of a repeated float field. Null-value and 0-length
* fields have the same hash code.
*/
public static int hashCode(float[] field) {
return field == null || field.length == 0 ? 0 : Arrays.hashCode(field);
}
/**
* Computes the hash code of a repeated double field. Null-value and 0-length
* fields have the same hash code.
*/
public static int hashCode(double[] field) {
return field == null || field.length == 0 ? 0 : Arrays.hashCode(field);
}
/**
* Computes the hash code of a repeated boolean field. Null-value and 0-length
* fields have the same hash code.
*/
public static int hashCode(boolean[] field) {
return field == null || field.length == 0 ? 0 : Arrays.hashCode(field);
}
/**
* Computes the hash code of a repeated bytes field. Only the sequence of all
* non-null elements are used in the computation. Null-value fields and fields
* of any length with only null elements have the same hash code.
*/
public static int hashCode(byte[][] field) {
int result = 0;
for (int i = 0, size = field == null ? 0 : field.length; i < size; i++) {
byte[] element = field[i];
if (element != null) {
result = 31 * result + Arrays.hashCode(element);
}
}
return result;
}
/**
* Computes the hash code of a repeated string/message field. Only the
* sequence of all non-null elements are used in the computation. Null-value
* fields and fields of any length with only null elements have the same hash
* code.
*/
public static int hashCode(Object[] field) {
int result = 0;
for (int i = 0, size = field == null ? 0 : field.length; i < size; i++) {
Object element = field[i];
if (element != null) {
result = 31 * result + element.hashCode();
}
}
return result;
}
}
......@@ -30,6 +30,8 @@
package com.google.protobuf.nano;
import java.util.Arrays;
/**
* Stores unknown fields. These might be extensions or fields that the generated API doesn't
* know about yet.
......@@ -37,6 +39,7 @@ package com.google.protobuf.nano;
* @author bduff@google.com (Brian Duff)
*/
public final class UnknownFieldData {
final int tag;
final byte[] bytes;
......@@ -44,4 +47,25 @@ public final class UnknownFieldData {
this.tag = tag;
this.bytes = bytes;
}
@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (!(o instanceof UnknownFieldData)) {
return false;
}
UnknownFieldData other = (UnknownFieldData) o;
return tag == other.tag && Arrays.equals(bytes, other.bytes);
}
@Override
public int hashCode() {
int result = 17;
result = 31 * result + tag;
result = 31 * result + Arrays.hashCode(bytes);
return result;
}
}
......@@ -60,6 +60,7 @@ import junit.framework.TestCase;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
/**
......@@ -2684,6 +2685,224 @@ public class NanoTest extends TestCase {
assertHasWireData(message, false);
}
public void testHashCodeEquals() throws Exception {
// Complete equality:
TestAllTypesNano a = createMessageForHashCodeEqualsTest();
TestAllTypesNano aEquivalent = createMessageForHashCodeEqualsTest();
// Null and empty array for repeated fields equality:
TestAllTypesNano b = createMessageForHashCodeEqualsTest();
b.repeatedBool = null;
b.repeatedFloat = new float[0];
TestAllTypesNano bEquivalent = createMessageForHashCodeEqualsTest();
bEquivalent.repeatedBool = new boolean[0];
bEquivalent.repeatedFloat = null;
// Ref-element-type repeated fields use non-null subsequence equality:
TestAllTypesNano c = createMessageForHashCodeEqualsTest();
c.repeatedString = null;
c.repeatedStringPiece = new String[] {null, "one", null, "two"};
c.repeatedBytes = new byte[][] {{3, 4}, null};
TestAllTypesNano cEquivalent = createMessageForHashCodeEqualsTest();
cEquivalent.repeatedString = new String[3];
cEquivalent.repeatedStringPiece = new String[] {"one", "two", null};
cEquivalent.repeatedBytes = new byte[][] {{3, 4}};
// Complete equality for messages with has fields:
TestAllTypesNanoHas d = createMessageWithHasForHashCodeEqualsTest();
TestAllTypesNanoHas dEquivalent = createMessageWithHasForHashCodeEqualsTest();
// If has-fields exist, fields with the same default values but
// different has-field values are different.
TestAllTypesNanoHas e = createMessageWithHasForHashCodeEqualsTest();
e.optionalInt32++; // make different from d
e.hasDefaultString = false;
TestAllTypesNanoHas eDifferent = createMessageWithHasForHashCodeEqualsTest();
eDifferent.optionalInt32 = e.optionalInt32;
eDifferent.hasDefaultString = true;
// Complete equality for messages with accessors:
TestNanoAccessors f = createMessageWithAccessorsForHashCodeEqualsTest();
TestNanoAccessors fEquivalent = createMessageWithAccessorsForHashCodeEqualsTest();
System.out.println("equals: " + f.equals(fEquivalent));
System.out.println("hashCode: " + f.hashCode() + " vs " + fEquivalent.hashCode());
// If using accessors, explicitly setting a field to its default value
// should make the message different.
TestNanoAccessors g = createMessageWithAccessorsForHashCodeEqualsTest();
g.setOptionalInt32(g.getOptionalInt32() + 1); // make different from f
g.clearDefaultString();
TestNanoAccessors gDifferent = createMessageWithAccessorsForHashCodeEqualsTest();
gDifferent.setOptionalInt32(g.getOptionalInt32());
gDifferent.setDefaultString(g.getDefaultString());
// Complete equality for reference typed messages:
NanoReferenceTypes.TestAllTypesNano h = createRefTypedMessageForHashCodeEqualsTest();
NanoReferenceTypes.TestAllTypesNano hEquivalent = createRefTypedMessageForHashCodeEqualsTest();
// Inequality of null and default value for reference typed messages:
NanoReferenceTypes.TestAllTypesNano i = createRefTypedMessageForHashCodeEqualsTest();
i.optionalInt32 = 1; // make different from h
i.optionalFloat = null;
NanoReferenceTypes.TestAllTypesNano iDifferent = createRefTypedMessageForHashCodeEqualsTest();
iDifferent.optionalInt32 = i.optionalInt32;
iDifferent.optionalFloat = 0.0f;
HashMap<MessageNano, String> hashMap = new HashMap<MessageNano, String>();
hashMap.put(a, "a");
hashMap.put(b, "b");
hashMap.put(c, "c");
hashMap.put(d, "d");
hashMap.put(e, "e");
hashMap.put(f, "f");
hashMap.put(g, "g");
hashMap.put(h, "h");
hashMap.put(i, "i");
assertEquals(9, hashMap.size()); // a-i should be different from each other.
assertEquals("a", hashMap.get(a));
assertEquals("a", hashMap.get(aEquivalent));
assertEquals("b", hashMap.get(b));
assertEquals("b", hashMap.get(bEquivalent));
assertEquals("c", hashMap.get(c));
assertEquals("c", hashMap.get(cEquivalent));
assertEquals("d", hashMap.get(d));
assertEquals("d", hashMap.get(dEquivalent));
assertEquals("e", hashMap.get(e));
assertNull(hashMap.get(eDifferent));
assertEquals("f", hashMap.get(f));
assertEquals("f", hashMap.get(fEquivalent));
assertEquals("g", hashMap.get(g));
assertNull(hashMap.get(gDifferent));
assertEquals("h", hashMap.get(h));
assertEquals("h", hashMap.get(hEquivalent));
assertEquals("i", hashMap.get(i));
assertNull(hashMap.get(iDifferent));
}
private TestAllTypesNano createMessageForHashCodeEqualsTest() {
TestAllTypesNano message = new TestAllTypesNano();
message.optionalInt32 = 5;
message.optionalInt64 = 777;
message.optionalFloat = 1.0f;
message.optionalDouble = 2.0;
message.optionalBool = true;
message.optionalString = "Hello";
message.optionalBytes = new byte[] { 1, 2, 3 };
message.optionalNestedMessage = new TestAllTypesNano.NestedMessage();
message.optionalNestedMessage.bb = 27;
message.optionalNestedEnum = TestAllTypesNano.BAR;
message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
message.repeatedInt64 = new long[] { 27L, 28L, 29L };
message.repeatedFloat = new float[] { 5.0f, 6.0f };
message.repeatedDouble = new double[] { 99.1, 22.5 };
message.repeatedBool = new boolean[] { true, false, true };
message.repeatedString = new String[] { "One", "Two" };
message.repeatedBytes = new byte[][] { { 2, 7 }, { 2, 7 } };
message.repeatedNestedMessage = new TestAllTypesNano.NestedMessage[] {
message.optionalNestedMessage,
message.optionalNestedMessage
};
message.repeatedNestedEnum = new int[] {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
// We set the _nan fields to something other than nan, because equality
// is defined for nan such that Float.NaN != Float.NaN, which makes any
// instance of TestAllTypesNano unequal to any other instance unless
// these fields are set. This is also the behavior of the regular java
// generator when the value of a field is NaN.
message.defaultFloatNan = 1.0f;
message.defaultDoubleNan = 1.0;
return message;
}
private TestAllTypesNanoHas createMessageWithHasForHashCodeEqualsTest() {
TestAllTypesNanoHas message = new TestAllTypesNanoHas();
message.optionalInt32 = 5;
message.optionalString = "Hello";
message.optionalBytes = new byte[] { 1, 2, 3 };
message.optionalNestedMessage = new TestAllTypesNanoHas.NestedMessage();
message.optionalNestedMessage.bb = 27;
message.optionalNestedEnum = TestAllTypesNano.BAR;
message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
message.repeatedString = new String[] { "One", "Two" };
message.repeatedBytes = new byte[][] { { 2, 7 }, { 2, 7 } };
message.repeatedNestedMessage = new TestAllTypesNanoHas.NestedMessage[] {
message.optionalNestedMessage,
message.optionalNestedMessage
};
message.repeatedNestedEnum = new int[] {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
message.defaultFloatNan = 1.0f;
return message;
}
private TestNanoAccessors createMessageWithAccessorsForHashCodeEqualsTest() {
TestNanoAccessors message = new TestNanoAccessors()
.setOptionalInt32(5)
.setOptionalString("Hello")
.setOptionalBytes(new byte[] {1, 2, 3})
.setOptionalNestedMessage(new TestNanoAccessors.NestedMessage().setBb(27))
.setOptionalNestedEnum(TestNanoAccessors.BAR)
.setDefaultFloatNan(1.0f);
message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
message.repeatedString = new String[] { "One", "Two" };
message.repeatedBytes = new byte[][] { { 2, 7 }, { 2, 7 } };
message.repeatedNestedMessage = new TestNanoAccessors.NestedMessage[] {
message.getOptionalNestedMessage(),
message.getOptionalNestedMessage()
};
message.repeatedNestedEnum = new int[] {
TestAllTypesNano.BAR,
TestAllTypesNano.BAZ
};
return message;
}
private NanoReferenceTypes.TestAllTypesNano createRefTypedMessageForHashCodeEqualsTest() {
NanoReferenceTypes.TestAllTypesNano message = new NanoReferenceTypes.TestAllTypesNano();
message.optionalInt32 = 5;
message.optionalInt64 = 777L;
message.optionalFloat = 1.0f;
message.optionalDouble = 2.0;
message.optionalBool = true;
message.optionalString = "Hello";
message.optionalBytes = new byte[] { 1, 2, 3 };
message.optionalNestedMessage =
new NanoReferenceTypes.TestAllTypesNano.NestedMessage();
message.optionalNestedMessage.foo = 27;
message.optionalNestedEnum = NanoReferenceTypes.TestAllTypesNano.BAR;
message.repeatedInt32 = new int[] { 5, 6, 7, 8 };
message.repeatedInt64 = new long[] { 27L, 28L, 29L };
message.repeatedFloat = new float[] { 5.0f, 6.0f };
message.repeatedDouble = new double[] { 99.1, 22.5 };
message.repeatedBool = new boolean[] { true, false, true };
message.repeatedString = new String[] { "One", "Two" };
message.repeatedBytes = new byte[][] { { 2, 7 }, { 2, 7 } };
message.repeatedNestedMessage =
new NanoReferenceTypes.TestAllTypesNano.NestedMessage[] {
message.optionalNestedMessage,
message.optionalNestedMessage
};
message.repeatedNestedEnum = new int[] {
NanoReferenceTypes.TestAllTypesNano.BAR,
NanoReferenceTypes.TestAllTypesNano.BAZ
};
return message;
}
public void testNullRepeatedFields() throws Exception {
// Check that serialization after explicitly setting a repeated field
// to null doesn't NPE.
......
......@@ -159,8 +159,46 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
}
}
string EnumFieldGenerator::GetBoxedType() const {
return ClassName(params_, descriptor_->enum_type());
void EnumFieldGenerator::GenerateEqualsCode(io::Printer* printer) const {
if (params_.use_reference_types_for_primitives()) {
printer->Print(variables_,
"if (this.$name$ == null) {\n"
" if (other.$name$ != null) {\n"
" return false;\n"
" }\n"
"} else if (!this.$name$.equals(other.$name$)) {\n"
" return false;"
"}\n");
} else {
// We define equality as serialized form equality. If generate_has(),
// then if the field value equals the default value in both messages,
// but one's 'has' field is set and the other's is not, the serialized
// forms are different and we should return false.
printer->Print(variables_,
"if (this.$name$ != other.$name$");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (this.$name$ == $default$\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
"}\n");
}
}
void EnumFieldGenerator::GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(
"result = 31 * result + ");
if (params_.use_reference_types_for_primitives()) {
printer->Print(variables_,
"(this.$name$ == null ? 0 : this.$name$)");
} else {
printer->Print(variables_,
"this.$name$");
}
printer->Print(";\n");
}
// ===================================================================
......@@ -227,8 +265,19 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
"}\n");
}
string AccessorEnumFieldGenerator::GetBoxedType() const {
return ClassName(params_, descriptor_->enum_type());
void AccessorEnumFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
printer->Print(variables_,
"if ($different_has$\n"
" || $name$_ != other.$name$_) {\n"
" return false;\n"
"}\n");
}
void AccessorEnumFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(variables_,
"result = 31 * result + $name$_;\n");
}
// ===================================================================
......@@ -366,8 +415,20 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
}
}
string RepeatedEnumFieldGenerator::GetBoxedType() const {
return ClassName(params_, descriptor_->enum_type());
void RepeatedEnumFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
printer->Print(variables_,
"if (!com.google.protobuf.nano.InternalNano.equals(\n"
" this.$name$, other.$name$)) {\n"
" return false;\n"
"}\n");
}
void RepeatedEnumFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(variables_,
"result = 31 * result\n"
" + com.google.protobuf.nano.InternalNano.hashCode(this.$name$);\n");
}
} // namespace javanano
......
......@@ -55,8 +55,8 @@ class EnumFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......@@ -77,8 +77,8 @@ class AccessorEnumFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......@@ -98,8 +98,8 @@ class RepeatedEnumFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......
......@@ -62,8 +62,8 @@ class FieldGenerator {
virtual void GenerateMergingCode(io::Printer* printer) const = 0;
virtual void GenerateSerializationCode(io::Printer* printer) const = 0;
virtual void GenerateSerializedSizeCode(io::Printer* printer) const = 0;
virtual string GetBoxedType() const = 0;
virtual void GenerateEqualsCode(io::Printer* printer) const = 0;
virtual void GenerateHashCodeCode(io::Printer* printer) const = 0;
protected:
const Params& params_;
......
......@@ -125,6 +125,8 @@ bool JavaNanoGenerator::Generate(const FileDescriptor* file,
} else if (options[i].first == "optional_field_style") {
params.set_optional_field_accessors(options[i].second == "accessors");
params.set_use_reference_types_for_primitives(options[i].second == "reftypes");
} else if (options[i].first == "generate_equals") {
params.set_generate_equals(options[i].second == "true");
} else {
*error = "Ignore unknown javanano generator option: " + options[i].first;
}
......
......@@ -482,7 +482,7 @@ string GenerateGetBit(int bit_index) {
int bit_in_var_index = bit_index % 32;
string mask = kBitMasks[bit_in_var_index];
string result = "((" + var_name + " & " + mask + ") == " + mask + ")";
string result = "((" + var_name + " & " + mask + ") != 0)";
return result;
}
......@@ -504,11 +504,22 @@ string GenerateClearBit(int bit_index) {
return result;
}
string GenerateDifferentBit(int bit_index) {
string var_name = GetBitFieldNameForBit(bit_index);
int bit_in_var_index = bit_index % 32;
string mask = kBitMasks[bit_in_var_index];
string result = "((" + var_name + " & " + mask
+ ") != (other." + var_name + " & " + mask + "))";
return result;
}
void SetBitOperationVariables(const string name,
int bitIndex, map<string, string>* variables) {
(*variables)["get_" + name] = GenerateGetBit(bitIndex);
(*variables)["set_" + name] = GenerateSetBit(bitIndex);
(*variables)["clear_" + name] = GenerateClearBit(bitIndex);
(*variables)["different_" + name] = GenerateDifferentBit(bitIndex);
}
} // namespace javanano
......
......@@ -141,30 +141,37 @@ string DefaultValue(const Params& params, const FieldDescriptor* field);
// Methods for shared bitfields.
// Gets the name of the shared bitfield for the given index.
// Gets the name of the shared bitfield for the given field index.
string GetBitFieldName(int index);
// Gets the name of the shared bitfield for the given bit index.
// Effectively, GetBitFieldName(bit_index / 32)
string GetBitFieldNameForBit(int bit_index);
// Generates the java code for the expression that returns the boolean value
// of the bit of the shared bitfields for the given bit index.
// Example: "((bitField1_ & 0x04) == 0x04)"
// Generates the java code for the expression that returns whether the bit at
// the given bit index is set.
// Example: "((bitField1_ & 0x04000000) != 0)"
string GenerateGetBit(int bit_index);
// Generates the java code for the expression that sets the bit of the shared
// bitfields for the given bit index.
// Example: "bitField1_ = (bitField1_ | 0x04)"
// Generates the java code for the expression that sets the bit at the given
// bit index.
// Example: "bitField1_ |= 0x04000000"
string GenerateSetBit(int bit_index);
// Generates the java code for the expression that clears the bit of the shared
// bitfields for the given bit index.
// Example: "bitField1_ = (bitField1_ & ~0x04)"
// Generates the java code for the expression that clears the bit at the given
// bit index.
// Example: "bitField1_ = (bitField1_ & ~0x04000000)"
string GenerateClearBit(int bit_index);
// Sets the 'get_*', 'set_*' and 'clear_*' variables, where * is the given bit
// field name, to the appropriate Java expressions for the given bit index.
// Generates the java code for the expression that returns whether the bit at
// the given bit index contains different values in the current object and
// another object accessible via the variable 'other'.
// Example: "((bitField1_ & 0x04000000) != (other.bitField1_ & 0x04000000))"
string GenerateDifferentBit(int bit_index);
// Sets the 'get_*', 'set_*', 'clear_*' and 'different_*' variables, where * is
// the given name of the bit, to the appropriate Java expressions for the given
// bit index.
void SetBitOperationVariables(const string name,
int bitIndex, map<string, string>* variables);
......
......@@ -198,6 +198,11 @@ void MessageGenerator::Generate(io::Printer* printer) {
GenerateClear(printer);
if (params_.generate_equals()) {
GenerateEquals(printer);
GenerateHashCode(printer);
}
// If we have an extension range, generate accessors for extensions.
if (params_.store_unknown_fields()
&& descriptor_->extension_range_count() > 0) {
......@@ -326,11 +331,11 @@ void MessageGenerator::GenerateMergeFromMethods(io::Printer* printer) {
if (params_.store_unknown_fields()) {
printer->Print(
"if (unknownFieldData == null) {\n"
" unknownFieldData = \n"
" unknownFieldData =\n"
" new java.util.ArrayList<com.google.protobuf.nano.UnknownFieldData>();\n"
"}\n"
"if (!com.google.protobuf.nano.WireFormatNano.storeUnknownField(unknownFieldData, \n"
" input, tag)) {\n"
"if (!com.google.protobuf.nano.WireFormatNano.storeUnknownField(\n"
" unknownFieldData, input, tag)) {\n"
" return this;\n"
"}\n");
} else {
......@@ -427,6 +432,79 @@ void MessageGenerator::GenerateClear(io::Printer* printer) {
"}\n");
}
void MessageGenerator::GenerateEquals(io::Printer* printer) {
// Don't override if there are no fields. We could generate an
// equals method that compares types, but often empty messages
// are used as namespaces.
if (descriptor_->field_count() == 0 && !params_.store_unknown_fields()) {
return;
}
printer->Print(
"\n"
"@Override\n"
"public boolean equals(Object o) {\n");
printer->Indent();
printer->Print(
"if (o == this) {\n"
" return true;\n"
"}\n"
"if (!(o instanceof $classname$)) {\n"
" return false;\n"
"}\n"
"$classname$ other = ($classname$) o;\n",
"classname", descriptor_->name());
for (int i = 0; i < descriptor_->field_count(); i++) {
const FieldDescriptor* field = descriptor_->field(i);
field_generators_.get(field).GenerateEqualsCode(printer);
}
if (params_.store_unknown_fields()) {
printer->Print(
"if (unknownFieldData == null || unknownFieldData.isEmpty()) {\n"
" return other.unknownFieldData == null || other.unknownFieldData.isEmpty();"
"} else {\n"
" return unknownFieldData.equals(other.unknownFieldData);\n"
"}\n");
} else {
printer->Print(
"return true;\n");
}
printer->Outdent();
printer->Print("}\n");
}
void MessageGenerator::GenerateHashCode(io::Printer* printer) {
if (descriptor_->field_count() == 0 && !params_.store_unknown_fields()) {
return;
}
printer->Print(
"\n"
"@Override\n"
"public int hashCode() {\n");
printer->Indent();
printer->Print("int result = 17;\n");
for (int i = 0; i < descriptor_->field_count(); i++) {
const FieldDescriptor* field = descriptor_->field(i);
field_generators_.get(field).GenerateHashCodeCode(printer);
}
if (params_.store_unknown_fields()) {
printer->Print(
"result = 31 * result + (unknownFieldData == null || unknownFieldData.isEmpty()\n"
" ? 0 : unknownFieldData.hashCode());\n");
}
printer->Print("return result;\n");
printer->Outdent();
printer->Print("}\n");
}
// ===================================================================
} // namespace javanano
......
......@@ -36,9 +36,10 @@
#define GOOGLE_PROTOBUF_COMPILER_JAVA_MESSAGE_H__
#include <string>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/compiler/javanano/javanano_params.h>
#include <google/protobuf/compiler/javanano/javanano_helpers.h>
#include <google/protobuf/compiler/javanano/javanano_field.h>
#include <google/protobuf/compiler/javanano/javanano_params.h>
#include <google/protobuf/stubs/common.h>
namespace google {
namespace protobuf {
......@@ -76,6 +77,8 @@ class MessageGenerator {
const FieldDescriptor* field);
void GenerateClear(io::Printer* printer);
void GenerateEquals(io::Printer* printer);
void GenerateHashCode(io::Printer* printer);
const Params& params_;
const Descriptor* descriptor_;
......
......@@ -126,8 +126,25 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
"}\n");
}
string MessageFieldGenerator::GetBoxedType() const {
return ClassName(params_, descriptor_->message_type());
void MessageFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
printer->Print(variables_,
"if (this.$name$ == null) { \n"
" if (other.$name$ != null) {\n"
" return false;\n"
" }\n"
"} else {\n"
" if (!this.$name$.equals(other.$name$)) {\n"
" return false;\n"
" }\n"
"}\n");
}
void MessageFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(variables_,
"result = 31 * result +\n"
" (this.$name$ == null ? 0 : this.$name$.hashCode());\n");
}
// ===================================================================
......@@ -203,8 +220,22 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
"}\n");
}
string AccessorMessageFieldGenerator::GetBoxedType() const {
return ClassName(params_, descriptor_->message_type());
void AccessorMessageFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
printer->Print(variables_,
"if ($name$_ == null) {\n"
" if (other.$name$_ != null) {\n"
" return false;\n"
" }\n"
"} else if (!$name$_.equals(other.$name$_)) {\n"
" return false;\n"
"}\n");
}
void AccessorMessageFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(variables_,
"result = 31 * result + ($name$_ == null ? 0 : $name$_.hashCode());\n");
}
// ===================================================================
......@@ -291,8 +322,20 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
"}\n");
}
string RepeatedMessageFieldGenerator::GetBoxedType() const {
return ClassName(params_, descriptor_->message_type());
void RepeatedMessageFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
printer->Print(variables_,
"if (!com.google.protobuf.nano.InternalNano.equals(\n"
" this.$name$, other.$name$)) {\n"
" return false;\n"
"}\n");
}
void RepeatedMessageFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(variables_,
"result = 31 * result\n"
" + com.google.protobuf.nano.InternalNano.hashCode(this.$name$);\n");
}
} // namespace javanano
......
......@@ -55,8 +55,8 @@ class MessageFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......@@ -77,8 +77,8 @@ class AccessorMessageFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......@@ -99,8 +99,8 @@ class RepeatedMessageFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......
......@@ -61,6 +61,7 @@ class Params {
bool java_enum_style_;
bool optional_field_accessors_;
bool use_reference_types_for_primitives_;
bool generate_equals_;
public:
Params(const string & base_name) :
......@@ -71,7 +72,8 @@ class Params {
generate_has_(false),
java_enum_style_(false),
optional_field_accessors_(false),
use_reference_types_for_primitives_(false) {
use_reference_types_for_primitives_(false),
generate_equals_(false) {
}
const string& base_name() const {
......@@ -186,6 +188,13 @@ class Params {
bool use_reference_types_for_primitives() const {
return use_reference_types_for_primitives_;
}
void set_generate_equals(bool value) {
generate_equals_ = value;
}
bool generate_equals() const {
return generate_equals_;
}
};
} // namespace javanano
......
......@@ -404,8 +404,102 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
}
}
string PrimitiveFieldGenerator::GetBoxedType() const {
return BoxedPrimitiveTypeName(GetJavaType(descriptor_));
void PrimitiveFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
// We define equality as serialized form equality. If generate_has(),
// then if the field value equals the default value in both messages,
// but one's 'has' field is set and the other's is not, the serialized
// forms are different and we should return false.
JavaType java_type = GetJavaType(descriptor_);
if (java_type == JAVATYPE_BYTES) {
printer->Print(variables_,
"if (!java.util.Arrays.equals(this.$name$, other.$name$)");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (java.util.Arrays.equals(this.$name$, $default$)\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
"}\n");
} else if (java_type == JAVATYPE_STRING
|| params_.use_reference_types_for_primitives()) {
printer->Print(variables_,
"if (this.$name$ == null) {\n"
" if (other.$name$ != null) {\n"
" return false;\n"
" }\n"
"} else if (!this.$name$.equals(other.$name$)");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (this.$name$.equals($default$)\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
"}\n");
} else {
printer->Print(variables_,
"if (this.$name$ != other.$name$");
if (params_.generate_has()) {
printer->Print(variables_,
"\n"
" || (this.$name$ == $default$\n"
" && this.has$capitalized_name$ != other.has$capitalized_name$)");
}
printer->Print(") {\n"
" return false;\n"
"}\n");
}
}
void PrimitiveFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
JavaType java_type = GetJavaType(descriptor_);
if (java_type == JAVATYPE_BYTES) {
printer->Print(variables_,
"result = 31 * result + java.util.Arrays.hashCode(this.$name$);\n");
} else if (java_type == JAVATYPE_STRING
|| params_.use_reference_types_for_primitives()) {
printer->Print(variables_,
"result = 31 * result\n"
" + (this.$name$ == null ? 0 : this.$name$.hashCode());\n");
} else {
switch (java_type) {
// For all Java primitive types below, the hash codes match the
// results of BoxedType.valueOf(primitiveValue).hashCode().
case JAVATYPE_INT:
printer->Print(variables_,
"result = 31 * result + this.$name$;\n");
break;
case JAVATYPE_LONG:
printer->Print(variables_,
"result = 31 * result\n"
" + (int) (this.$name$ ^ (this.$name$ >>> 32));\n");
break;
case JAVATYPE_FLOAT:
printer->Print(variables_,
"result = 31 * result\n"
" + java.lang.Float.floatToIntBits(this.$name$);\n");
break;
case JAVATYPE_DOUBLE:
printer->Print(variables_,
"{\n"
" long v = java.lang.Double.doubleToLongBits(this.$name$);\n"
" result = 31 * result + (int) (v ^ (v >>> 32));\n"
"}\n");
break;
case JAVATYPE_BOOLEAN:
printer->Print(variables_,
"result = 31 * result + (this.$name$ ? 1231 : 1237);\n");
break;
default:
GOOGLE_LOG(ERROR) << "unknown java type for primitive field";
break;
}
}
}
// ===================================================================
......@@ -483,8 +577,87 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
"}\n");
}
string AccessorPrimitiveFieldGenerator::GetBoxedType() const {
return BoxedPrimitiveTypeName(GetJavaType(descriptor_));
void AccessorPrimitiveFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
switch (GetJavaType(descriptor_)) {
// For all Java primitive types below, the hash codes match the
// results of BoxedType.valueOf(primitiveValue).hashCode().
case JAVATYPE_INT:
case JAVATYPE_LONG:
case JAVATYPE_FLOAT:
case JAVATYPE_DOUBLE:
case JAVATYPE_BOOLEAN:
printer->Print(variables_,
"if ($different_has$\n"
" || $name$_ != other.$name$_) {\n"
" return false;\n"
"}\n");
break;
case JAVATYPE_STRING:
// Accessor style would guarantee $name$_ non-null
printer->Print(variables_,
"if ($different_has$\n"
" || !$name$_.equals(other.$name$_)) {\n"
" return false;\n"
"}\n");
break;
case JAVATYPE_BYTES:
// Accessor style would guarantee $name$_ non-null
printer->Print(variables_,
"if ($different_has$\n"
" || !java.util.Arrays.equals($name$_, other.$name$_)) {\n"
" return false;\n"
"}\n");
break;
default:
GOOGLE_LOG(ERROR) << "unknown java type for primitive field";
break;
}
}
void AccessorPrimitiveFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
switch (GetJavaType(descriptor_)) {
// For all Java primitive types below, the hash codes match the
// results of BoxedType.valueOf(primitiveValue).hashCode().
case JAVATYPE_INT:
printer->Print(variables_,
"result = 31 * result + $name$_;\n");
break;
case JAVATYPE_LONG:
printer->Print(variables_,
"result = 31 * result + (int) ($name$_ ^ ($name$_ >>> 32));\n");
break;
case JAVATYPE_FLOAT:
printer->Print(variables_,
"result = 31 * result +\n"
" java.lang.Float.floatToIntBits($name$_);\n");
break;
case JAVATYPE_DOUBLE:
printer->Print(variables_,
"{\n"
" long v = java.lang.Double.doubleToLongBits($name$_);\n"
" result = 31 * result + (int) (v ^ (v >>> 32));\n"
"}\n");
break;
case JAVATYPE_BOOLEAN:
printer->Print(variables_,
"result = 31 * result + ($name$_ ? 1231 : 1237);\n");
break;
case JAVATYPE_STRING:
// Accessor style would guarantee $name$_ non-null
printer->Print(variables_,
"result = 31 * result + $name$_.hashCode();\n");
break;
case JAVATYPE_BYTES:
// Accessor style would guarantee $name$_ non-null
printer->Print(variables_,
"result = 31 * result + java.util.Arrays.hashCode($name$_);\n");
break;
default:
GOOGLE_LOG(ERROR) << "unknown java type for primitive field";
break;
}
}
// ===================================================================
......@@ -629,8 +802,20 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
"}\n");
}
string RepeatedPrimitiveFieldGenerator::GetBoxedType() const {
return BoxedPrimitiveTypeName(GetJavaType(descriptor_));
void RepeatedPrimitiveFieldGenerator::
GenerateEqualsCode(io::Printer* printer) const {
printer->Print(variables_,
"if (!com.google.protobuf.nano.InternalNano.equals(\n"
" this.$name$, other.$name$)) {\n"
" return false;\n"
"}\n");
}
void RepeatedPrimitiveFieldGenerator::
GenerateHashCodeCode(io::Printer* printer) const {
printer->Print(variables_,
"result = 31 * result\n"
" + com.google.protobuf.nano.InternalNano.hashCode(this.$name$);\n");
}
} // namespace javanano
......
......@@ -55,8 +55,8 @@ class PrimitiveFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
void GenerateSerializationConditional(io::Printer* printer) const;
......@@ -79,8 +79,8 @@ class AccessorPrimitiveFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
const FieldDescriptor* descriptor_;
......@@ -100,8 +100,8 @@ class RepeatedPrimitiveFieldGenerator : public FieldGenerator {
void GenerateMergingCode(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const;
string GetBoxedType() const;
void GenerateEqualsCode(io::Printer* printer) const;
void GenerateHashCodeCode(io::Printer* printer) const;
private:
void GenerateRepeatedDataSizeCode(io::Printer* printer) const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment