Commit d07a9963 authored by Josh Haberman's avatar Josh Haberman

Ruby: fixed string freezing for JRuby.

parent ff7f68ae
......@@ -148,8 +148,8 @@ public class RubyMap extends RubyObject {
*/
@JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject key, IRubyObject value) {
Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
key = Utils.checkType(context, keyType, key, (RubyModule) valueTypeClass);
value = Utils.checkType(context, valueType, value, (RubyModule) valueTypeClass);
IRubyObject symbol;
if (valueType == Descriptors.FieldDescriptor.Type.ENUM &&
Utils.isRubyNum(value) &&
......
......@@ -504,7 +504,7 @@ public class RubyMessage extends RubyObject {
break;
case BYTES:
case STRING:
Utils.validateStringEncoding(context.runtime, fieldDescriptor.getType(), value);
Utils.validateStringEncoding(context, fieldDescriptor.getType(), value);
RubyString str = (RubyString) value;
switch (fieldDescriptor.getType()) {
case BYTES:
......@@ -695,7 +695,7 @@ public class RubyMessage extends RubyObject {
}
}
if (addValue) {
Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
this.fields.put(fieldDescriptor, value);
} else {
this.fields.remove(fieldDescriptor);
......
......@@ -110,7 +110,7 @@ public class RubyRepeatedField extends RubyObject {
@JRubyMethod(name = "[]=")
public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
int arrIndex = normalizeArrayIndex(index);
Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
IRubyObject defaultValue = defaultValue(context);
for (int i = this.storage.size(); i < arrIndex; i++) {
this.storage.set(i, defaultValue);
......@@ -166,7 +166,7 @@ public class RubyRepeatedField extends RubyObject {
public IRubyObject push(ThreadContext context, IRubyObject value) {
if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
value == context.runtime.getNil())) {
Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
value = Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
}
this.storage.add(value);
return this.storage;
......
......@@ -64,8 +64,8 @@ public class Utils {
return context.runtime.newSymbol(typeName.replace("TYPE_", "").toLowerCase());
}
public static void checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
IRubyObject value, RubyModule typeClass) {
public static IRubyObject checkType(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType,
IRubyObject value, RubyModule typeClass) {
Ruby runtime = context.runtime;
Object val;
switch(fieldType) {
......@@ -106,7 +106,7 @@ public class Utils {
break;
case BYTES:
case STRING:
validateStringEncoding(context.runtime, fieldType, value);
value = validateStringEncoding(context, fieldType, value);
break;
case MESSAGE:
if (value.getMetaClass() != typeClass) {
......@@ -127,6 +127,7 @@ public class Utils {
default:
break;
}
return value;
}
public static IRubyObject wrapPrimaryValue(ThreadContext context, Descriptors.FieldDescriptor.Type fieldType, Object value) {
......@@ -148,10 +149,16 @@ public class Utils {
return runtime.newFloat((Double) value);
case BOOL:
return (Boolean) value ? runtime.getTrue() : runtime.getFalse();
case BYTES:
return runtime.newString(((ByteString) value).toStringUtf8());
case STRING:
return runtime.newString(value.toString());
case BYTES: {
IRubyObject wrapped = runtime.newString(((ByteString) value).toStringUtf8());
wrapped.setFrozen(true);
return wrapped;
}
case STRING: {
IRubyObject wrapped = runtime.newString(value.toString());
wrapped.setFrozen(true);
return wrapped;
}
default:
return runtime.getNil();
}
......@@ -180,25 +187,21 @@ public class Utils {
}
}
public static void validateStringEncoding(Ruby runtime, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
public static IRubyObject validateStringEncoding(ThreadContext context, Descriptors.FieldDescriptor.Type type, IRubyObject value) {
if (!(value instanceof RubyString))
throw runtime.newTypeError("Invalid argument for string field.");
Encoding encoding = ((RubyString) value).getEncoding();
throw context.runtime.newTypeError("Invalid argument for string field.");
switch(type) {
case BYTES:
if (encoding != ASCIIEncoding.INSTANCE)
throw runtime.newTypeError("Encoding for bytes fields" +
" must be \"ASCII-8BIT\", but was " + encoding);
value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::ASCII_8BIT"));
break;
case STRING:
if (encoding != UTF8Encoding.INSTANCE
&& encoding != USASCIIEncoding.INSTANCE)
throw runtime.newTypeError("Encoding for string fields" +
" must be \"UTF-8\" or \"ASCII\", but was " + encoding);
value = ((RubyString)value).encode(context, context.runtime.evalScriptlet("Encoding::UTF_8"));
break;
default:
break;
}
value.setFrozen(true);
return value;
}
public static void checkNameAvailability(ThreadContext context, String name) {
......
......@@ -861,8 +861,10 @@ module BasicTest
m2 = TestMessage.decode_json(json)
assert_equal 'foo', m2.optional_string
assert_equal ['bar1', 'bar2'], m2.repeated_string
assert m2.optional_string.frozen?
assert m2.repeated_string[0].frozen?
if RUBY_PLATFORM != "java"
assert m2.optional_string.frozen?
assert m2.repeated_string[0].frozen?
end
proto = m.to_proto
m2 = TestMessage.decode(proto)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment