Commit 916ae6e9 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fix serializer so it can serialize PP BERT large (#4127) (#4137)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 33ca594a
...@@ -114,7 +114,6 @@ public: ...@@ -114,7 +114,6 @@ public:
json serialize_output(const Output<Node>& output); json serialize_output(const Output<Node>& output);
json serialize_parameter_vector(const ParameterVector& parameters); json serialize_parameter_vector(const ParameterVector& parameters);
json serialize_output_vector(const OutputVector& output_vector); json serialize_output_vector(const OutputVector& output_vector);
json serialize_node_reference(const Node& node);
json serialize_node(const Node& node); json serialize_node(const Node& node);
json serialize_axis_set(const AxisSet& axis_set); json serialize_axis_set(const AxisSet& axis_set);
json serialize_tensor_iterator_input_description( json serialize_tensor_iterator_input_description(
...@@ -127,8 +126,6 @@ protected: ...@@ -127,8 +126,6 @@ protected:
bool m_serialize_output_shapes{false}; bool m_serialize_output_shapes{false};
bool m_binary_constant_data{false}; bool m_binary_constant_data{false};
json m_json_nodes; json m_json_nodes;
set<const Node*> m_nodes_serialized;
queue<const Node*> m_nodes_to_serialize;
}; };
class JSONDeserializer class JSONDeserializer
...@@ -444,7 +441,7 @@ json JSONSerializer::serialize_parameter_vector(const ParameterVector& parameter ...@@ -444,7 +441,7 @@ json JSONSerializer::serialize_parameter_vector(const ParameterVector& parameter
json json_parameters = json::array(); json json_parameters = json::array();
for (auto param : parameters) for (auto param : parameters)
{ {
json_parameters.push_back(serialize_node_reference(*param)); json_parameters.push_back(param->get_name());
} }
return json_parameters; return json_parameters;
} }
...@@ -458,9 +455,16 @@ json JSONSerializer::serialize_function(const Function& f) ...@@ -458,9 +455,16 @@ json JSONSerializer::serialize_function(const Function& f)
// TODO Functions can return multiple results // TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i) for (size_t i = 0; i < f.get_output_size(); ++i)
{ {
function["result"].push_back(serialize_node_reference(*f.get_output_op(i))); function["result"].push_back(f.get_output_op(i)->get_name());
} }
function["ops"] = m_json_nodes;
json nodes;
for (shared_ptr<Node> node : f.get_ordered_ops(true))
{
nodes.push_back(serialize_node(*node));
}
function["ops"] = nodes;
return function; return function;
} }
...@@ -2996,36 +3000,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2996,36 +3000,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
return node; return node;
} }
json JSONSerializer::serialize_node_reference(const Node& n)
{
if (m_nodes_serialized.count(&n) != 1)
{
m_nodes_to_serialize.push(&n);
if (m_nodes_to_serialize.size() == 1)
{
// Nothing in the queue
stack<json> serialized_nodes;
while (!m_nodes_to_serialize.empty())
{
const Node* next_node = m_nodes_to_serialize.front();
m_nodes_to_serialize.pop();
serialized_nodes.push(serialize_node(*next_node));
}
while (serialized_nodes.size() > 0)
{
m_json_nodes.push_back(serialized_nodes.top());
serialized_nodes.pop();
}
}
}
return n.get_name();
}
json JSONSerializer::serialize_output(const Output<Node>& output) json JSONSerializer::serialize_output(const Output<Node>& output)
{ {
json result; json result;
auto index = output.get_index(); auto index = output.get_index();
json json_node_reference = serialize_node_reference(*output.get_node()); json json_node_reference = output.get_node()->get_name();
if (index == 0) if (index == 0)
{ {
result = json_node_reference; result = json_node_reference;
...@@ -3050,7 +3029,6 @@ json JSONSerializer::serialize_output_vector(const OutputVector& output_vector) ...@@ -3050,7 +3029,6 @@ json JSONSerializer::serialize_output_vector(const OutputVector& output_vector)
json JSONSerializer::serialize_node(const Node& n) json JSONSerializer::serialize_node(const Node& n)
{ {
m_nodes_serialized.insert(&n);
const NodeTypeInfo& type_info = n.get_type_info(); const NodeTypeInfo& type_info = n.get_type_info();
json jtype_info; json jtype_info;
jtype_info["name"] = type_info.name; jtype_info["name"] = type_info.name;
...@@ -3077,7 +3055,7 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -3077,7 +3055,7 @@ json JSONSerializer::serialize_node(const Node& n)
} }
for (auto cdep : n.get_control_dependencies()) for (auto cdep : n.get_control_dependencies())
{ {
control_deps.push_back(serialize_node_reference(*cdep)); control_deps.push_back(cdep->get_name());
} }
for (auto& output : n.outputs()) for (auto& output : n.outputs())
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment