Commit 733fbd90 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Serialize cleanup (#2792)

* don't emit empty objects

* Add method to set the option to add output shapes to json

* more cleanup
parent 80af2c7f
......@@ -122,6 +122,14 @@ using namespace std;
using json = nlohmann::json;
using const_data_callback_t = shared_ptr<Node>(const string&, const element::Type&, const Shape&);
static bool s_serialize_output_shapes_enabled =
(std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr);
void ngraph::set_serialize_output_shapes(bool enable)
{
s_serialize_output_shapes_enabled = enable;
}
// This expands the op list in op_tbl.hpp into a list of enumerations that look like this:
// Abs,
// Acos,
......@@ -430,6 +438,18 @@ static json write(const Function& f, bool binary_constant_data)
return function;
}
template <typename T>
T get_value(nlohmann::json js, const string& key)
{
T rc;
auto it = js.find(key);
if (it != js.end())
{
rc = it->get<T>();
}
return rc;
}
static shared_ptr<ngraph::Function>
read_function(const json& func_js,
unordered_map<string, shared_ptr<Function>>& function_map,
......@@ -446,20 +466,13 @@ static shared_ptr<ngraph::Function>
try
{
string node_name = node_js.at("name").get<string>();
string friendly_name;
auto it = node_js.find("friendly_name");
if (it != node_js.end())
{
friendly_name = it->get<string>();
}
string node_op = node_js.at("op").get<string>();
vector<string> node_inputs = node_js.at("inputs").get<vector<string>>();
vector<string> control_deps_inputs =
get_or_default<vector<string>>(node_js, "control_deps", vector<string>{});
vector<string> node_outputs = node_js.at("outputs").get<vector<string>>();
string friendly_name = get_value<string>(node_js, "friendly_name");
vector<string> node_inputs = get_value<vector<string>>(node_js, "inputs");
vector<string> control_deps_inputs = get_value<vector<string>>(node_js, "control_deps");
vector<string> node_outputs = get_value<vector<string>>(node_js, "outputs");
shared_ptr<Node> node;
vector<shared_ptr<Node>> args;
vector<shared_ptr<Node>> control_deps;
for (const string& name : node_inputs)
{
args.push_back(node_map.at(name));
......@@ -1084,15 +1097,11 @@ static shared_ptr<ngraph::Function>
// This is a legacy field whose functionality is no longer supported. The new
// behavior is equivalent to interior padding of 0, so we will accept it under
// those conditions.
auto padding_interior_maybe = node_js.find("padding_interior");
if (padding_interior_maybe != node_js.end())
{
auto padding_interior = padding_interior_maybe->get<vector<size_t>>();
NGRAPH_CHECK(std::all_of(padding_interior.begin(),
padding_interior.end(),
[](size_t s) { return s == 0; }),
"Legacy padding_interior field must be zero everywhere.");
}
auto padding_interior = get_value<vector<size_t>>(node_js, "padding_interior");
NGRAPH_CHECK(std::all_of(padding_interior.begin(),
padding_interior.end(),
[](size_t s) { return s == 0; }),
"Legacy padding_interior field must be zero everywhere.");
auto pad_mode = node_js.count("pad_mode") == 0
? op::PadMode::CONSTANT
......@@ -1462,11 +1471,20 @@ static json write(const Node& n, bool binary_constant_data)
outputs.push_back(output.get_tensor().get_name());
}
node["inputs"] = inputs;
node["control_deps"] = control_deps;
node["outputs"] = outputs;
if (!inputs.empty())
{
node["inputs"] = inputs;
}
if (!control_deps.empty())
{
node["control_deps"] = control_deps;
}
if (!outputs.empty())
{
node["outputs"] = outputs;
}
if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr)
if (s_serialize_output_shapes_enabled)
{
json output_shapes = json::array();
for (size_t i = 0; i < n.get_output_size(); ++i)
......
......@@ -55,4 +55,10 @@ namespace ngraph
/// \brief Deserialize a Function
/// \param str The json formatted string to deseriailze.
std::shared_ptr<ngraph::Function> deserialize(const std::string& str);
/// \brief If enabled adds output shapes to the serialized graph
/// \param enable Set to true to enable or false otherwise
///
/// Option may be enabled by setting the environment variable NGRAPH_SERIALIZER_OUTPUT_SHAPES
void set_serialize_output_shapes(bool enable);
}
......@@ -204,6 +204,10 @@ TEST(benchmark, serialize)
shared_ptr<Function> f = ngraph::deserialize(json_string);
timer.stop();
cout << "deserialize took " << timer.get_milliseconds() << "ms\n";
ngraph::set_serialize_output_shapes(true);
ofstream out("test.json");
out << serialize(f, 4);
}
MATCHER_P2(IsOutputShape, type, shape, "")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment