Commit d32259c8 authored by Robert Kimball's avatar Robert Kimball Committed by Sang Ik Lee

Add testing for serializing and deserializing graphs (#4105)

* Add serializer/deserializer check to interpreter

* Fix TopK

* Fix GRUCell

* Fix RNNCell. Does anybody test their own code? Apparently not.

* Fix LSTMCell

* Fix MVN

* Fix Select v1

* Fix GroupConvolution

* Fix ScalarConstantLike

* General cleanup

* Revert "General cleanup"

This reverts commit d765d2c2451cf5d3c9a41c4d7d672c278783b0a2.

* Fix op_version_tbl.hpp

* More cleanup

* Fix LSTMSequence

* revert

* Disable INTERPRETER serialize test by default
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent e07bc028
......@@ -63,7 +63,13 @@ runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& f
: m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection}
{
#ifdef INTERPRETER_FORCE_SERIALIZE
// To verify that the serializer works correctly let's just run this graph round-trip
string ser = serialize(function);
m_function = deserialize(ser);
#else
m_function = clone_function(*function);
#endif
pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
......
......@@ -300,13 +300,6 @@ static element::Type read_element_type(json j)
return element::Type(bitwidth, is_real, is_signed, is_quantized, c_type_string);
}
static op::LSTMWeightsFormat read_lstm_weights_format(const json& js)
{
return has_key(js, "weights_format")
? static_cast<op::LSTMWeightsFormat>(js.at("weights_format"))
: op::LSTMWeightsFormat::IFCO;
}
void ngraph::serialize(const string& path, shared_ptr<ngraph::Function> func, size_t indent)
{
ofstream out(path);
......@@ -1518,8 +1511,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Reshape_v1:
{
const auto zero_flag = node_js.at("zero_flag").get<bool>();
node = make_shared<op::v1::Reshape>(args[0], args[1], zero_flag);
const bool special_zero = node_js.at("special_zero").get<bool>();
node = make_shared<op::v1::Reshape>(args[0], args[1], special_zero);
break;
}
case OP_TYPEID::DynSlice:
......@@ -1709,18 +1702,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = read_pad_type(node_js);
node = make_shared<op::GroupConvolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
groups,
pad_type);
if (has_key(node_js, "groups"))
{
auto groups = node_js.at("groups").get<size_t>();
node = make_shared<op::GroupConvolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
groups,
pad_type);
}
else
{
node = make_shared<op::GroupConvolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
pad_type);
}
break;
}
case OP_TYPEID::GroupConvolutionBackpropData:
......@@ -1791,20 +1797,38 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
auto activation_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activations_beta").get<vector<float>>();
auto linear_before_reset = node_js.at("linear_before_reset").get<bool>();
node = make_shared<op::GRUCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
args[4],
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset);
switch (args.size())
{
case 4:
node = make_shared<op::GRUCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset);
break;
case 5:
node = make_shared<op::GRUCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
args[4],
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset);
break;
default: throw runtime_error("GRUCell constructor not supported in serializer");
}
break;
}
case OP_TYPEID::HardSigmoid:
......@@ -1933,14 +1957,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::LSTMCell:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto weights_format = read_lstm_weights_format(node_js);
auto weights_format = get_or_default<op::LSTMWeightsFormat>(
node_js, "weights_format", op::LSTMWeightsFormat::IFCO);
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
if (args.size() == 7)
switch (args.size())
{
case 7:
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
......@@ -1955,9 +1981,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activations_beta,
clip,
input_forget);
}
if (args.size() == 6)
{
break;
case 6:
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
......@@ -1971,9 +1996,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activations_beta,
clip,
input_forget);
}
else
{
break;
case 5:
node = make_shared<op::LSTMCell>(args[0],
args[1],
args[2],
......@@ -1986,19 +2010,22 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
activations_beta,
clip,
input_forget);
break;
default: throw runtime_error("LSTMCell constructor not supported in serializer");
}
break;
}
case OP_TYPEID::LSTMSequence:
{
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto clip = node_js.at("clip_threshold").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activations_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activations_beta = node_js.at("activations_beta").get<vector<float>>();
auto input_forget = node_js.at("input_forget").get<bool>();
auto direction = node_js.at("direction").get<op::LSTMSequence::direction>();
auto weights_format = read_lstm_weights_format(node_js);
auto weights_format = get_or_default<op::LSTMWeightsFormat>(
node_js, "weights_format", op::LSTMWeightsFormat::IFCO);
if (args.size() == 8)
{
node = make_shared<op::LSTMSequence>(args[0],
......@@ -2217,9 +2244,16 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::MVN:
{
auto normalize_variance = node_js.at("normalize_variance").get<bool>();
auto reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
AxisSet reduction_axes = deserialize_axis_set(node_js.at("reduction_axes"));
auto eps = node_js.at("eps").get<double>();
node = make_shared<op::MVN>(args[0], normalize_variance, normalize_variance, eps);
if (reduction_axes.size() > 0)
{
node = make_shared<op::MVN>(args[0], reduction_axes, normalize_variance, eps);
}
else
{
node = make_shared<op::MVN>(args[0], true, normalize_variance, eps);
}
break;
}
case OP_TYPEID::Negative:
......@@ -2583,18 +2617,35 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
auto hidden_size = node_js.at("hidden_size").get<size_t>();
auto clip = node_js.at("clip").get<float>();
auto activations = node_js.at("activations").get<vector<string>>();
auto activation_alpha = node_js.at("activation_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activation_beta").get<vector<float>>();
node = make_shared<op::RNNCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
activations,
activation_alpha,
activation_beta,
clip);
auto activation_alpha = node_js.at("activations_alpha").get<vector<float>>();
auto activation_beta = node_js.at("activations_beta").get<vector<float>>();
switch (args.size())
{
case 4:
node = make_shared<op::RNNCell>(args[0],
args[1],
args[2],
args[3],
hidden_size,
activations,
activation_alpha,
activation_beta,
clip);
break;
case 5:
node = make_shared<op::RNNCell>(args[0],
args[1],
args[2],
args[3],
args[4],
hidden_size,
activations,
activation_alpha,
activation_beta,
clip);
break;
default: throw runtime_error("GRUCell constructor not supported in serializer");
}
break;
}
case OP_TYPEID::ROIPooling: { break;
......@@ -2894,23 +2945,26 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
{
auto compute_max = node_js.at("compute_max").get<bool>();
auto target_type = read_element_type(node_js.at("index_element_type"));
op::TopKSortType sort = node_js.at("sort").get<op::TopKSortType>();
if (has_key(node_js, "top_k_axis"))
{
auto top_k_axis = node_js.at("top_k_axis").get<size_t>();
if (has_key(node_js, "k"))
{
auto k = node_js.at("k").get<size_t>();
node = make_shared<op::TopK>(args[0], top_k_axis, target_type, k, compute_max);
node = make_shared<op::TopK>(
args[0], top_k_axis, target_type, k, compute_max, sort);
}
else
{
node = make_shared<op::TopK>(
args[0], args[1], top_k_axis, target_type, compute_max);
args[0], args[1], top_k_axis, target_type, compute_max, sort);
}
}
else
{
node = make_shared<op::TopK>(args[0], args[1], args[2], target_type, compute_max);
node = make_shared<op::TopK>(
args[0], args[1], args[2], target_type, compute_max, sort);
}
break;
}
......@@ -2987,8 +3041,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
m_node_map[node_name] = node;
}
catch (...)
catch (exception& err)
{
NGRAPH_INFO << err.what();
string node_name;
auto it = node_js.find("name");
if (it != node_js.end())
......@@ -3768,7 +3823,10 @@ json JSONSerializer::serialize_node(const Node& n)
node["padding_below"] = tmp->get_padding_below();
node["padding_above"] = tmp->get_padding_above();
node["data_dilation_strides"] = tmp->get_data_dilation_strides();
node["groups"] = tmp->get_groups();
if (!tmp->has_groups_in_filters())
{
node["groups"] = tmp->get_groups();
}
node["pad_type"] = tmp->get_pad_type();
break;
}
......@@ -4355,8 +4413,8 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto tmp = static_cast<const op::ScalarConstantLike*>(&n);
auto constant = tmp->as_constant();
node["value"] = constant->get_value_strings()[0];
node["element_type"] = write_element_type(constant->get_element_type());
char* p_end;
node["value"] = strtod(constant->get_value_strings()[0].c_str(), &p_end);
break;
}
case OP_TYPEID::ScaleShift: { break;
......@@ -4567,6 +4625,17 @@ json JSONSerializer::serialize_node(const Node& n)
const auto tmp = static_cast<const op::TopK*>(&n);
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
node["compute_max"] = tmp->get_compute_max();
node["sort"] = tmp->get_sort();
switch (tmp->inputs().size())
{
case 1:
node["k"] = tmp->get_k();
node["top_k_axis"] = tmp->get_top_k_axis();
break;
case 2: node["top_k_axis"] = tmp->get_top_k_axis(); break;
case 3: break;
default: throw runtime_error("TopK constructor not supported in serializer");
}
break;
}
......
......@@ -83,7 +83,7 @@ static const map<element::Type_t, const TypeInfo>& get_type_info_map()
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint8_t", "u1")},
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint1_t", "u1")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
......
......@@ -408,7 +408,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_global_lp_pool_p3)
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front()));
EXPECT_TRUE(test::all_close_f(expected_outputs.front(), outputs.front(), 18));
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_convtranspose_output_shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment