Unverified Commit a33ad828 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Replace NodeMap class with an unordered map. (#2830)

parent 86fcc656
......@@ -34,7 +34,7 @@
using namespace ngraph;
std::shared_ptr<Node> make_zero(const std::shared_ptr<Node>& node)
std::shared_ptr<Node> make_broadcast_zero(const std::shared_ptr<Node>& node)
{
std::shared_ptr<Node> zero = std::make_shared<op::ScalarConstantLike>(node, 0.0);
std::shared_ptr<Node> bzero = std::make_shared<op::BroadcastLike>(zero, node, AxisSet{});
......@@ -49,12 +49,12 @@ NodeVector make_zeros(std::shared_ptr<Node> x)
auto goes = op::get_output_elements(x);
for (size_t i = 0; i < goes.size(); ++i)
{
zeros.push_back(make_zero(goes.at(i)));
zeros.push_back(make_broadcast_zero(goes.at(i)));
}
}
else
{
zeros.push_back(make_zero(x));
zeros.push_back(make_broadcast_zero(x));
}
return zeros;
}
......@@ -187,7 +187,7 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x,
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto zero = make_zero(x);
auto zero = make_broadcast_zero(x);
NodeVector zeros{
std::make_shared<op::ReplaceSlice>(zero, delta, lower_bounds, upper_bounds, strides)};
m_adjoint_map.insert({x.get(), zeros});
......
......@@ -211,34 +211,6 @@ bool ngraph::is_post_dominated(Node* X, Node* Y)
return true;
}
void ngraph::NodeMap::update(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> val)
{
if (!exists(orig))
{
throw ngraph_error("Node doesn't exist!");
}
m_node_map[orig] = val;
}
void ngraph::NodeMap::add(std::shared_ptr<ngraph::Node> orig,
std::shared_ptr<ngraph::Node> replacement)
{
if (exists(orig))
{
throw ngraph_error("NodeMap: key already exists");
}
m_node_map[orig] = replacement;
}
std::shared_ptr<ngraph::Node> ngraph::NodeMap::get(std::shared_ptr<ngraph::Node> orig) const
{
if (!exists(orig))
{
throw ngraph_error("NodeMap: key does not exist");
}
return m_node_map.at(orig);
}
std::list<std::shared_ptr<ngraph::Node>>
ngraph::clone_nodes(const std::list<std::shared_ptr<ngraph::Node>>& nodes, NodeMap& node_map)
{
......@@ -246,22 +218,22 @@ std::list<std::shared_ptr<ngraph::Node>>
auto sorted_nodes = topological_sort(nodes, true);
for (auto node : sorted_nodes)
{
if (!node_map.exists(node))
if (node_map.count(node.get()) == 0)
{
// get (already) cloned arguments and clone the node
NodeVector cloned_args;
for (auto arg : node->get_arguments())
{
cloned_args.push_back(node_map.get(arg));
cloned_args.push_back(node_map.at(arg.get()));
}
auto cloned_node = node->copy_with_new_args(cloned_args);
//copy control dependencies
for (auto cdep : node->get_control_dependencies())
{
cloned_node->add_control_dependency(node_map.get(cdep));
cloned_node->add_control_dependency(node_map.at(cdep.get()));
}
node_map.add(node, cloned_node);
node_map[node.get()] = cloned_node;
}
}
......@@ -270,7 +242,7 @@ std::list<std::shared_ptr<ngraph::Node>>
std::list<std::shared_ptr<ngraph::Node>> cloned_nodes;
for (auto node : nodes)
{
cloned_nodes.push_back(node_map.get(node));
cloned_nodes.push_back(node_map.at(node.get()));
}
return cloned_nodes;
}
......@@ -291,7 +263,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
ResultVector cloned_results;
for (shared_ptr<Node> node : func.get_results())
{
auto result = std::dynamic_pointer_cast<op::Result>(node_map.get(node));
auto result = std::dynamic_pointer_cast<op::Result>(node_map.at(node.get()));
if (!result)
{
throw ngraph_error("Results should be of type op::Result");
......@@ -301,7 +273,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
for (auto param : func.get_parameters())
{
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.get(param)));
cloned_params.push_back(std::dynamic_pointer_cast<op::Parameter>(node_map.at(param.get())));
}
// create and return cloned function
......
......@@ -233,48 +233,6 @@ namespace ngraph
bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant);
// maps original to replacement nodes e.g. for clone utilities
// performs index checking on access
class NodeMap
{
public:
// map original node to replacement node
// throws ngraph_error if key already exists
void add(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> replacement);
// get replacement node from original node
// throws ngrah_error if key does not exist
std::shared_ptr<ngraph::Node> get(std::shared_ptr<ngraph::Node> orig) const;
template <typename T>
T dynamic_get(const T& orig)
{
return std::dynamic_pointer_cast<typename T::element_type>(get(orig));
}
// returns true if original node is already mapped
bool exists(std::shared_ptr<ngraph::Node> orig) const
{
return (m_node_map.count(orig) != 0);
}
void update(std::shared_ptr<ngraph::Node> orig, std::shared_ptr<ngraph::Node> val);
const std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>&
get_node_map() const
{
return m_node_map;
}
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>&
get_node_map()
{
return m_node_map;
}
private:
std::unordered_map<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>> m_node_map;
};
// input nodes are cloned and returned
// NodeMap input may contain default node mapping i.e. pre-cloned nodes
// NodeMap output (by reference) fully maps input and cloned nodes
......
......@@ -60,6 +60,9 @@ namespace ngraph
size_t i);
const NodeVector& check_single_output_args(const NodeVector& args);
/// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
/// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor
/// or a (possibly empty) tuple of values.
......@@ -140,8 +143,7 @@ namespace ngraph
/// graph against the graph.
bool is_same_op_type(const std::shared_ptr<Node>& node) const
{
Node* n = node.get();
return std::type_index(typeid(*this)) == std::type_index(typeid(*n));
return description() == node->description();
}
/// \brief Marks an input as being relevant or irrelevant to the output shapes of this
......
......@@ -15,7 +15,6 @@
//*****************************************************************************
#include "ngraph/runtime/cpu/op/loop_kernel.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
......@@ -35,7 +34,7 @@ shared_ptr<Node>
NodeMap nm;
for (size_t i = 0; i < args.size(); i++)
{
nm.add(args.at(i), new_args.at(i));
nm[args.at(i).get()] = new_args.at(i);
}
NodeVector new_node_list;
......@@ -44,17 +43,17 @@ shared_ptr<Node>
NodeVector cur_args;
for (auto a : n->get_arguments())
{
cur_args.push_back(nm.get(a));
cur_args.push_back(nm.at(a.get()));
}
auto new_n = n->copy_with_new_args(cur_args);
nm.add(n, new_n);
nm[n.get()] = new_n;
new_node_list.push_back(new_n);
}
NodeVector new_outputs;
for (auto o : m_output_nodes)
{
new_outputs.push_back(nm.get(o));
new_outputs.push_back(nm.at(o.get()));
}
return std::make_shared<LoopKernel>(new_node_list, new_outputs, new_args);
......
......@@ -18,8 +18,6 @@
using namespace ngraph;
using ReplacementMap = std::map<Node*, std::shared_ptr<Node>>;
std::shared_ptr<Function>
ngraph::specialize_shapes(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
......@@ -28,7 +26,7 @@ std::shared_ptr<Function>
NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size());
ReplacementMap m;
NodeMap m;
for (size_t i = 0; i < parameter_shapes.size(); i++)
{
......
......@@ -202,7 +202,6 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// Create a fprop_cache object to store the results of this analysis
FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>();
// Traverse bprop to find all of the nodes in the bprop graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
......@@ -227,9 +226,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
if (in_bprop.count(node) != 0 &&
std::find(bprop_inputs.begin(), bprop_inputs.end(), node) == bprop_inputs.end())
{
fprop_cache.node_param_map->add(
node,
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
fprop_cache.node_param_map[node.get()] =
std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape());
}
});
......@@ -237,13 +235,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// intermediate parameters from fprop_cache. This breaks connections in the
// bprop graph such that only intermediate values from fprop needed by bprop
// are still connected to the bprop graph as parameters
ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));
ngraph::clone_nodes(bprop->get_ops(), fprop_cache.node_param_map);
// invert the fprop_cache cloned node map for easy back and for acces.
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<Node>> inverted_node_map;
for (auto kv : fprop_cache.node_param_map->get_node_map())
std::unordered_map<Node*, Node*> inverted_node_map;
for (auto kv : fprop_cache.node_param_map)
{
inverted_node_map[kv.second] = kv.first;
inverted_node_map[kv.second.get()] = kv.first;
}
// get cloned bprop results
......@@ -251,7 +249,8 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
NodeVector result_nodes;
for (auto node : bprop->get_results())
{
auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map->get(node));
auto result =
std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map.at(node.get()));
if (!result)
{
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
......@@ -266,15 +265,15 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
ParameterVector bprop_input_params;
for (auto param : bprop_inputs)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
bprop_input_params.push_back(std::dynamic_pointer_cast<op::Parameter>(
fprop_cache.node_param_map.at(param.get())));
}
// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map.at(x)));
}
return bprop_input_params;
};
......@@ -291,7 +290,7 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
std::find(cloned_bprop_inputs.begin(), cloned_bprop_inputs.end(), pnode) ==
cloned_bprop_inputs.end())
{
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node));
fprop_cache.fprop_output_nodes.push_back(inverted_node_map.at(node.get()));
}
},
false /* no control dependencies */);
......@@ -299,8 +298,9 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
// create the new outputs for fprop and the new fprop function
ResultVector fprop_outputs = fprop->get_results();
for (auto fpir : fprop_cache.fprop_output_nodes)
for (auto fpirn : fprop_cache.fprop_output_nodes)
{
auto fpir = fpirn->shared_from_this();
if (std::dynamic_pointer_cast<op::Result>(fpir))
{
throw ngraph_error("Expected op::Result in fprop->get_results()");
......
......@@ -31,6 +31,7 @@
#include <vector>
#include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node_vector.hpp"
#include "ngraph/shape.hpp"
......@@ -38,7 +39,6 @@ namespace ngraph
{
class Node;
class Function;
class NodeMap;
class stopwatch;
namespace runtime
......@@ -214,8 +214,8 @@ namespace ngraph
{
std::shared_ptr<Function> fprop;
std::shared_ptr<Function> bprop;
std::vector<std::shared_ptr<Node>> fprop_output_nodes;
std::shared_ptr<NodeMap> node_param_map;
std::vector<Node*> fprop_output_nodes;
NodeMap node_param_map;
};
/**
......
......@@ -216,7 +216,7 @@ public:
auto cloneit = clone.begin();
while (origit != orig.end() && cloneit != clone.end())
{
if (*cloneit != nm.get_node_map().at(*origit))
if (*cloneit != nm.at((*origit).get()))
{
return false;
}
......@@ -232,11 +232,11 @@ TEST_F(CloneTest, clone_nodes_full)
auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodeVector(nodes, cloned_nodes, node_map));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.get(A)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.get(B)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.get(C)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Add>(node_map.get(AplusB)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Multiply>(node_map.get(AplusBtimesC)));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.at(A.get())));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.at(B.get())));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Parameter>(node_map.at(C.get())));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Add>(node_map.at(AplusB.get())));
ASSERT_NE(nullptr, std::dynamic_pointer_cast<op::Multiply>(node_map.at(AplusBtimesC.get())));
auto sorted_nodes = topological_sort(nodes);
auto sorted_cloned_nodes = topological_sort(cloned_nodes);
......@@ -247,13 +247,13 @@ TEST_F(CloneTest, clone_nodes_partial)
{
// map A -> A' prior to clone
auto Aprime = make_shared<op::Parameter>(element::f32, shape);
node_map.add(A, Aprime);
node_map[A.get()] = Aprime;
auto cloned_nodes = clone_nodes(nodes, node_map);
ASSERT_TRUE(CompareNodeVector(nodes, cloned_nodes, node_map));
// ensure A -> A' after clone
ASSERT_EQ(Aprime, node_map.get(A));
ASSERT_EQ(Aprime, node_map.at(A.get()));
}
TEST_F(CloneTest, clone_function_full)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment