Commit 563c0c21 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fix serializer so it can serialize PP BERT large (#4127)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 280d97d5
......@@ -114,7 +114,6 @@ public:
json serialize_output(const Output<Node>& output);
json serialize_parameter_vector(const ParameterVector& parameters);
json serialize_output_vector(const OutputVector& output_vector);
json serialize_node_reference(const Node& node);
json serialize_node(const Node& node);
json serialize_axis_set(const AxisSet& axis_set);
json serialize_tensor_iterator_input_description(
......@@ -127,8 +126,6 @@ protected:
bool m_serialize_output_shapes{false};
bool m_binary_constant_data{false};
json m_json_nodes;
set<const Node*> m_nodes_serialized;
queue<const Node*> m_nodes_to_serialize;
};
class JSONDeserializer
......@@ -444,7 +441,7 @@ json JSONSerializer::serialize_parameter_vector(const ParameterVector& parameter
json json_parameters = json::array();
for (auto param : parameters)
{
json_parameters.push_back(serialize_node_reference(*param));
json_parameters.push_back(param->get_name());
}
return json_parameters;
}
......@@ -458,9 +455,16 @@ json JSONSerializer::serialize_function(const Function& f)
// TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i)
{
function["result"].push_back(serialize_node_reference(*f.get_output_op(i)));
function["result"].push_back(f.get_output_op(i)->get_name());
}
function["ops"] = m_json_nodes;
json nodes;
for (shared_ptr<Node> node : f.get_ordered_ops(true))
{
nodes.push_back(serialize_node(*node));
}
function["ops"] = nodes;
return function;
}
......@@ -2643,6 +2647,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
args[1],
args[2],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
break;
}
case OP_TYPEID::Stack:
{
......@@ -3001,36 +3006,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
return node;
}
json JSONSerializer::serialize_node_reference(const Node& n)
{
if (m_nodes_serialized.count(&n) != 1)
{
m_nodes_to_serialize.push(&n);
if (m_nodes_to_serialize.size() == 1)
{
// Nothing in the queue
stack<json> serialized_nodes;
while (!m_nodes_to_serialize.empty())
{
const Node* next_node = m_nodes_to_serialize.front();
m_nodes_to_serialize.pop();
serialized_nodes.push(serialize_node(*next_node));
}
while (serialized_nodes.size() > 0)
{
m_json_nodes.push_back(serialized_nodes.top());
serialized_nodes.pop();
}
}
}
return n.get_name();
}
json JSONSerializer::serialize_output(const Output<Node>& output)
{
json result;
auto index = output.get_index();
json json_node_reference = serialize_node_reference(*output.get_node());
json json_node_reference = output.get_node()->get_name();
if (index == 0)
{
result = json_node_reference;
......@@ -3055,7 +3035,6 @@ json JSONSerializer::serialize_output_vector(const OutputVector& output_vector)
json JSONSerializer::serialize_node(const Node& n)
{
m_nodes_serialized.insert(&n);
const NodeTypeInfo& type_info = n.get_type_info();
json jtype_info;
jtype_info["name"] = type_info.name;
......@@ -3082,7 +3061,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
for (auto cdep : n.get_control_dependencies())
{
control_deps.push_back(serialize_node_reference(*cdep));
control_deps.push_back(cdep->get_name());
}
for (auto& output : n.outputs())
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment