Commit d07a9963 authored by Josh Haberman's avatar Josh Haberman

Ruby: fixed string freezing for JRuby.

parent ff7f68ae
...@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject { ...@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
*/ */
@JRubyMethod(name = "[]=") @JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) { public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass); key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass); value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
IRubyObject symbol; IRubyObject symbol;
if (valueType == Descriptors.FieldDescriptor.Type.ENUM && if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
Utils.isRubyNum(value) && Utils.isRubyNum(value) &&
......
...@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject { ...@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject {
break; break;
case BYTES: case BYTES:
case STRING: case STRING:
Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value); Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
RubyString str = (RubyString) value; RubyString str = (RubyString) value;
switch (fieldDescriptor.getType()) { switch (fieldDescriptor.getType()) {
case BYTES: case BYTES:
...@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject { ...@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject {
} }
} }
if (addValue) { if (addValue) {
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
this.fields.put(fieldDescriptor, value); this.fields.put(fieldDescriptor, value);
} else { } else {
this.fields.remove(fieldDescriptor); this.fields.remove(fieldDescriptor);
......
...@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject { ...@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
@JRubyMethod(name = "[]=") @JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) { public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
int arrIndex = normalizeArrayIndex(index); int arrIndex = normalizeArrayIndex(index);
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
IRubyObject defaultValue = defaultValue(context); IRubyObject defaultValue = defaultValue(context);
for (int i = this.storage.size(); i < arrIndex; i++) { for (int i = this.storage.size(); i < arrIndex; i++) {
this.storage.set(i, defaultValue); this.storage.set(i, defaultValue);
...@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject { ...@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
public IRubyObject push(ThreadContext context, IRubyObject value) { public IRubyObject push(ThreadContext context, IRubyObject value) {
if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE && if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
value == context.runtime.getNil())) { value == context.runtime.getNil())) {
Utils.checkType(context, fieldType, value, (RubyModule) typeClass); value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
} }
this.storage.add(value); this.storage.add(value);
return this.storage; return this.storage;
......
...@@ -64,8 +64,8 @@ public class Utils { ...@@ -64,8 +64,8 @@ public class Utils {
return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase()); return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
} }
public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
IRubyObject value, RubyModule typeClass) { IRubyObject value, RubyModule typeClass) {
Ruby runtime = context.runtime; Ruby runtime = context.runtime;
Object val; Object val;
switch(fieldType) { switch(fieldType) {
...@@ -106,7 +106,7 @@ public class Utils { ...@@ -106,7 +106,7 @@ public class Utils {
break; break;
case BYTES: case BYTES:
case STRING: case STRING:
validateStringEncoding(context.runtime, fieldType, value); value = validateStringEncoding(context, fieldType, value);
break; break;
case MESSAGE: case MESSAGE:
if (value.getMetaClass() != typeClass) { if (value.getMetaClass() != typeClass) {
...@@ -127,6 +127,7 @@ public class Utils { ...@@ -127,6 +127,7 @@ public class Utils {
default: default:
break; break;
} }
return value;
} }
public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) { public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
...@@ -148,10 +149,16 @@ public class Utils { ...@@ -148,10 +149,16 @@ public class Utils {
return runtime.newFloat((Double) value); return runtime.newFloat((Double) value);
case BOOL: case BOOL:
return (Boolean) value ? runtime.getTrue() : runtime.getFalse(); return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
case BYTES: case BYTES: {
return runtime.newString(((ByteString) value).toStringUtf8()); IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
case STRING: wrapped.setFrozen(true);
return runtime.newString(value.toString()); return wrapped;
}
case STRING: {
IRubyObject wrapped = runtime.newString(value.toString());
wrapped.setFrozen(true);
return wrapped;
}
default: default:
return runtime.getNil(); return runtime.getNil();
} }
...@@ -180,25 +187,21 @@ public class Utils { ...@@ -180,25 +187,21 @@ public class Utils {
} }
} }
public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) { public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
if (!(value instanceof RubyString)) if (!(value instanceof RubyString))
throw runtime.newTypeError("Invalid argument for string field."); throw context.runtime.newTypeError("Invalid argument for string field.");
Encoding encoding = ((RubyString) value).getEncoding();
switch(type) { switch(type) {
case BYTES: case BYTES:
if (encoding != ASCIIEncoding.INSTANCE) value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
throw runtime.newTypeError("Encoding for bytes fields" +
" must be \"ASCII-8BIT\", but was " + encoding);
break; break;
case STRING: case STRING:
if (encoding != UTF8Encoding.INSTANCE value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
&& encoding != USASCIIEncoding.INSTANCE)
throw runtime.newTypeError("Encoding for string fields" +
" must be \"UTF-8\" or \"ASCII\", but was " + encoding);
break; break;
default: default:
break; break;
} }
value.setFrozen(true);
return value;
} }
public static void checkNameAvailability(ThreadContext context, String name) { public static void checkNameAvailability(ThreadContext context, String name) {
......
...@@ -861,8 +861,10 @@ module BasicTest ...@@ -861,8 +861,10 @@ module BasicTest
m2 = TestMessage.decode_json(json) m2 = TestMessage.decode_json(json)
assert_equal 'foo', m2.optional_string assert_equal 'foo', m2.optional_string
assert_equal ['bar1', 'bar2'], m2.repeated_string assert_equal ['bar1', 'bar2'], m2.repeated_string
assert m2.optional_string.frozen? if RUBY_PLATFORM != "java"
assert m2.repeated_string[0].frozen? assert m2.optional_string.frozen?
assert m2.repeated_string[0].frozen?
end
proto = m.to_proto proto = m.to_proto
m2 = TestMessage.decode(proto) m2 = TestMessage.decode(proto)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment