Commit d0f03eec authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Serialization/deserialization of node inputs connected directly to a non-zero-index output (#3102)

parent b9dc7fa9
......@@ -209,17 +209,18 @@ public:
m_binary_constant_data = binary_constant_data;
}
json serialize_function(Function& function);
json serialize_node_reference(Node& node);
json serialize_node(Node& node);
json serialize_function(const Function& function);
json serialize_output(const Output<Node>& output);
json serialize_node_reference(const Node& node);
json serialize_node(const Node& node);
protected:
size_t m_indent{0};
bool m_serialize_output_shapes{false};
bool m_binary_constant_data{false};
json m_json_nodes;
set<Node*> m_nodes_serialized;
queue<Node*> m_nodes_to_serialize;
set<const Node*> m_nodes_serialized;
queue<const Node*> m_nodes_to_serialize;
};
class JSONDeserializer
......@@ -231,6 +232,7 @@ public:
}
shared_ptr<Function> deserialize_function(json& j);
Output<Node> deserialize_output(json& j);
shared_ptr<Node> deserialize_node_reference(json& j);
shared_ptr<Node> deserialize_node(json& j);
......@@ -493,7 +495,7 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
return rc;
}
json JSONSerializer::serialize_function(Function& f)
json JSONSerializer::serialize_function(const Function& f)
{
json function;
function["name"] = f.get_name();
......@@ -532,6 +534,27 @@ shared_ptr<Node> JSONDeserializer::deserialize_node_reference(json& j)
return m_node_map.at(name);
}
Output<Node> JSONDeserializer::deserialize_output(json& j)
{
size_t index;
json json_node_reference;
if (j.is_string())
{
json_node_reference = j;
index = 0;
}
else if (j.is_object())
{
json_node_reference = j["node"];
index = j["index"];
}
else
{
throw ngraph_error("Expected string or object an output while deserializing");
}
return Output<Node>(deserialize_node_reference(json_node_reference), index);
}
shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
{
string func_name = func_js.at("name").get<string>();
......@@ -579,6 +602,47 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
return rc;
}
// This helps with conversions to old-style shared-ptr<Node> and new-style Output&
// arguments to node constructors. Uses of OutputHelper should be replaced with Output
// when all op constructors use the new style arguments.
struct OutputHelper
{
OutputHelper(const Output<Node>& output)
: m_output(output)
{
}
operator shared_ptr<Node>() const
{
return m_output.get_index() == 0
? m_output.get_node_shared_ptr()
: make_shared<op::GetOutputElement>(m_output.get_node_shared_ptr(),
m_output.get_index());
}
operator const Output<Node>&() const { return m_output; }
Output<Node> m_output;
};
// This helps with conversions to old-style shared-ptr<Node> and new-style Output&
// arguments to node constructors. Uses of OutputVectorHelper should be replaced with OutputVector
// when all op constructors use the new style arguments.
struct OutputVectorHelper
{
const OutputHelper& operator[](size_t i) const { return m_vector[i]; }
void push_back(const Output<Node>& output) { m_vector.push_back(output); }
size_t size() const { return m_vector.size(); }
operator vector<shared_ptr<Node>>() const
{
vector<shared_ptr<Node>> result;
for (auto& o : m_vector)
{
result.push_back(o);
}
return result;
}
vector<OutputHelper> m_vector;
};
shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
{
shared_ptr<Node> node;
......@@ -590,10 +654,10 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
vector<json> node_inputs = get_value<vector<json>>(node_js, "inputs");
vector<json> control_deps_inputs = get_value<vector<json>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
vector<shared_ptr<Node>> args;
OutputVectorHelper args;
for (auto& node_input : node_inputs)
{
args.push_back(deserialize_node_reference(node_input));
args.push_back(deserialize_output(node_input));
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
......@@ -1738,7 +1802,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
return node;
}
json JSONSerializer::serialize_node_reference(Node& n)
json JSONSerializer::serialize_node_reference(const Node& n)
{
if (m_nodes_serialized.count(&n) != 1)
{
......@@ -1749,7 +1813,7 @@ json JSONSerializer::serialize_node_reference(Node& n)
stack<json> serialized_nodes;
while (!m_nodes_to_serialize.empty())
{
Node* next_node = m_nodes_to_serialize.front();
const Node* next_node = m_nodes_to_serialize.front();
m_nodes_to_serialize.pop();
serialized_nodes.push(serialize_node(*next_node));
}
......@@ -1763,7 +1827,24 @@ json JSONSerializer::serialize_node_reference(Node& n)
return n.get_name();
}
json JSONSerializer::serialize_node(Node& n)
json JSONSerializer::serialize_output(const Output<Node>& output)
{
json result;
auto index = output.get_index();
json json_node_reference = serialize_node_reference(*output.get_node());
if (index == 0)
{
result = json_node_reference;
}
else
{
result["node"] = json_node_reference;
result["index"] = index;
}
return result;
}
json JSONSerializer::serialize_node(const Node& n)
{
m_nodes_serialized.insert(&n);
json node;
......@@ -1780,7 +1861,7 @@ json JSONSerializer::serialize_node(Node& n)
for (auto& input : n.inputs())
{
inputs.push_back(serialize_node_reference(*input.get_source_output().get_node()));
inputs.push_back(serialize_output(input.get_source_output()));
}
for (auto cdep : n.get_control_dependencies())
{
......
......@@ -324,3 +324,19 @@ TEST(serialize, constant_infinity_nan)
EXPECT_NE(str.find(R"(label="C)"), string::npos);
EXPECT_NE(str.find(R"(label="D)"), string::npos);
}
TEST(serialize, non_zero_node_output)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{10});
auto topk = make_shared<op::TopK>(arg, 0, element::i32, 5, true);
auto abs = make_shared<op::Abs>(Output<Node>(topk, 1));
auto result = make_shared<op::Result>(abs);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg});
string s = serialize(f);
shared_ptr<Function> g = deserialize(s);
auto g_result = g->get_results().at(0);
auto g_abs = g_result->input(0).get_source_output().get_node_shared_ptr();
auto topk_out = g_abs->input(0).get_source_output();
EXPECT_EQ(topk_out.get_index(), 1);
EXPECT_EQ(topk_out.get_node()->description(), "TopK");
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment