Unverified Commit a01170fb authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

deserialize fix to return the right outermost function (#457)

* deserialize fix to return the right outermost function

* fix tests

* more switches to print out/visualize shapes
parent 328ff806
......@@ -66,12 +66,22 @@ std::string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
stringstream ss;
if (node->is_parameter())
{
ss << " " << node->get_name() << " [shape=box color=blue]\n";
ss << " " << node->get_name() << " [shape=box color=blue";
}
else
{
ss << " " << node->get_name() << " [shape=ellipse color=black]\n";
ss << " " << node->get_name() << " [shape=ellipse color=black";
}
ss << " label=\"" << node->get_name();
if (std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES") != nullptr)
{
ss << " " << vector_to_string(node->get_shape());
}
ss << " \"]\n";
return ss.str();
}
......@@ -88,7 +98,14 @@ void pass::VisualizeTree::render() const
out.close();
stringstream ss;
ss << "dot -Tpng " << tmp_file << " -o " << m_name;
const char* format = std::getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
if (!format)
{
format = "png";
}
ss << "dot -T" << format << " " << tmp_file << " -o " << m_name;
auto cmd = ss.str();
auto stream = popen(cmd.c_str(), "r");
pclose(stream);
......
......@@ -225,10 +225,7 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
for (json func : js)
{
shared_ptr<Function> f = read_function(func, function_map);
if (rc == nullptr)
{
rc = f;
}
rc = f;
}
return rc;
......@@ -772,6 +769,7 @@ static json write(const Node& n)
// TODO Multiple outputs
json inputs = json::array();
json outputs = json::array();
for (const descriptor::Input& input : n.get_inputs())
{
inputs.push_back(input.get_output().get_node()->get_name());
......@@ -780,9 +778,20 @@ static json write(const Node& n)
{
outputs.push_back(n.get_output_tensor(i).get_name());
}
node["inputs"] = inputs;
node["outputs"] = outputs;
if (std::getenv("NGRAPH_SERIALIZER_OUTPUT_SHAPES") != nullptr)
{
json output_shapes = json::array();
for (size_t i = 0; i < n.get_output_size(); ++i)
{
output_shapes.push_back(n.get_output_shape(i));
}
node["output_shapes"] = output_shapes;
}
string node_op = n.description();
if (node_op == "Abs")
{
......
......@@ -65,7 +65,7 @@ TEST(serialize, main)
istringstream in(js);
shared_ptr<Function> sfunc = deserialize(in);
// Now call g on some test vectors.
// Now call h on some test vectors.
auto manager = runtime::Manager::get("INTERPRETER");
auto external = manager->compile(sfunc);
auto backend = manager->allocate_backend();
......@@ -80,13 +80,13 @@ TEST(serialize, main)
auto result = backend->make_primary_tensor_view(element::f32, shape);
cf->call({x, y, z}, {result});
EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));
EXPECT_EQ((vector<float>{216, 320, 440, 576}), read_vector<float>(result));
cf->call({y, x, z}, {result});
EXPECT_EQ((vector<float>{54, 80, 110, 144}), read_vector<float>(result));
EXPECT_EQ((vector<float>{216, 320, 440, 576}), read_vector<float>(result));
cf->call({x, z, y}, {result});
EXPECT_EQ((vector<float>{50, 72, 98, 128}), read_vector<float>(result));
EXPECT_EQ((vector<float>{200, 288, 392, 512}), read_vector<float>(result));
}
TEST(serialize, existing_models)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment