Commit 0a6f5bca authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add replace_nodes function (#3468)

* Add replace_by_friendly_name function

* Add replace_nodes function

* Nuke replace_by_friendly_name

* Modify replace_nodes to handle parameter replacement
parent 8e7d10df
......@@ -306,3 +306,15 @@ bool Function::is_dynamic() const
}
return false;
}
void Function::replace_parameter(size_t parameter_index, const shared_ptr<op::Parameter>& parameter)
{
NGRAPH_CHECK(parameter_index < m_parameters.size(),
"replace_parameter(): Tried to replace parameter at index ",
parameter_index,
" but the function only has ",
m_parameters.size(),
" parameters.");
replace_node(m_parameters[parameter_index], parameter);
m_parameters[parameter_index] = parameter;
}
......@@ -117,6 +117,16 @@ namespace ngraph
/// \brief Returns true if any of the op's defined in the function contains partial shape
bool is_dynamic() const;
/// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
///
/// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
/// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
///
/// \param parameter_index The index of the parameter to replace.
/// \param parameter The parameter to substitute for the `parameter_index`th parameter.
void replace_parameter(size_t parameter_index,
const std::shared_ptr<op::Parameter>& parameter);
protected:
ResultVector m_results;
ParameterVector m_parameters;
......
......@@ -28,7 +28,6 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/provenance.hpp"
......@@ -139,10 +138,12 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
throw ngraph_error("Result nodes cannot be replaced.");
}
if (target->get_users().empty())
{
throw ngraph_error("replacing an unreachable node");
}
NGRAPH_CHECK(!target->get_users().empty(),
"Attempted to replace unreachable node '",
*target,
"'. Replacement: '",
*replacement,
"'");
// Fix input/output descriptors
NGRAPH_CHECK(target->get_output_size() == replacement->get_output_size());
......@@ -179,6 +180,35 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
target->clear_control_dependents();
}
void ngraph::replace_nodes(
const std::shared_ptr<Function>& f,
const unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>>&
parameter_replacement_map,
const unordered_map<shared_ptr<Node>, shared_ptr<Node>>& body_replacement_map)
{
auto& params = f->get_parameters();
for (size_t i = 0; i < params.size(); i++)
{
if (parameter_replacement_map.count(params[i]) != 0 &&
parameter_replacement_map.at(params[i]) != params[i])
{
f->replace_parameter(i, parameter_replacement_map.at(params[i]));
}
}
for (auto& kv : body_replacement_map)
{
auto& k = kv.first;
auto& v = kv.second;
if (k != v)
{
f->replace_node(k, v);
}
}
}
// Check if all paths from X to a result go through Y
bool ngraph::is_post_dominated(Node* X, Node* Y)
{
......
......@@ -214,6 +214,34 @@ namespace ngraph
/// replace_node(N, M);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// \brief Replace multiple nodes in a function.
/// \param f Function where replacement is taking place.
/// \param parameter_replacement_map A mapping from parameter shared pointers to parameter
/// shared pointers. For each pair (k,v) in the map, parameter
/// k is replaced by parameter v, except if k==v or k is not a
/// parameter bound by f, in which case the pair (k,v) is
/// ignored.
/// \param body_replacement_map A mapping from node shared pointers to node shared pointers.
/// For each pair (k,v) in the map, node k is replaced by node v,
/// except if k==v, the pair (k,v) is ignored.
/// Note that if k is a parameter, its users will be redirected to
/// v, but k will _not_ be replaced in the function's parameter
/// list.
///
/// Limitations:
///
/// - No check is made that the replaced nodes in `parameter_replacement_map` are actually
/// among the bound parameters of `f`. (If a parameter appears in the map that is not
/// bound by `f`, it will be silently ignored.)
/// - If a parameter node appears as a key in both `parameter_replacement_map` _and_ in
/// `body_replacement_map`, behavior is unspecified.
void replace_nodes(
const std::shared_ptr<Function>& f,
const std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Parameter>>&
parameter_replacement_map,
const std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>>&
body_replacement_map);
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
/// Topological sort of nodes needed to compute root_nodes
......
......@@ -73,6 +73,7 @@ set(SRC
pass_shape_relevance.cpp
pattern.cpp
provenance.cpp
replace_node.cpp
reshape_elimination.cpp
reshape_sinking.cpp
shape.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
//
// Graph before (params in [] brackets, constants in () parens, results in {} braces):
//
// [x] [y] [z]
// \ / |
// Add (k) |
// \ / |
// Mul** |
// \ /
// Sub
// |
// {r}
//
// Param substitutions:
//
// [x] -> [x']
//
// Body substitutions:
//
// (k) -> (k')
// [y] -> (k'')
// [z] -> [x'] + **
//
// After replacement:
//
// [x']---------
// | |
// | (k'') | [z] and [y] is still there, but dead
// \ / |
// Add (k') |
// \ / |
// Mul |
// \ /
// Sub ***
// |
// {r}
//
TEST(replace_node, replace_nodes)
{
auto x = make_shared<op::Parameter>(element::f32, Shape{2});
auto y = make_shared<op::Parameter>(element::f32, Shape{2});
auto z = make_shared<op::Parameter>(element::f32, Shape{2});
auto add = x + y;
auto k = make_shared<op::Constant>(element::f32, Shape{2}, vector<float>{1, 2});
auto mul = add * k;
auto sub = mul - z;
auto f = make_shared<Function>(NodeVector{sub}, ParameterVector{x, y, z});
unordered_map<shared_ptr<op::Parameter>, shared_ptr<op::Parameter>> parameter_replacement_map;
auto x_replacement = make_shared<op::Parameter>(element::f32, Shape{2});
parameter_replacement_map[x] = x_replacement;
unordered_map<shared_ptr<Node>, shared_ptr<Node>> body_replacement_map;
auto y_replacement = make_shared<op::Constant>(element::f32, Shape{2}, vector<float>{3, 4});
auto k_replacement = make_shared<op::Constant>(element::f32, Shape{2}, vector<float>{5, 6});
auto z_replacement = x_replacement + mul;
body_replacement_map[y] = y_replacement;
body_replacement_map[k] = k_replacement;
body_replacement_map[z] = z_replacement;
replace_nodes(f, parameter_replacement_map, body_replacement_map);
// Should still have three params.
ASSERT_EQ(f->get_parameters().size(), 3);
// The three params be {x_replacement, y, z}.
ASSERT_EQ(f->get_parameters()[0], x_replacement);
ASSERT_EQ(f->get_parameters()[1], y);
ASSERT_EQ(f->get_parameters()[2], z);
// y, z should be dead.
ASSERT_EQ(y->get_users(true).size(), 0);
ASSERT_EQ(z->get_users(true).size(), 0);
// Should still have one result.
ASSERT_EQ(f->get_results().size(), 1);
// Result node should be sub (unchanged).
ASSERT_EQ(f->get_results()[0]->input(0).get_source_output().get_node_shared_ptr(), sub);
// sub's arguments should be mul (unchanged) and z_replacement.
ASSERT_EQ(sub->input(0).get_source_output().get_node_shared_ptr(), mul);
ASSERT_EQ(sub->input(1).get_source_output().get_node_shared_ptr(), z_replacement);
// mul's arguments should be add (unchanged) and k_replacement.
ASSERT_EQ(mul->input(0).get_source_output().get_node_shared_ptr(), add);
ASSERT_EQ(mul->input(1).get_source_output().get_node_shared_ptr(), k_replacement);
// add's arguments should be x_replacement and y_replacement.
ASSERT_EQ(add->input(0).get_source_output().get_node_shared_ptr(), x_replacement);
ASSERT_EQ(add->input(1).get_source_output().get_node_shared_ptr(), y_replacement);
// z_replacement's arguments should be x_replacement and mul.
ASSERT_EQ(z_replacement->input(0).get_source_output().get_node_shared_ptr(), x_replacement);
ASSERT_EQ(z_replacement->input(1).get_source_output().get_node_shared_ptr(), mul);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment