Unverified Commit d3016b24 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/mergeser (#3097)

* Serialize nodes by reference

* Most od deserialization

* deserialize node

* review comments
parent 778b6004
......@@ -16,6 +16,8 @@
#include <fstream>
#include <functional>
#include <queue>
#include <stack>
#include "ngraph/cpio.hpp"
#include "ngraph/file_util.hpp"
......@@ -193,13 +195,51 @@ T get_or_default(nlohmann::json& j, const std::string& key, const T& default_val
return j.count(key) != 0 ? j.at(key).get<T>() : default_value;
}
static std::shared_ptr<ngraph::Function>
read_function(const json&,
std::unordered_map<std::string, std::shared_ptr<Function>>&,
function<const_data_callback_t>);
class JSONSerializer
{
public:
void set_indent(size_t indent) { m_indent = indent; }
void set_serialize_output_shapes(bool serialize_output_shapes)
{
m_serialize_output_shapes = serialize_output_shapes;
}
void set_binary_constant_data(bool binary_constant_data)
{
m_binary_constant_data = binary_constant_data;
}
json serialize_function(Function& function);
json serialize_node_reference(Node& node);
json serialize_node(Node& node);
protected:
size_t m_indent{0};
bool m_serialize_output_shapes{false};
bool m_binary_constant_data{false};
json m_json_nodes;
set<Node*> m_nodes_serialized;
queue<Node*> m_nodes_to_serialize;
};
class JSONDeserializer
{
public:
void set_const_data_callback(function<const_data_callback_t> const_data_callback)
{
m_const_data_callback = const_data_callback;
}
shared_ptr<Function> deserialize_function(json& j);
shared_ptr<Node> deserialize_node_reference(json& j);
shared_ptr<Node> deserialize_node(json& j);
protected:
unordered_map<string, shared_ptr<Node>> m_node_map;
unordered_map<string, shared_ptr<Function>> m_function_map;
function<const_data_callback_t> m_const_data_callback;
};
static json write(const ngraph::Function&, bool binary_constant_data);
static json write(const ngraph::Node&, bool binary_constant_data);
static string
serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data);
......@@ -355,10 +395,15 @@ static void serialize_to_cpio(ostream& out, shared_ptr<ngraph::Function> func, s
}
#endif
static string serialize(shared_ptr<ngraph::Function> func, size_t indent, bool binary_constant_data)
static string serialize(shared_ptr<Function> func, size_t indent, bool binary_constant_data)
{
JSONSerializer serializer;
serializer.set_binary_constant_data(binary_constant_data);
serializer.set_indent(indent);
serializer.set_serialize_output_shapes(s_serialize_output_shapes_enabled);
json j;
j.push_back(write(*func, binary_constant_data));
j.push_back(serializer.serialize_function(*func));
string rc;
if (indent == 0)
......@@ -393,12 +438,8 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
string jstr(data, size);
delete[] data;
json js = json::parse(jstr);
unordered_map<string, shared_ptr<Function>> function_map;
for (json func : js)
{
shared_ptr<Function> f = read_function(
func,
function_map,
JSONDeserializer deserializer;
deserializer.set_const_data_callback(
[&](const string& const_name, const element::Type& et, const Shape& shape) {
shared_ptr<Node> const_node;
for (const cpio::FileInfo& info : file_info)
......@@ -414,7 +455,9 @@ shared_ptr<ngraph::Function> ngraph::deserialize(istream& in)
}
return const_node;
});
rc = f;
for (json func : js)
{
rc = deserializer.deserialize_function(func);
}
}
}
......@@ -440,18 +483,17 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
else
{
json js = json::parse(s);
unordered_map<string, shared_ptr<Function>> function_map;
JSONDeserializer deserializer;
for (json func : js)
{
shared_ptr<Function> f = read_function(func, function_map, nullptr);
rc = f;
rc = deserializer.deserialize_function(func);
}
}
return rc;
}
static json write(const Function& f, bool binary_constant_data)
json JSONSerializer::serialize_function(Function& f)
{
json function;
function["name"] = f.get_name();
......@@ -459,24 +501,16 @@ static json write(const Function& f, bool binary_constant_data)
vector<string> parameter_list;
for (auto param : f.get_parameters())
{
parameter_list.push_back(param->get_name());
parameter_list.push_back(serialize_node_reference(*param));
}
function["parameters"] = parameter_list;
// TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i)
{
function["result"].push_back(f.get_output_op(i)->get_name());
function["result"].push_back(serialize_node_reference(*f.get_output_op(i)));
}
Function* pf = const_cast<Function*>(&f);
json nodes;
for (shared_ptr<Node> node : pf->get_ordered_ops(true))
{
nodes.push_back(write(*node, binary_constant_data));
}
function["ops"] = nodes;
function["ops"] = m_json_nodes;
return function;
}
......@@ -492,32 +526,74 @@ T get_value(nlohmann::json js, const string& key)
return rc;
}
static shared_ptr<ngraph::Function>
read_function(const json& func_js,
unordered_map<string, shared_ptr<Function>>& function_map,
function<const_data_callback_t> const_data_callback)
shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json& j)
{
shared_ptr<ngraph::Function> rc;
const string& name = j;
return m_node_map.at(name);
}
shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
{
string func_name = func_js.at("name").get<string>();
vector<string> func_parameters = func_js.at("parameters").get<vector<string>>();
vector<string> func_result = func_js.at("result").get<vector<string>>();
unordered_map<string, shared_ptr<Node>> node_map;
vector<json> func_parameters = func_js.at("parameters");
vector<json> func_result = func_js.at("result");
for (json node_js : func_js.at("ops"))
{
deserialize_node(node_js);
}
// This handles both graphs w/ `op::Result` and legacy graphs w/o it
// If we are dealing w/ a legacy graph, add op::Result for each output node
ResultVector result;
size_t results = 0;
for (auto& result_ref : func_result)
{
auto fr = deserialize_node_reference(result_ref);
if (auto res = std::dynamic_pointer_cast<op::Result>(fr))
{
result.push_back(res);
// make sure we have `op::Result` on top of all outputs
results++;
}
else
{
result.push_back(std::make_shared<op::Result>(fr));
}
}
if (results != 0 && results != func_result.size())
{
throw ngraph_error(
"Graph serialization is inconsistent. Some op::Results appear to be missing");
}
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto& param_ref : func_parameters)
{
params.push_back(
dynamic_pointer_cast<op::Parameter>(deserialize_node_reference(param_ref)));
}
shared_ptr<Function> rc{make_shared<Function>(result, params, func_name)};
m_function_map[func_name] = rc;
return rc;
}
shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
{
shared_ptr<Node> node;
try
{
string node_name = node_js.at("name").get<string>();
string node_op = node_js.at("op").get<string>();
string friendly_name = get_value<string>(node_js, "friendly_name");
vector<string> node_inputs = get_value<vector<string>>(node_js, "inputs");
vector<string> control_deps_inputs = get_value<vector<string>>(node_js, "control_deps");
vector<json> node_inputs = get_value<vector<json>>(node_js, "inputs");
vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
shared_ptr<Node> node;
vector<shared_ptr<Node>> args;
for (const string& name : node_inputs)
for (auto& node_input : node_inputs)
{
args.push_back(node_map.at(name));
args.push_back(deserialize_node_reference(node_input));
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
......@@ -539,8 +615,7 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Add:
{
node =
make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Add>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::All:
......@@ -556,8 +631,7 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::And:
{
node =
make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::And>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Any:
......@@ -813,8 +887,7 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides =
node_js.at("data_dilation_strides").get<vector<size_t>>();
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
node = make_shared<op::ConvolutionBias>(args[0],
args[1],
......@@ -834,8 +907,7 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides =
node_js.at("data_dilation_strides").get<vector<size_t>>();
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
node = make_shared<op::ConvolutionBiasAdd>(args[0],
args[1],
......@@ -862,8 +934,8 @@ static shared_ptr<ngraph::Function>
node_js.at("padding_above_forward").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides_forward =
node_js.at("data_dilation_strides_forward").get<vector<size_t>>();
node = make_shared<op::ConvolutionBiasBackpropFiltersBias>(
args[0],
node =
make_shared<op::ConvolutionBiasBackpropFiltersBias>(args[0],
filters_shape,
bias_shape,
args[1],
......@@ -955,8 +1027,7 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Equal:
{
node =
make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Equal>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Erf:
......@@ -972,8 +1043,8 @@ static shared_ptr<ngraph::Function>
case OP_TYPEID::FakeQuantize:
{
size_t levels = node_js.at("levels").get<size_t>();
node = make_shared<op::FakeQuantize>(
args[0], args[1], args[2], args[3], args[4], levels);
node =
make_shared<op::FakeQuantize>(args[0], args[1], args[2], args[3], args[4], levels);
break;
}
case OP_TYPEID::Floor:
......@@ -998,8 +1069,7 @@ static shared_ptr<ngraph::Function>
auto beta = node_js.at("beta").get<double>();
auto transA = node_js.at("transA").get<bool>();
auto transB = node_js.at("transB").get<bool>();
node =
make_shared<op::Gemm>(args[0], args[1], args[2], alpha, beta, transA, transB);
node = make_shared<op::Gemm>(args[0], args[1], args[2], alpha, beta, transA, transB);
break;
}
case OP_TYPEID::GenerateMask:
......@@ -1021,14 +1091,14 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Greater:
{
node = make_shared<op::Greater>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::Greater>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::GreaterEq:
{
node = make_shared<op::GreaterEq>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::GreaterEq>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::GRN:
......@@ -1052,8 +1122,7 @@ static shared_ptr<ngraph::Function>
node_js.at("window_dilation_strides").get<vector<size_t>>();
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides =
node_js.at("data_dilation_strides").get<vector<size_t>>();
auto data_dilation_strides = node_js.at("data_dilation_strides").get<vector<size_t>>();
auto groups = node_js.at("groups").get<size_t>();
op::PadType pad_type = node_js["pad_type"].empty()
......@@ -1103,14 +1172,12 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Less:
{
node =
make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Less>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::LessEq:
{
node = make_shared<op::LessEq>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::LessEq>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Log:
......@@ -1225,8 +1292,8 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Maximum:
{
node = make_shared<op::Maximum>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::Maximum>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Min:
......@@ -1237,14 +1304,14 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Minimum:
{
node = make_shared<op::Minimum>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::Minimum>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Multiply:
{
node = make_shared<op::Multiply>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::Multiply>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::MVN:
......@@ -1265,14 +1332,14 @@ static shared_ptr<ngraph::Function>
bool across_spatial = node_js.at("across_spatial").get<bool>();
bool channel_shared = node_js.at("channel_shared").get<bool>();
float eps = node_js.at("eps").get<float>();
node = make_shared<op::Normalize>(
args[0], args[1], across_spatial, channel_shared, eps);
node =
make_shared<op::Normalize>(args[0], args[1], across_spatial, channel_shared, eps);
break;
}
case OP_TYPEID::NotEqual:
{
node = make_shared<op::NotEqual>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::NotEqual>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Not:
......@@ -1310,8 +1377,7 @@ static shared_ptr<ngraph::Function>
? op::PadMode::CONSTANT
: static_cast<op::PadMode>(node_js.at("pad_mode"));
node =
make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode);
node = make_shared<op::Pad>(args[0], args[1], padding_below, padding_above, pad_mode);
break;
}
case OP_TYPEID::Parameter:
......@@ -1321,8 +1387,7 @@ static shared_ptr<ngraph::Function>
auto element_type = read_element_type(type_node_js.at("element_type"));
auto shape = type_node_js.at("shape");
auto cacheable = get_or_default<bool>(node_js, "cacheable", false);
node =
make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
node = make_shared<op::Parameter>(element_type, read_partial_shape(shape), cacheable);
break;
}
case OP_TYPEID::Passthrough:
......@@ -1343,8 +1408,7 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Power:
{
node =
make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
node = make_shared<op::Power>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::PRelu:
......@@ -1400,8 +1464,7 @@ static shared_ptr<ngraph::Function>
auto padding_below = node_js.at("padding_below").get<vector<std::ptrdiff_t>>();
auto padding_above = node_js.at("padding_above").get<vector<std::ptrdiff_t>>();
auto data_dilation_strides = node_js["data_dilation_strides"];
node =
make_shared<op::Convolution>(args[0],
node = make_shared<op::Convolution>(args[0],
args[1],
window_movement_strides,
window_dilation_strides,
......@@ -1473,8 +1536,7 @@ static shared_ptr<ngraph::Function>
{
auto batch_axis = node_js.at("batch_axis").get<size_t>();
auto sequence_axis = node_js.at("sequence_axis").get<size_t>();
node =
make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
node = make_shared<op::ReverseSequence>(args[0], args[1], batch_axis, sequence_axis);
break;
}
case OP_TYPEID::ScalarConstantLike:
......@@ -1584,8 +1646,8 @@ static shared_ptr<ngraph::Function>
}
case OP_TYPEID::Subtract:
{
node = make_shared<op::Subtract>(
args[0], args[1], read_auto_broadcast(node_js["autob"]));
node =
make_shared<op::Subtract>(args[0], args[1], read_auto_broadcast(node_js["autob"]));
break;
}
case OP_TYPEID::Sum:
......@@ -1644,9 +1706,9 @@ static shared_ptr<ngraph::Function>
#pragma GCC diagnostic pop
#endif
for (const string& name : control_deps_inputs)
for (auto& control_dep : control_deps_inputs)
{
node->add_control_dependency(node_map.at(name));
node->add_control_dependency(deserialize_node_reference(control_dep));
}
if (!friendly_name.empty())
......@@ -1657,7 +1719,7 @@ static shared_ptr<ngraph::Function>
{
node->set_friendly_name(node_name);
}
node_map[node_name] = node;
m_node_map[node_name] = node;
}
catch (...)
{
......@@ -1673,47 +1735,37 @@ static shared_ptr<ngraph::Function>
}
throw runtime_error("Error parsing json at node '" + node_name + "'");
}
}
return node;
}
// This handles both graphs w/ `op::Result` and legacy graphs w/o it
// If we are dealing w/ a legacy graph, add op::Result for each output node
ResultVector result;
size_t results = 0;
for (auto result_name : func_result)
json JSONSerializer::serialize_node_reference(Node& n)
{
if (m_nodes_serialized.count(&n) != 1)
{
auto fr = node_map.at(result_name);
if (auto res = std::dynamic_pointer_cast<op::Result>(fr))
m_nodes_to_serialize.push(&n);
if (m_nodes_to_serialize.size() == 1)
{
result.push_back(res);
// make sure we have `op::Result` on top of all outputs
results++;
}
else
// Nothing in the queue
stack<json> serialized_nodes;
while (!m_nodes_to_serialize.empty())
{
result.push_back(std::make_shared<op::Result>(fr));
}
Node* next_node = m_nodes_to_serialize.front();
m_nodes_to_serialize.pop();
serialized_nodes.push(serialize_node(*next_node));
}
if (results != 0 && results != func_result.size())
while (serialized_nodes.size() > 0)
{
throw ngraph_error(
" Graph serialization is inconsistent. Some op::Results appear to be missing");
m_json_nodes.push_back(serialized_nodes.top());
serialized_nodes.pop();
}
std::vector<std::shared_ptr<op::Parameter>> params;
for (auto param_name : func_parameters)
{
params.push_back(dynamic_pointer_cast<op::Parameter>(node_map.at(param_name)));
}
rc = make_shared<Function>(result, params, func_name);
function_map[func_name] = rc;
return rc;
}
return n.get_name();
}
static json write(const Node& n, bool binary_constant_data)
json JSONSerializer::serialize_node(Node& n)
{
m_nodes_serialized.insert(&n);
json node;
node["name"] = n.get_name();
if (n.get_name() != n.get_friendly_name())
......@@ -1728,11 +1780,11 @@ static json write(const Node& n, bool binary_constant_data)
for (auto& input : n.inputs())
{
inputs.push_back(input.get_source_output().get_node()->get_name());
inputs.push_back(serialize_node_reference(*input.get_source_output().get_node()));
}
for (auto cdep : n.get_control_dependencies())
{
control_deps.push_back(cdep->get_name());
control_deps.push_back(serialize_node_reference(*cdep));
}
for (auto& output : n.outputs())
{
......@@ -2538,6 +2590,5 @@ static json write(const Node& n, bool binary_constant_data)
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
return node;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment