Unverified Commit b5542083 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Assorted cleanups migrated from GOE removal work (#4263)

parent 311b06d2
......@@ -195,6 +195,36 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
target->clear_control_dependents();
}
void ngraph::replace_node(const std::shared_ptr<Node>& target,
const OutputVector& replacement_values)
{
if (target->is_output())
{
throw ngraph_error("Result nodes cannot be replaced.");
}
NGRAPH_CHECK(!target->get_users().empty(), "Attempted to replace unreachable node '", *target);
NGRAPH_CHECK(target->get_output_size() == replacement_values.size());
unordered_set<shared_ptr<Node>> replacement_nodes;
// For each of target's output O with replacement output O_rep:
// For each O's connected downstream input I:
// Change I's connected upstream output to O_rep
for (size_t i = 0; i < target->get_output_size(); i++)
{
auto& replacement_value = replacement_values.at(i);
auto replacement_node = replacement_value.get_node_shared_ptr();
if (replacement_nodes.find(replacement_node) == replacement_nodes.end())
{
replacement_node->add_node_control_dependents(target);
target->transfer_provenance_tags(replacement_node);
replacement_nodes.insert(replacement_node);
}
target->output(i).replace(replacement_values.at(i));
}
target->clear_control_dependents();
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
auto default_output_order = vector<int64_t>(target->get_output_size());
......@@ -314,6 +344,70 @@ std::list<std::shared_ptr<ngraph::Node>>
return cloned_nodes;
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& output_map)
{
// for each node in topological order
auto sorted_nodes = topological_sort(nodes, true);
std::list<shared_ptr<Node>> cloned_nodes;
for (auto node : sorted_nodes)
{
auto node_outputs = node->outputs();
for (auto value : node_outputs)
{
if (output_map.count(value) == 0)
{
// We need this node cloned
// get (already) cloned arguments and clone the node
OutputVector cloned_args;
for (auto value : node->input_values())
{
cloned_args.push_back(output_map.at(value));
}
NodeVector cloned_dependencies;
for (auto& dependency : node->get_control_dependencies())
{
for (auto dependency_value : dependency->outputs())
{
shared_ptr<Node> dependent =
output_map.at(dependency_value).get_node_shared_ptr();
if (find(cloned_dependencies.begin(),
cloned_dependencies.end(),
dependent) == cloned_dependencies.end())
{
cloned_dependencies.push_back(dependent);
}
}
}
auto cloned_node = node->copy_with_new_inputs(cloned_args, cloned_dependencies);
cloned_nodes.push_back(cloned_node);
if (node->get_friendly_name() != node->get_name())
{
// There is a friendly name for this node so copy it
cloned_node->set_friendly_name(node->get_friendly_name());
}
for (auto tag : node->get_provenance_tags())
{
cloned_node->add_provenance_tag(tag);
}
cloned_node->set_op_annotations(node->get_op_annotations());
for (auto cloned_value : cloned_node->outputs())
{
auto original_value = node_outputs.at(cloned_value.get_index());
if (output_map.count(original_value) == 0)
{
output_map[original_value] = cloned_value;
}
}
break;
}
}
}
return cloned_nodes;
}
std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function& func)
{
NodeMap nm;
......@@ -505,7 +599,7 @@ bool ngraph::is_zero(const Output<Node>& reduce_constant)
return result_bool;
}
bool ngraph::is_one(std::shared_ptr<Node> reduce_constant)
bool ngraph::is_one(const Output<Node>& reduce_constant)
{
auto result_bool = is_equal_to_const_value("1", reduce_constant);
return result_bool;
......
......@@ -217,6 +217,12 @@ namespace ngraph
void replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order);
/// Replace target.outputs[i] with replacement_values[i] and transfer control dependents and
/// provenance from target to the node(s) in replacement_values.
NGRAPH_API
void replace_node(const std::shared_ptr<Node>& target, const OutputVector& replacement_values);
NGRAPH_API
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
......@@ -391,6 +397,13 @@ namespace ngraph
std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map);
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
std::list<std::shared_ptr<ngraph::Node>>
clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
RawNodeOutputMap& node_map);
// input function is cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned function ops
......@@ -429,7 +442,7 @@ namespace ngraph
// or a node that belongs to args
NodeVector extract_subgraph(const NodeVector& results, const NodeVector& args);
bool is_one(std::shared_ptr<Node> reduce_constant);
bool is_one(const Output<Node>& reduce_constant);
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
......
......@@ -484,6 +484,29 @@ void Node::merge_provenance_tags_from(const std::shared_ptr<const Node>& source)
}
}
void Node::transfer_provenance_tags(const shared_ptr<Node>& replacement)
{
auto common_args = ngraph::find_common_args(shared_from_this(), replacement);
std::set<string> removed_subgraph_tags;
auto set_replacement_prov = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
for (auto tag : node->get_provenance_tags())
{
removed_subgraph_tags.insert(tag);
}
};
traverse_nodes({shared_from_this()}, set_replacement_prov, false, common_args);
replacement->add_provenance_tags(removed_subgraph_tags);
auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
node->add_provenance_tags(removed_subgraph_tags);
};
traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
}
std::shared_ptr<Node> Node::get_argument(size_t index) const
{
NGRAPH_CHECK(
......@@ -557,6 +580,12 @@ void Node::add_node_control_dependents(std::shared_ptr<Node> source_node)
}
}
void Node::transfer_control_dependents(std::shared_ptr<Node> replacement)
{
replacement->add_node_control_dependents(shared_from_this());
clear_control_dependents();
}
void Node::remove_control_dependency(std::shared_ptr<Node> node)
{
{
......@@ -991,6 +1020,18 @@ namespace ngraph
<< "):" << input.get_element_type()
<< input.get_partial_shape();
}
void Output<Node>::replace(const Output<Node>& replacement)
{
for (auto& input : get_target_inputs())
{
// GOEs are used as handles in passes
if (!is_type<op::GetOutputElement>(input.get_node()))
{
input.replace_source_output(replacement);
}
}
}
}
Input<Node> Node::input(size_t input_index)
......
......@@ -313,6 +313,9 @@ namespace ngraph
/// This node becomes a dependent of every node dependent on source_node
void add_node_control_dependents(std::shared_ptr<Node> source_node);
/// This node's control dependencies are replaced by replacement
void transfer_control_dependents(std::shared_ptr<Node> replacement);
/// Returns the number of outputs from the node.
size_t get_output_size() const;
......@@ -459,6 +462,9 @@ namespace ngraph
// to be used when nodes are replaced
void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
/// Transfer provenance tags to replacement
void transfer_provenance_tags(const std::shared_ptr<Node>& replacement);
/// Get all the nodes that uses the current node
NodeVector get_users(bool check_is_used = false) const;
......@@ -797,6 +803,9 @@ namespace ngraph
// TODO(amprocte): Investigate whether this really ought to be public.
void remove_target_input(const Input<Node>& target_input) const;
/// \brief Replace all users of this value with replacement
void replace(const Output<Node>& replacement);
bool operator==(const Output& other) const
{
return m_node == other.m_node && m_index == other.m_index;
......@@ -972,6 +981,40 @@ namespace ngraph
&(target_input.get_node()->m_inputs.at(target_input.get_index())));
}
// Like an Output but with a Node* instead of a shared_ptr<Node>
struct RawNodeOutput
{
RawNodeOutput(const Output<Node>& value)
: node(value.get_node())
, index(value.get_index())
{
}
RawNodeOutput(const RawNodeOutput&) = default;
RawNodeOutput() = default;
Node* node;
size_t index{0};
operator Output<Node>() { return Output<Node>(node->shared_from_this(), index); }
bool operator==(const RawNodeOutput& other) const
{
return node == other.node && index == other.index;
}
bool operator!=(const RawNodeOutput& other) const { return !(*this == other); }
bool operator<(const RawNodeOutput& other) const
{
return node < other.node || (node == other.node && index < other.index);
}
bool operator>(const RawNodeOutput& other) const
{
return node > other.node || (node == other.node && index > other.index);
}
bool operator<=(const RawNodeOutput& other) const { return !(*this > other); }
bool operator>=(const RawNodeOutput& other) const { return !(*this < other); }
};
using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
class NodeValidationFailure : public CheckFailure
{
public:
......
......@@ -540,7 +540,7 @@ namespace
const auto delta = node->input_value(1);
shared_ptr<Node> replacement_node;
if (node->get_inputs().size() == 3)
if (node->get_input_size() == 3)
{
const auto result_forward = node->input_value(2);
replacement_node = make_shared<op::v0::MaxPoolBackprop>(arg_forward,
......
......@@ -409,7 +409,7 @@ namespace
auto kernel = node->get_window_shape();
shared_ptr<Node> replacement_node;
if (node->get_inputs().size() == 3)
if (node->get_input_size() == 3)
{
replacement_node = make_shared<op::v1::MaxPoolBackprop>(node->input_value(0),
node->input_value(1),
......
......@@ -206,15 +206,12 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
FpropCache fprop_cache;
// Traverse bprop to find all of the nodes in the bprop graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
std::set<Output<Node>> in_bprop;
ngraph::traverse_nodes(bprop,
[&in_bprop](std::shared_ptr<Node> node) {
if (node->get_output_size() == 1)
for (auto value : node->outputs())
{
if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
in_bprop.insert(value);
}
},
false /* no control dependencies */);
......@@ -222,14 +219,22 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop iff they are in bprop
// and aren't inputs to bprop
auto bprop_inputs = bprop->get_parameters();
vector<Output<Node>> bprop_inputs;
for (auto param : bprop->get_parameters())
{
bprop_inputs.push_back(param);
}
ngraph::traverse_nodes(
fprop, [&fprop_cache, &in_bprop, &bprop_inputs](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0 &&
std::find(bprop_inputs.begin(), bprop_inputs.end(), node) == bprop_inputs.end())
for (auto value : node->outputs())
{
fprop_cache.node_param_map[node.get()] =
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape());
if (in_bprop.count(value) != 0 &&
std::find(bprop_inputs.begin(), bprop_inputs.end(), value) ==
bprop_inputs.end())
{
fprop_cache.node_param_map[value] = std::make_shared<op::Parameter>(
value.get_element_type(), value.get_shape());
}
}
});
......@@ -240,10 +245,10 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ngraph::clone_nodes(bprop->get_ops(), fprop_cache.node_param_map);
// invert the fprop_cache cloned node map for easy back and for acces.
std::unordered_map<Node*, Node*> inverted_node_map;
std::map<Output<Node>, RawNodeOutput> inverted_node_map;
for (auto kv : fprop_cache.node_param_map)
{
inverted_node_map[kv.second.get()] = kv.first;
inverted_node_map[kv.second] = kv.first;
}
// get cloned bprop results
......@@ -251,7 +256,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
NodeVector result_nodes;
for (auto node : bprop->get_results())
{
auto result = as_type_ptr<op::Result>(fprop_cache.node_param_map.at(node.get()));
auto result = as_type_ptr<op::Result>(
fprop_cache.node_param_map.at(Output<Node>(node)).get_node_shared_ptr());
if (!result)
{
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
......@@ -266,15 +272,15 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ParameterVector bprop_input_params;
for (auto param : bprop_inputs)
{
bprop_input_params.push_back(
as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(param.get())));
bprop_input_params.push_back(as_type_ptr<op::Parameter>(
fprop_cache.node_param_map.at(Output<Node>(param)).get_node_shared_ptr()));
}
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(x)));
as_type_ptr<op::Parameter>(fprop_cache.node_param_map.at(x).get_node_shared_ptr()));
}
return bprop_input_params;
};
......@@ -287,11 +293,11 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
result_nodes,
[&cloned_bprop_inputs, &fprop_cache, &inverted_node_map](std::shared_ptr<Node> node) {
auto pnode = as_type_ptr<op::Parameter>(node);
if (pnode != nullptr &&
if (pnode &&
std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) ==
cloned_bprop_inputs.end())
{
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node.get()));
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(Output<Node>(node)));
}
},
false /* no control dependencies */);
......@@ -301,12 +307,11 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto fpirn : fprop_cache.fprop_output_nodes)
{
auto fpir = fpirn->shared_from_this();
if (as_type_ptr<op::Result>(fpir))
if (as_type_ptr<op::Result>(fpirn.node->shared_from_this()))
{
throw ngraph_error("Expected op::Result in fprop->get_results()");
throw ngraph_error("Unexpected op::Result in fprop->get_results()");
}
fprop_outputs.push_back(std::make_shared<op::Result>(fpir));
fprop_outputs.push_back(std::make_shared<op::Result>(fpirn));
}
fprop_cache.fprop = std::make_shared<Function>(fprop_outputs, fprop->get_parameters());
......
......@@ -223,8 +223,8 @@ namespace ngraph
{
std::shared_ptr<Function> fprop;
std::shared_ptr<Function> bprop;
std::vector<Node*> fprop_output_nodes;
NodeMap node_param_map;
std::vector<RawNodeOutput> fprop_output_nodes;
RawNodeOutputMap node_param_map;
};
//
......
......@@ -184,9 +184,10 @@ namespace ngraph
std::vector<std::shared_ptr<runtime::Tensor>> mod_df_input_args = df_input_args;
// add cached nodes to both modified f output and modified f' input arguments
for (auto node : fprop_cache.fprop_output_nodes)
for (auto weak_value : fprop_cache.fprop_output_nodes)
{
auto tv = backend->create_tensor(node->get_element_type(), node->get_shape());
Output<Node> value(weak_value);
auto tv = backend->create_tensor(value.get_element_type(), value.get_shape());
mod_f_output_args.push_back(tv);
mod_df_input_args.push_back(tv);
}
......
......@@ -101,7 +101,7 @@ size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
size_t count = 0;
for (auto op : f->get_ops())
{
if (ngraph::as_type_ptr<T>(op))
if (ngraph::is_type<T>(op))
{
count++;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment