Commit 733fbd90 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Serialize cleanup (#2792)

* don't emit empty objects

* Add method to set the option to add output shapes to json

* more cleanup
parent 80af2c7f
...@@ -122,6 +122,14 @@ using namespace std; ...@@ -122,6 +122,14 @@ using namespace std;
using json = nlohmann::json; using json = nlohmann::json;
using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&); using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&);
static bool s_serialize_output_shapes_enabled =
(std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr);
void ngraph::set_serialize_output_shapes(bool enable)
{
s_serialize_output_shapes_enabled = enable;
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this: // This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs, // Abs,
// Acos, // Acos,
...@@ -430,6 +438,18 @@ static json write(const Function& f, bool binary_constant_data) ...@@ -430,6 +438,18 @@ static json write(const Function& f, bool binary_constant_data)
return function; return function;
} }
template <typename T>
T get_value(nlohmann::json js, const string& key)
{
T rc;
auto it = js.find(key);
if (it != js.end())
{
rc = it->get<T>();
}
return rc;
}
static shared_ptr<ngraph::Function> static shared_ptr<ngraph::Function>
read_function(const json& func_js, read_function(const json& func_js,
unordered_map<string, shared_ptr<Function>>& function_map, unordered_map<string, shared_ptr<Function>>& function_map,
...@@ -446,20 +466,13 @@ static shared_ptr<ngraph::Function> ...@@ -446,20 +466,13 @@ static shared_ptr<ngraph::Function>
try try
{ {
string node_name = node_js.at("name").get<string>(); string node_name = node_js.at("name").get<string>();
string friendly_name;
auto it = node_js.find("friendly_name");
if (it != node_js.end())
{
friendly_name = it->get<string>();
}
string node_op = node_js.at("op").get<string>(); string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>(); string friendly_name = get_value<string>(node_js, "friendly_name");
vector<string> control_deps_inputs = vector<string> node_inputs = get_value<vector<string>>(node_js, "inputs");
get_or_default<vector<string>>(node_js, "control_deps", vector<string>{}); vector<string> control_deps_inputs = get_value<vector<string>>(node_js, "control_deps");
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>(); vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
shared_ptr<Node> node; shared_ptr<Node> node;
vector<shared_ptr<Node>> args; vector<shared_ptr<Node>> args;
vector<shared_ptr<Node>> control_deps;
for (const string& name : node_inputs) for (const string& name : node_inputs)
{ {
args.push_back(node_map.at(name)); args.push_back(node_map.at(name));
...@@ -1084,15 +1097,11 @@ static shared_ptr<ngraph::Function> ...@@ -1084,15 +1097,11 @@ static shared_ptr<ngraph::Function>
// This is a legacy field whose functionality is no longer supported. The new // This is a legacy field whose functionality is no longer supported. The new
// behavior is equivalent to interior padding of 0, so we will accept it under // behavior is equivalent to interior padding of 0, so we will accept it under
// those conditions. // those conditions.
auto padding_interior_maybe = node_js.find("padding_interior"); auto padding_interior = get_value<vector<size_t>>(node_js, "padding_interior");
if (padding_interior_maybe != node_js.end()) NGRAPH_CHECK(std::all_of(padding_interior.begin(),
{ padding_interior.end(),
auto padding_interior = padding_interior_maybe->get<vector<size_t>>(); [](size_t s) { return s == 0; }),
NGRAPH_CHECK(std::all_of(padding_interior.begin(), "Legacy padding_interior field must be zero everywhere.");
padding_interior.end(),
[](size_t s) { return s == 0; }),
"Legacy padding_interior field must be zero everywhere.");
}
auto pad_mode = node_js.count("pad_mode") == 0 auto pad_mode = node_js.count("pad_mode") == 0
? op::PadMode::CONSTANT ? op::PadMode::CONSTANT
...@@ -1462,11 +1471,20 @@ static json write(const Node& n, bool binary_constant_data) ...@@ -1462,11 +1471,20 @@ static json write(const Node& n, bool binary_constant_data)
outputs.push_back(output.get_tensor().get_name()); outputs.push_back(output.get_tensor().get_name());
} }
node["inputs"] = inputs; if (!inputs.empty())
node["control_deps"] = control_deps; {
node["outputs"] = outputs; node["inputs"] = inputs;
}
if (!control_deps.empty())
{
node["control_deps"] = control_deps;
}
if (!outputs.empty())
{
node["outputs"] = outputs;
}
if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr) if (s_serialize_output_shapes_enabled)
{ {
json output_shapes = json::array(); json output_shapes = json::array();
for (size_t i = 0; i < n.get_output_size(); ++i) for (size_t i = 0; i < n.get_output_size(); ++i)
......
...@@ -55,4 +55,10 @@ namespace ngraph ...@@ -55,4 +55,10 @@ namespace ngraph
/// \brief Deserialize a Function /// \brief Deserialize a Function
/// \param str The json formatted string to deseriailze. /// \param str The json formatted string to deseriailze.
std::shared_ptr<ngraph::Function> deserialize(const std::string& str); std::shared_ptr<ngraph::Function> deserialize(const std::string& str);
/// \brief If enabled adds output shapes to the serialized graph
/// \param enable Set to true to enable or false otherwise
///
/// Option may be enabled by setting the environment variable NGRAPH_SERIALIZER_OUTPUT_SHAPES
void set_serialize_output_shapes(bool enable);
} }
...@@ -204,6 +204,10 @@ TEST(benchmark, serialize) ...@@ -204,6 +204,10 @@ TEST(benchmark, serialize)
shared_ptr<Function> f = ngraph::deserialize(json_string); shared_ptr<Function> f = ngraph::deserialize(json_string);
timer.stop(); timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n"; cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
ngraph::set_serialize_output_shapes(true);
ofstream out("test.json");
out << serialize(f, 4);
} }
MATCHER_P2(IsOutputShape, type, shape, "") MATCHER_P2(IsOutputShape, type, shape, "")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment