Unverified Commit d3016b24 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/mergeser (#3097)

* Serialize nodes by reference

* Most od deserialization

* deserialize node

* review comments
parent 778b6004
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <fstream> #include <fstream>
#include <functional> #include <functional>
#include <queue>
#include <stack>
#include "ngraph/cpio.hpp" #include "ngraph/cpio.hpp"
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
...@@ -193,13 +195,51 @@ T get_or_default(nlohmann::json& j, const std::string& key, const T& default_val ...@@ -193,13 +195,51 @@ T get_or_default(nlohmann::json& j, const std::string& key, const T& default_val
return j.count(key) != 0 ? j.at(key).get<T>() : default_value; return j.count(key) != 0 ? j.at(key).get<T>() : default_value;
} }
static std::shared_ptr<ngraph::Function> class JSONSerializer
read_function(const json&, {
std::unordered_map<std::string, std::shared_ptr<Function>>&, public:
function<const_data_callback_t>); void set_indent(size_t indent) { m_indent = indent; }
void set_serialize_output_shapes(bool serialize_output_shapes)
{
m_serialize_output_shapes = serialize_output_shapes;
}
void set_binary_constant_data(bool binary_constant_data)
{
m_binary_constant_data = binary_constant_data;
}
json serialize_function(Function& function);
json serialize_node_reference(Node& node);
json serialize_node(Node& node);
protected:
size_t m_indent{0};
bool m_serialize_output_shapes{false};
bool m_binary_constant_data{false};
json m_json_nodes;
set<Node*> m_nodes_serialized;
queue<Node*> m_nodes_to_serialize;
};
class JSONDeserializer
{
public:
void set_const_data_callback(function<const_data_callback_t> const_data_callback)
{
m_const_data_callback = const_data_callback;
}
shared_ptr<Function> deserialize_function(json& j);
shared_ptr<Node> deserialize_node_reference(json& j);
shared_ptr<Node> deserialize_node(json& j);
protected:
unordered_map<string, shared_ptr<Node>> m_node_map;
unordered_map<string, shared_ptr<Function>> m_function_map;
function<const_data_callback_t> m_const_data_callback;
};
static json write(const ngraph::Function&, bool binary_constant_data);
static json write(const ngraph::Node&, bool binary_constant_data);
static string static string
serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data); serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data);
...@@ -355,10 +395,15 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s ...@@ -355,10 +395,15 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s
} }
#endif #endif
static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data) static string serialize(shared_ptr<Function> func, size_t indent, bool binary_constant_data)
{ {
JSONSerializer serializer;
serializer.set_binary_constant_data(binary_constant_data);
serializer.set_indent(indent);
serializer.set_serialize_output_shapes(s_serialize_output_shapes_enabled);
json j; json j;
j.push_back(write(*func, binary_constant_data)); j.push_back(serializer.serialize_function(*func));
string rc; string rc;
if (indent == 0) if (indent == 0)
...@@ -393,12 +438,8 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in) ...@@ -393,12 +438,8 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
string jstr(data, size); string jstr(data, size);
delete[] data; delete[] data;
json js = json::parse(jstr); json js = json::parse(jstr);
unordered_map<string, shared_ptr<Function>> function_map; JSONDeserializer deserializer;
for (json func : js) deserializer.set_const_data_callback(
{
shared_ptr<Function> f = read_function(
func,
function_map,
[&](const string& const_name, const element::Type& et, const Shape& shape) { [&](const string& const_name, const element::Type& et, const Shape& shape) {
shared_ptr<Node> const_node; shared_ptr<Node> const_node;
for (const cpio::FileInfo& info : file_info) for (const cpio::FileInfo& info : file_info)
...@@ -414,7 +455,9 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in) ...@@ -414,7 +455,9 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
} }
return const_node; return const_node;
}); });
rc = f; for (json func : js)
{
rc = deserializer.deserialize_function(func);
} }
} }
} }
...@@ -440,18 +483,17 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s) ...@@ -440,18 +483,17 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
else else
{ {
json js = json::parse(s); json js = json::parse(s);
unordered_map<string, shared_ptr<Function>> function_map; JSONDeserializer deserializer;
for (json func : js) for (json func : js)
{ {
shared_ptr<Function> f = read_function(func, function_map, nullptr); rc = deserializer.deserialize_function(func);
rc = f;
} }
} }
return rc; return rc;
} }
static json write(const Function& f, bool binary_constant_data) json JSONSerializer::serialize_function(Function& f)
{ {
json function; json function;
function["name"] = f.get_name(); function["name"] = f.get_name();
...@@ -459,24 +501,16 @@ static json write(const Function& f, bool binary_constant_data) ...@@ -459,24 +501,16 @@ static json write(const Function& f, bool binary_constant_data)
vector<string> parameter_list; vector<string> parameter_list;
for (auto param : f.get_parameters()) for (auto param : f.get_parameters())
{ {
parameter_list.push_back(param->get_name()); parameter_list.push_back(serialize_node_reference(*param));
} }
function["parameters"] = parameter_list; function["parameters"] = parameter_list;
// TODO Functions can return multiple results // TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i) for (size_t i = 0; i < f.get_output_size(); ++i)
{ {
function["result"].push_back(f.get_output_op(i)->get_name()); function["result"].push_back(serialize_node_reference(*f.get_output_op(i)));
} }
function["ops"] = m_json_nodes;
Function* pf = const_cast<Function*>(&f);
json nodes;
for (shared_ptr<Node> node : pf->get_ordered_ops(true))
{
nodes.push_back(write(*node, binary_constant_data));
}
function["ops"] = nodes;
return function; return function;
} }
...@@ -492,32 +526,74 @@ T get_value(nlohmann::json js, const string& key) ...@@ -492,32 +526,74 @@ T get_value(nlohmann::json js, const string& key)
return rc; return rc;
} }
static shared_ptr<ngraph::Function> shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json& j)
read_function(const json& func_js,
unordered_map<string, shared_ptr<Function>>& function_map,
function<const_data_callback_t> const_data_callback)
{ {
shared_ptr<ngraph::Function> rc; const string& name = j;
return m_node_map.at(name);
}
shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
{
string func_name = func_js.at("name").get<string>(); string func_name = func_js.at("name").get<string>();
vector<string> func_parameters = func_js.at("parameters").get<vector<string>>(); vector<json> func_parameters = func_js.at("parameters");
vector<string> func_result = func_js.at("result").get<vector<string>>(); vector<json> func_result = func_js.at("result");
unordered_map<string, shared_ptr<Node>> node_map;
for (json node_js : func_js.at("ops")) for (json node_js : func_js.at("ops"))
{ {
deserialize_node(node_js);
}
// This handles both graphs w/ `op::Result` and legacy graphs w/o it
// If we are dealing w/ a legacy graph, add op::Result for each output node
ResultVector result;
size_t results = 0;
for (auto& result_ref : func_result)
{
auto fr = deserialize_node_reference(result_ref);
if (auto res = std::dynamic_pointer_cast<op::Result>(fr))
{
result.push_back(res);
// make sure we have `op::Result` on top of all outputs
results++;
}
else
{
result.push_back(std::make_shared<op::Result>(fr));
}
}
if (results != 0 && results != func_result.size())
{
throw ngraph_error(
"Graph serialization is inconsistent. Some op::Results appear to be missing");
}
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : func_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
}
shared_ptr<Function> rc{make_shared<Function>(result, params, func_name)};
m_function_map[func_name] = rc;
return rc;
}
shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
{
shared_ptr<Node> node;
try try
{ {
string node_name = node_js.at("name").get<string>(); string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>(); string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name"); string friendly_name = get_value<string>(node_js, "friendly_name");
vector<string> node_inputs = get_value<vector<string>>(node_js, "inputs"); vector<json> node_inputs = get_value<vector<json>>(node_js, "inputs");
vector<string> control_deps_inputs = get_value<vector<string>>(node_js, "control_deps"); vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs"); vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
shared_ptr<Node> node;
vector<shared_ptr<Node>> args; vector<shared_ptr<Node>> args;
for (const string& name : node_inputs) for (auto& node_input : node_inputs)
{ {
args.push_back(node_map.at(name)); args.push_back(deserialize_node_reference(node_input));
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push #pragma GCC diagnostic push
...@@ -539,8 +615,7 @@ static shared_ptr<ngraph::Function> ...@@ -539,8 +615,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
node = node = make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
...@@ -556,8 +631,7 @@ static shared_ptr<ngraph::Function> ...@@ -556,8 +631,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
node = node = make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
...@@ -813,8 +887,7 @@ static shared_ptr<ngraph::Function> ...@@ -813,8 +887,7 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
node_js.at("data_dilation_strides").get<vector<size_t>>();
node = make_shared<op::ConvolutionBias>(args[0], node = make_shared<op::ConvolutionBias>(args[0],
args[1], args[1],
...@@ -834,8 +907,7 @@ static shared_ptr<ngraph::Function> ...@@ -834,8 +907,7 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
node_js.at("data_dilation_strides").get<vector<size_t>>();
node = make_shared<op::ConvolutionBiasAdd>(args[0], node = make_shared<op::ConvolutionBiasAdd>(args[0],
args[1], args[1],
...@@ -862,8 +934,8 @@ static shared_ptr<ngraph::Function> ...@@ -862,8 +934,8 @@ static shared_ptr<ngraph::Function>
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>(); node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward = auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>(); node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBiasBackpropFiltersBias>( node =
args[0], make_shared<op::ConvolutionBiasBackpropFiltersBias>(args[0],
filters_shape, filters_shape,
bias_shape, bias_shape,
args[1], args[1],
...@@ -955,8 +1027,7 @@ static shared_ptr<ngraph::Function> ...@@ -955,8 +1027,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
node = node = make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Erf: case OP_TYPEID::Erf:
...@@ -972,8 +1043,8 @@ static shared_ptr<ngraph::Function> ...@@ -972,8 +1043,8 @@ static shared_ptr<ngraph::Function>
case OP_TYPEID::FakeQuantize: case OP_TYPEID::FakeQuantize:
{ {
size_t levels = node_js.at("levels").get<size_t>(); size_t levels = node_js.at("levels").get<size_t>();
node = make_shared<op::FakeQuantize>( node =
args[0], args[1], args[2], args[3], args[4], levels); make_shared<op::FakeQuantize>(args[0], args[1], args[2], args[3], args[4], levels);
break; break;
} }
case OP_TYPEID::Floor: case OP_TYPEID::Floor:
...@@ -998,8 +1069,7 @@ static shared_ptr<ngraph::Function> ...@@ -998,8 +1069,7 @@ static shared_ptr<ngraph::Function>
auto beta = node_js.at("beta").get<double>(); auto beta = node_js.at("beta").get<double>();
auto transA = node_js.at("transA").get<bool>(); auto transA = node_js.at("transA").get<bool>();
auto transB = node_js.at("transB").get<bool>(); auto transB = node_js.at("transB").get<bool>();
node = node = make_shared<op::Gemm>(args[0], args[1], args[2], alpha, beta, transA, transB);
make_shared<op::Gemm>(args[0], args[1], args[2], alpha, beta, transA, transB);
break; break;
} }
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
...@@ -1021,14 +1091,14 @@ static shared_ptr<ngraph::Function> ...@@ -1021,14 +1091,14 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
node = make_shared<op::Greater>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Greater>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
node = make_shared<op::GreaterEq>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::GreaterEq>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::GRN: case OP_TYPEID::GRN:
...@@ -1052,8 +1122,7 @@ static shared_ptr<ngraph::Function> ...@@ -1052,8 +1122,7 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>(); node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>(); auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty() op::PadType pad_type = node_js["pad_type"].empty()
...@@ -1103,14 +1172,12 @@ static shared_ptr<ngraph::Function> ...@@ -1103,14 +1172,12 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
node = node = make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
node = make_shared<op::LessEq>( node = make_shared<op::LessEq>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Log: case OP_TYPEID::Log:
...@@ -1225,8 +1292,8 @@ static shared_ptr<ngraph::Function> ...@@ -1225,8 +1292,8 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
node = make_shared<op::Maximum>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Maximum>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Min: case OP_TYPEID::Min:
...@@ -1237,14 +1304,14 @@ static shared_ptr<ngraph::Function> ...@@ -1237,14 +1304,14 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
node = make_shared<op::Minimum>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Minimum>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
node = make_shared<op::Multiply>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Multiply>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::MVN: case OP_TYPEID::MVN:
...@@ -1265,14 +1332,14 @@ static shared_ptr<ngraph::Function> ...@@ -1265,14 +1332,14 @@ static shared_ptr<ngraph::Function>
bool across_spatial = node_js.at("across_spatial").get<bool>(); bool across_spatial = node_js.at("across_spatial").get<bool>();
bool channel_shared = node_js.at("channel_shared").get<bool>(); bool channel_shared = node_js.at("channel_shared").get<bool>();
float eps = node_js.at("eps").get<float>(); float eps = node_js.at("eps").get<float>();
node = make_shared<op::Normalize>( node =
args[0], args[1], across_spatial, channel_shared, eps); make_shared<op::Normalize>(args[0], args[1], across_spatial, channel_shared, eps);
break; break;
} }
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
node = make_shared<op::NotEqual>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::NotEqual>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Not: case OP_TYPEID::Not:
...@@ -1310,8 +1377,7 @@ static shared_ptr<ngraph::Function> ...@@ -1310,8 +1377,7 @@ static shared_ptr<ngraph::Function>
? op::PadMode::CONSTANT ? op::PadMode::CONSTANT
: static_cast<op::PadMode>(node_js.at("pad_mode")); : static_cast<op::PadMode>(node_js.at("pad_mode"));
node = node = make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode);
make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode);
break; break;
} }
case OP_TYPEID::Parameter: case OP_TYPEID::Parameter:
...@@ -1321,8 +1387,7 @@ static shared_ptr<ngraph::Function> ...@@ -1321,8 +1387,7 @@ static shared_ptr<ngraph::Function>
auto element_type = read_element_type(type_node_js.at("element_type")); auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape"); auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false); auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node = node = make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break; break;
} }
case OP_TYPEID::Passthrough: case OP_TYPEID::Passthrough:
...@@ -1343,8 +1408,7 @@ static shared_ptr<ngraph::Function> ...@@ -1343,8 +1408,7 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
node = node = make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::PRelu: case OP_TYPEID::PRelu:
...@@ -1400,8 +1464,7 @@ static shared_ptr<ngraph::Function> ...@@ -1400,8 +1464,7 @@ static shared_ptr<ngraph::Function>
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>(); auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>(); auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js["data_dilation_strides"]; auto data_dilation_strides = node_js["data_dilation_strides"];
node = node = make_shared<op::Convolution>(args[0],
make_shared<op::Convolution>(args[0],
args[1], args[1],
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
...@@ -1473,8 +1536,7 @@ static shared_ptr<ngraph::Function> ...@@ -1473,8 +1536,7 @@ static shared_ptr<ngraph::Function>
{ {
auto batch_axis = node_js.at("batch_axis").get<size_t>(); auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>(); auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
node = node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break; break;
} }
case OP_TYPEID::ScalarConstantLike: case OP_TYPEID::ScalarConstantLike:
...@@ -1584,8 +1646,8 @@ static shared_ptr<ngraph::Function> ...@@ -1584,8 +1646,8 @@ static shared_ptr<ngraph::Function>
} }
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
node = make_shared<op::Subtract>( node =
args[0], args[1], read_auto_broadcast(node_js["autob"])); make_shared<op::Subtract>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
...@@ -1644,9 +1706,9 @@ static shared_ptr<ngraph::Function> ...@@ -1644,9 +1706,9 @@ static shared_ptr<ngraph::Function>
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
for (const string& name : control_deps_inputs) for (auto& control_dep : control_deps_inputs)
{ {
node->add_control_dependency(node_map.at(name)); node->add_control_dependency(deserialize_node_reference(control_dep));
} }
if (!friendly_name.empty()) if (!friendly_name.empty())
...@@ -1657,7 +1719,7 @@ static shared_ptr<ngraph::Function> ...@@ -1657,7 +1719,7 @@ static shared_ptr<ngraph::Function>
{ {
node->set_friendly_name(node_name); node->set_friendly_name(node_name);
} }
node_map[node_name] = node; m_node_map[node_name] = node;
} }
catch (...) catch (...)
{ {
...@@ -1673,47 +1735,37 @@ static shared_ptr<ngraph::Function> ...@@ -1673,47 +1735,37 @@ static shared_ptr<ngraph::Function>
} }
throw runtime_error("Error parsing json at node '" + node_name + "'"); throw runtime_error("Error parsing json at node '" + node_name + "'");
} }
} return node;
}
// This handles both graphs w/ `op::Result` and legacy graphs w/o it json JSONSerializer::serialize_node_reference(Node& n)
// If we are dealing w/ a legacy graph, add op::Result for each output node {
ResultVector result; if (m_nodes_serialized.count(&n) != 1)
size_t results = 0;
for (auto result_name : func_result)
{ {
auto fr = node_map.at(result_name); m_nodes_to_serialize.push(&n);
if (auto res = std::dynamic_pointer_cast<op::Result>(fr)) if (m_nodes_to_serialize.size() == 1)
{ {
result.push_back(res); // Nothing in the queue
// make sure we have `op::Result` on top of all outputs stack<json> serialized_nodes;
results++; while (!m_nodes_to_serialize.empty())
}
else
{ {
result.push_back(std::make_shared<op::Result>(fr)); Node* next_node = m_nodes_to_serialize.front();
} m_nodes_to_serialize.pop();
serialized_nodes.push(serialize_node(*next_node));
} }
while (serialized_nodes.size() > 0)
if (results != 0 && results != func_result.size())
{ {
throw ngraph_error( m_json_nodes.push_back(serialized_nodes.top());
" Graph serialization is inconsistent. Some op::Results appear to be missing"); serialized_nodes.pop();
} }
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto param_name : func_parameters)
{
params.push_back(dynamic_pointer_cast<op::Parameter>(node_map.at(param_name)));
} }
}
rc = make_shared<Function>(result, params, func_name); return n.get_name();
function_map[func_name] = rc;
return rc;
} }
static json write(const Node& n, bool binary_constant_data) json JSONSerializer::serialize_node(Node& n)
{ {
m_nodes_serialized.insert(&n);
json node; json node;
node["name"] = n.get_name(); node["name"] = n.get_name();
if (n.get_name() != n.get_friendly_name()) if (n.get_name() != n.get_friendly_name())
...@@ -1728,11 +1780,11 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1728,11 +1780,11 @@ static json write(const Node& n, bool binary_constant_data)
for (auto& input : n.inputs()) for (auto& input : n.inputs())
{ {
inputs.push_back(input.get_source_output().get_node()->get_name()); inputs.push_back(serialize_node_reference(*input.get_source_output().get_node()));
} }
for (auto cdep : n.get_control_dependencies()) for (auto cdep : n.get_control_dependencies())
{ {
control_deps.push_back(cdep->get_name()); control_deps.push_back(serialize_node_reference(*cdep));
} }
for (auto& output : n.outputs()) for (auto& output : n.outputs())
{ {
...@@ -2538,6 +2590,5 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -2538,6 +2590,5 @@ static json write(const Node& n, bool binary_constant_data)
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
return node; return node;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment