Commit d32259c8 authored by Robert Kimball's avatar Robert Kimball Committed by Sang Ik Lee

Add testing for serializing and deserializing graphs (#4105)

* Add serializer/deserializer check to interpreter

* Fix TopK

* Fix GRUCell

* Fix RNNCell. Does anybody test their own code? Apparently not.

* Fix LSTMCell

* Fix MVN

* Fix Select v1

* Fix GroupConvolution

* Fix ScalarConstantLike

* General cleanup

* Revert "General cleanup"

This reverts commit d765d2c2451cf5d3c9a41c4d7d672c278783b0a2.

* Fix op_version_tbl.hpp

* More cleanup

* Fix LSTMSequence

* revert

* Disable INTERPRETER serialize test by default
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent e07bc028
...@@ -63,7 +63,13 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f ...@@ -63,7 +63,13 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
: m_is_compiled{true} : m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection} , m_performance_counters_enabled{enable_performance_collection}
{ {
#ifdef INTERPRETER_FORCE_SERIALIZE
// To verify that the serializer works correctly let's just run this graph round-trip
string ser = serialize(function);
m_function = deserialize(ser);
#else
m_function = clone_function(*function); m_function = clone_function(*function);
#endif
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>(); pass_manager.register_pass<pass::FusedOpDecomposition>();
......
...@@ -300,13 +300,6 @@ static element::Type read_element_type(json j) ...@@ -300,13 +300,6 @@ static element::Type read_element_type(json j)
return element::Type(bitwidth, is_real, is_signed, is_quantized, c_type_string); return element::Type(bitwidth, is_real, is_signed, is_quantized, c_type_string);
} }
static op::LSTMWeightsFormat read_lstm_weights_format(const json& js)
{
return has_key(js, "weights_format")
? static_cast<op::LSTMWeightsFormat>(js.at("weights_format"))
: op::LSTMWeightsFormat::IFCO;
}
void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, size_t indent) void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, size_t indent)
{ {
ofstream out(path); ofstream out(path);
...@@ -1518,8 +1511,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1518,8 +1511,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
case OP_TYPEID::Reshape_v1: case OP_TYPEID::Reshape_v1:
{ {
const auto zero_flag = node_js.at("zero_flag").get<bool>(); const bool special_zero = node_js.at("special_zero").get<bool>();
node = make_shared<op::v1::Reshape>(args[0], args[1], zero_flag); node = make_shared<op::v1::Reshape>(args[0], args[1], special_zero);
break; break;
} }
case OP_TYPEID::DynSlice: case OP_TYPEID::DynSlice:
...@@ -1709,9 +1702,10 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1709,9 +1702,10 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>(); auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = read_pad_type(node_js); op::PadType pad_type = read_pad_type(node_js);
if (has_key(node_js, "groups"))
{
auto groups = node_js.at("groups").get<size_t>();
node = make_shared<op::GroupConvolution>(args[0], node = make_shared<op::GroupConvolution>(args[0],
args[1], args[1],
window_movement_strides, window_movement_strides,
...@@ -1721,6 +1715,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1721,6 +1715,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
data_dilation_strides, data_dilation_strides,
groups, groups,
pad_type); pad_type);
}
else
{
node = make_shared<op::GroupConvolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
pad_type);
}
break; break;
} }
case OP_TYPEID::GroupConvolutionBackpropData: case OP_TYPEID::GroupConvolutionBackpropData:
...@@ -1791,9 +1797,24 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1791,9 +1797,24 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto hidden_size = node_js.at("hidden_size").get<size_t>(); auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>(); auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>(); auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>(); auto activation_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>(); auto activation_beta = node_js.at("activations_beta").get<vector<float>>();
auto linear_before_reset = node_js.at("linear_before_reset").get<bool>(); auto linear_before_reset = node_js.at("linear_before_reset").get<bool>();
switch (args.size())
{
case 4:
node = make_shared<op::GRUCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset);
break;
case 5:
node = make_shared<op::GRUCell>(args[0], node = make_shared<op::GRUCell>(args[0],
args[1], args[1],
args[2], args[2],
...@@ -1806,6 +1827,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1806,6 +1827,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
clip, clip,
linear_before_reset); linear_before_reset);
break; break;
default: throw runtime_error("GRUCell constructor not supported in serializer");
}
break;
} }
case OP_TYPEID::HardSigmoid: case OP_TYPEID::HardSigmoid:
{ {
...@@ -1933,14 +1957,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1933,14 +1957,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::LSTMCell: case OP_TYPEID::LSTMCell:
{ {
auto hidden_size = node_js.at("hidden_size").get<size_t>(); auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto weights_format = read_lstm_weights_format(node_js); auto weights_format = get_or_default<op::LSTMWeightsFormat>(
node_js, "weights_format", op::LSTMWeightsFormat::IFCO);
auto clip = node_js.at("clip").get<float>(); auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>(); auto activations = node_js.at("activations").get<vector<string>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>(); auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>(); auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>(); auto input_forget = node_js.at("input_forget").get<bool>();
if (args.size() == 7) switch (args.size())
{ {
case 7:
node = make_shared<op::LSTMCell>(args[0], node = make_shared<op::LSTMCell>(args[0],
args[1], args[1],
args[2], args[2],
...@@ -1955,9 +1981,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1955,9 +1981,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activations_beta, activations_beta,
clip, clip,
input_forget); input_forget);
} break;
if (args.size() == 6) case 6:
{
node = make_shared<op::LSTMCell>(args[0], node = make_shared<op::LSTMCell>(args[0],
args[1], args[1],
args[2], args[2],
...@@ -1971,9 +1996,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1971,9 +1996,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activations_beta, activations_beta,
clip, clip,
input_forget); input_forget);
} break;
else case 5:
{
node = make_shared<op::LSTMCell>(args[0], node = make_shared<op::LSTMCell>(args[0],
args[1], args[1],
args[2], args[2],
...@@ -1986,19 +2010,22 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1986,19 +2010,22 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activations_beta, activations_beta,
clip, clip,
input_forget); input_forget);
break;
default: throw runtime_error("LSTMCell constructor not supported in serializer");
} }
break; break;
} }
case OP_TYPEID::LSTMSequence: case OP_TYPEID::LSTMSequence:
{ {
auto hidden_size = node_js.at("hidden_size").get<size_t>(); auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>(); auto clip = node_js.at("clip_threshold").get<float>();
auto activations = node_js.at("activations").get<vector<string>>(); auto activations = node_js.at("activations").get<vector<string>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>(); auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>(); auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>(); auto input_forget = node_js.at("input_forget").get<bool>();
auto direction = node_js.at("direction").get<op::LSTMSequence::direction>(); auto direction = node_js.at("direction").get<op::LSTMSequence::direction>();
auto weights_format = read_lstm_weights_format(node_js); auto weights_format = get_or_default<op::LSTMWeightsFormat>(
node_js, "weights_format", op::LSTMWeightsFormat::IFCO);
if (args.size() == 8) if (args.size() == 8)
{ {
node = make_shared<op::LSTMSequence>(args[0], node = make_shared<op::LSTMSequence>(args[0],
...@@ -2217,9 +2244,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2217,9 +2244,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
{ {
auto normalize_variance = node_js.at("normalize_variance").get<bool>(); auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes")); AxisSet reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
auto eps = node_js.at("eps").get<double>(); auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps); if (reduction_axes.size() > 0)
{
node = make_shared<op::MVN>(args[0], reduction_axes, normalize_variance, eps);
}
else
{
node = make_shared<op::MVN>(args[0], true, normalize_variance, eps);
}
break; break;
} }
case OP_TYPEID::Negative: case OP_TYPEID::Negative:
...@@ -2583,8 +2617,22 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2583,8 +2617,22 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto hidden_size = node_js.at("hidden_size").get<size_t>(); auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>(); auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>(); auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>(); auto activation_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>(); auto activation_beta = node_js.at("activations_beta").get<vector<float>>();
switch (args.size())
{
case 4:
node = make_shared<op::RNNCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
activations,
activation_alpha,
activation_beta,
clip);
break;
case 5:
node = make_shared<op::RNNCell>(args[0], node = make_shared<op::RNNCell>(args[0],
args[1], args[1],
args[2], args[2],
...@@ -2596,6 +2644,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2596,6 +2644,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activation_beta, activation_beta,
clip); clip);
break; break;
default: throw runtime_error("GRUCell constructor not supported in serializer");
}
break;
} }
case OP_TYPEID::ROIPooling: { break; case OP_TYPEID::ROIPooling: { break;
} }
...@@ -2894,23 +2945,26 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2894,23 +2945,26 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{ {
auto compute_max = node_js.at("compute_max").get<bool>(); auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type")); auto target_type = read_element_type(node_js.at("index_element_type"));
op::TopKSortType sort = node_js.at("sort").get<op::TopKSortType>();
if (has_key(node_js, "top_k_axis")) if (has_key(node_js, "top_k_axis"))
{ {
auto top_k_axis = node_js.at("top_k_axis").get<size_t>(); auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
if (has_key(node_js, "k")) if (has_key(node_js, "k"))
{ {
auto k = node_js.at("k").get<size_t>(); auto k = node_js.at("k").get<size_t>();
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max); node = make_shared<op::TopK>(
args[0], top_k_axis, target_type, k, compute_max, sort);
} }
else else
{ {
node = make_shared<op::TopK>( node = make_shared<op::TopK>(
args[0], args[1], top_k_axis, target_type, compute_max); args[0], args[1], top_k_axis, target_type, compute_max, sort);
} }
} }
else else
{ {
node = make_shared<op::TopK>(args[0], args[1], args[2], target_type, compute_max); node = make_shared<op::TopK>(
args[0], args[1], args[2], target_type, compute_max, sort);
} }
break; break;
} }
...@@ -2987,8 +3041,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2987,8 +3041,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
m_node_map[node_name] = node; m_node_map[node_name] = node;
} }
catch (...) catch (exception& err)
{ {
NGRAPH_INFO << err.what();
string node_name; string node_name;
auto it = node_js.find("name"); auto it = node_js.find("name");
if (it != node_js.end()) if (it != node_js.end())
...@@ -3768,7 +3823,10 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3768,7 +3823,10 @@ json JSONSerializer::serialize_node(const Node& n)
node["padding_below"] = tmp->get_padding_below(); node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above(); node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides(); node["data_dilation_strides"] = tmp->get_data_dilation_strides();
if (!tmp->has_groups_in_filters())
{
node["groups"] = tmp->get_groups(); node["groups"] = tmp->get_groups();
}
node["pad_type"] = tmp->get_pad_type(); node["pad_type"] = tmp->get_pad_type();
break; break;
} }
...@@ -4355,8 +4413,8 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4355,8 +4413,8 @@ json JSONSerializer::serialize_node(const Node& n)
{ {
auto tmp = static_cast<const op::ScalarConstantLike*>(&n); auto tmp = static_cast<const op::ScalarConstantLike*>(&n);
auto constant = tmp->as_constant(); auto constant = tmp->as_constant();
node["value"] = constant->get_value_strings()[0]; char* p_end;
node["element_type"] = write_element_type(constant->get_element_type()); node["value"] = strtod(constant->get_value_strings()[0].c_str(), &p_end);
break; break;
} }
case OP_TYPEID::ScaleShift: { break; case OP_TYPEID::ScaleShift: { break;
...@@ -4567,6 +4625,17 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4567,6 +4625,17 @@ json JSONSerializer::serialize_node(const Node& n)
const auto tmp = static_cast<const op::TopK*>(&n); const auto tmp = static_cast<const op::TopK*>(&n);
node["index_element_type"] = write_element_type(tmp->get_index_element_type()); node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["compute_max"] = tmp->get_compute_max(); node["compute_max"] = tmp->get_compute_max();
node["sort"] = tmp->get_sort();
switch (tmp->inputs().size())
{
case 1:
node["k"] = tmp->get_k();
node["top_k_axis"] = tmp->get_top_k_axis();
break;
case 2: node["top_k_axis"] = tmp->get_top_k_axis(); break;
case 3: break;
default: throw runtime_error("TopK constructor not supported in serializer");
}
break; break;
} }
......
...@@ -83,7 +83,7 @@ static const map<element::Type_t, const TypeInfo>& get_type_info_map() ...@@ -83,7 +83,7 @@ static const map<element::Type_t, const TypeInfo>& get_type_info_map()
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")}, {element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")}, {element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")}, {element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint8_t", "u1")}, {element::Type_t::u1, TypeInfo(1, false, false, false, "uint1_t", "u1")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")}, {element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")}, {element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")}, {element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
......
...@@ -408,7 +408,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p3) ...@@ -408,7 +408,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p3)
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")}; Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front())); EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front(), 18));
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_convtranspose_output_shape) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_convtranspose_output_shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment