Commit ef58667f authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fixes for cpu_fusion.validate_fuse_gru_inputs (#3511)

* WIP

* Fix incorrect CK output adjustment

* Bug fix and enroce sanity check

* Change cycle search depth, and fix sanity check

* cpu_fusion.validate_fuse_gru_inputs passes.

* Fix as_single_output to be able to always create a GOE

* minor fix. style-apply

* Clean up debug msgs

* Switch to backward cycle check

* Enable failing test

* PR fixes

* Address feedback: Add fwd cycle checks. Make cycle checking depth configurable
parent a1f3202c
...@@ -45,9 +45,6 @@ using namespace ngraph::pass; ...@@ -45,9 +45,6 @@ using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
// Maximum depth to check for cycles. If exceeded, we conservatively assume a cycle.
#define MAX_CYCLE_DEPTH 100
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0; int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
template <typename T> template <typename T>
...@@ -99,6 +96,15 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2) ...@@ -99,6 +96,15 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
m_pass.m_id_to_graph.erase(sg2.get_id()); m_pass.m_id_to_graph.erase(sg2.get_id());
} }
MLIRSubgraphExtractionPass::MLIRSubgraphExtractionPass()
: m_max_cycle_depth(20)
{
if (char* max_cycle_depth = std::getenv("NGRAPH_MLIR_MAX_CYCLE_DEPTH"))
{
m_max_cycle_depth = std::stoi(max_cycle_depth);
}
}
// The sub-graph construction algorithm is as follows // The sub-graph construction algorithm is as follows
// For each node, check its predecessors, if // For each node, check its predecessors, if
// - all predecessors in sub-graphs belong to the same sub-graph (graph ID), then extend the // - all predecessors in sub-graphs belong to the same sub-graph (graph ID), then extend the
...@@ -118,7 +124,19 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2) ...@@ -118,7 +124,19 @@ void MLIRSubgraphExtractionPass::MLIRSubgraph::merge(MLIRSubgraph& sg2)
// - CK will internally have lists record graph nodes, and graph output nodes. // - CK will internally have lists record graph nodes, and graph output nodes.
bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{ {
NGRAPH_DEBUG << "[CK Extract] Construct sub-graphs" << std::endl; build_subgraphs(func);
auto ck_nodes = build_ck_nodes(func);
#ifdef NGRAPH_DEBUG_ENABLE
sanity_check(func, ck_nodes);
#endif
return true;
}
void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
{
NGRAPH_DEBUG << "[CK Extract] Construct sub-graphs";
for (auto op : func->get_ordered_ops()) for (auto op : func->get_ordered_ops())
{ {
NodeVector inputs; NodeVector inputs;
...@@ -133,7 +151,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -133,7 +151,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
continue; continue;
} }
NGRAPH_DEBUG << "[CK Extract] Processing " << *op << std::endl; NGRAPH_DEBUG << "[CK Extract] Processing " << *op;
// supported op // supported op
for (auto pred : op->get_arguments()) for (auto pred : op->get_arguments())
{ {
...@@ -151,30 +169,32 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -151,30 +169,32 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
} }
if (subgraph_ids.size() == 0) if (subgraph_ids.size() == 0)
{ {
NGRAPH_DEBUG << "[CK Extract] Start new sub-graph " << std::endl;
// we couldn't find any predecessor sub-graphs to extend with this node // we couldn't find any predecessor sub-graphs to extend with this node
// create a new sub-graph // create a new sub-graph
MLIRSubgraph sg = MLIRSubgraph::create(this); MLIRSubgraph sg = MLIRSubgraph::create(this);
sg.add_inputs(inputs); sg.add_inputs(inputs);
sg.add_node(op); sg.add_node(op);
add_subgraph(sg); add_subgraph(sg);
NGRAPH_DEBUG << " [CK Extract] Start new sub-graph " << sg.get_id();
} }
else else
{ {
// we have sub-graphs. // we have sub-graphs.
// check if adding this node to the sub-graph will create a cycle in the DAG // check if adding this node to the sub-graph will create a cycle in the DAG
NGRAPH_DEBUG << "[CK Extract] Extending sub-graph. Check for cycles " << std::endl; NGRAPH_DEBUG << " [CK Extract] Extending sub-graph. Check for cycles";
if (!check_cycles(op, subgraph_ids)) if (!check_cycles(op, subgraph_ids))
{ {
NGRAPH_DEBUG << "[CK Extract] Merging subgraphs"; NGRAPH_DEBUG << " [CK Extract] Merging subgraphs ";
// merge sub-graphs if needed // merge sub-graphs if needed
std::unordered_set<int>::iterator it = subgraph_ids.begin(); std::unordered_set<int>::iterator it = subgraph_ids.begin();
int sg_id = *it; int sg_id = *it;
MLIRSubgraph& first_subgraph = get_subgraph(sg_id); MLIRSubgraph& first_subgraph = get_subgraph(sg_id);
NGRAPH_CHECK(first_subgraph.get_id() == sg_id); NGRAPH_CHECK(first_subgraph.get_id() == sg_id);
NGRAPH_DEBUG << " Graph ID: " << sg_id;
while (++it != subgraph_ids.end()) while (++it != subgraph_ids.end())
{ {
sg_id = *it; sg_id = *it;
NGRAPH_DEBUG << " Graph ID: " << sg_id;
MLIRSubgraph& subgraph = get_subgraph(sg_id); MLIRSubgraph& subgraph = get_subgraph(sg_id);
NGRAPH_CHECK(subgraph.get_id() == sg_id); NGRAPH_CHECK(subgraph.get_id() == sg_id);
first_subgraph.merge(subgraph); first_subgraph.merge(subgraph);
...@@ -187,7 +207,8 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -187,7 +207,8 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{ {
// we have a cycle, start a new sub-graph // we have a cycle, start a new sub-graph
MLIRSubgraph sg = MLIRSubgraph::create(this); MLIRSubgraph sg = MLIRSubgraph::create(this);
NGRAPH_DEBUG << "[CK Extract] Cycle found. Start a new subgraph"; NGRAPH_DEBUG << " [CK Extract] Cycle found. Start a new subgraph "
<< sg.get_id();
// use all predecessors as graph inputs // use all predecessors as graph inputs
NodeVector inputs = op->get_arguments(); NodeVector inputs = op->get_arguments();
sg.add_inputs(inputs); sg.add_inputs(inputs);
...@@ -195,10 +216,10 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -195,10 +216,10 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
add_subgraph(sg); add_subgraph(sg);
} }
} }
NGRAPH_DEBUG << "[CK Extract] Node Processed " << *op << std::endl; NGRAPH_DEBUG << "[CK Extract] Node Processed " << *op;
} }
NGRAPH_DEBUG << "[CK Extract] Get subgraphs output nodes" << std::endl; NGRAPH_DEBUG << "[CK Extract] Get subgraphs output nodes";
// get output nodes for each sub-graph. Do this before attaching CK nodes since we will // get output nodes for each sub-graph. Do this before attaching CK nodes since we will
// remove output edges from the sub-graphs. // remove output edges from the sub-graphs.
for (IDGraphMap::iterator it = m_id_to_graph.begin(); it != m_id_to_graph.end(); it++) for (IDGraphMap::iterator it = m_id_to_graph.begin(); it != m_id_to_graph.end(); it++)
...@@ -211,8 +232,12 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -211,8 +232,12 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
false /* ignore output duplicates */)); false /* ignore output duplicates */));
sg.add_outputs(outputs); sg.add_outputs(outputs);
} }
}
NGRAPH_DEBUG << "[CK Extract] Construct CK nodes" << std::endl; ngraph::NodeVector MLIRSubgraphExtractionPass::build_ck_nodes(std::shared_ptr<Function> func)
{
NodeVector ck_nodes;
NGRAPH_DEBUG << "[CK Extract] Construct CK nodes";
// attach CK node to each sub-graph. // attach CK node to each sub-graph.
for (auto it : m_id_to_graph) for (auto it : m_id_to_graph)
{ {
...@@ -228,27 +253,40 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -228,27 +253,40 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
NodeVector nodes_vector(nodes_list.begin(), nodes_list.end()); NodeVector nodes_vector(nodes_list.begin(), nodes_list.end());
auto ck = std::make_shared<CompiledKernel>(nodes_vector, outputs_vector, inputs_vector); auto ck = std::make_shared<CompiledKernel>(nodes_vector, outputs_vector, inputs_vector);
NGRAPH_DEBUG << "[CK Extract] Graph ID = " << sg.get_id() << std::endl; ck_nodes.push_back(ck);
NGRAPH_DEBUG << "[CK Extract] Graph Nodes: " << std::endl;
NGRAPH_DEBUG << "[CK Extract] Graph ID = " << sg.get_id();
NGRAPH_DEBUG << " [CK Extract] Graph Nodes: ";
for (auto node : nodes) for (auto node : nodes)
{ {
NGRAPH_DEBUG << "[CK Extract] " << *node << std::endl; NGRAPH_DEBUG << " [CK Extract] " << *node;
} }
NGRAPH_DEBUG << "[CK Extract] Input Nodes: " << std::endl; NGRAPH_DEBUG << " [CK Extract] Input Nodes: ";
for (auto node : inputs) for (auto node : inputs)
{ {
NGRAPH_DEBUG << "[CK Extract] " << *node << std::endl; NGRAPH_DEBUG << " [CK Extract] " << *node;
} }
NGRAPH_DEBUG << "[CK Extract] Output Nodes: " << std::endl; NGRAPH_DEBUG << " [CK Extract] Output Nodes: ";
for (auto node : outputs) for (auto node : outputs)
{ {
NGRAPH_DEBUG << "[CK Extract] " << *node << std::endl; NGRAPH_DEBUG << " [CK Extract] " << *node;
;
} }
NGRAPH_DEBUG << " [CK Extract] CK Node = " << *ck;
}
// Connect CompiledKernel to output nodes by replacing the output descriptors of the output
// Do this after all CK nodes are constructed since they add new edges in the graph (CK inputs)
for (auto& node : ck_nodes)
{
auto ck = std::static_pointer_cast<CompiledKernel>(node);
auto& outputs_vector = ck->get_kernel_outputs();
auto& node_list = ck->get_node_list();
std::unordered_set<std::shared_ptr<Node>> node_set(node_list.begin(), node_list.end());
// Connect CompiledKernel to output nodes by replacing the output descriptors of the output
// nodes.
for (size_t i = 0, end = outputs_vector.size(); i < end; ++i) for (size_t i = 0, end = outputs_vector.size(); i < end; ++i)
{ {
auto& output_descs = outputs_vector[i]->get_outputs(); auto& output_descs = outputs_vector[i]->get_outputs();
...@@ -260,15 +298,88 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -260,15 +298,88 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for (descriptor::Input* in_desc : input_descs) for (descriptor::Input* in_desc : input_descs)
{ {
if (nodes.find(in_desc->get_node()) == nodes.end()) if (node_set.find(in_desc->get_node()) == node_set.end())
{ {
in_desc->replace_output(ck, i); in_desc->replace_output(ck, i);
} }
} }
} }
} }
for (auto& node : ck_nodes)
{
auto ck = std::static_pointer_cast<CompiledKernel>(node);
if (ck->get_output_size() > 1)
{
for (auto& old_output : ck->outputs())
{
auto inputs = old_output.get_target_inputs();
auto goe_node = old_output.as_single_output_node(false);
auto new_output = goe_node->output(0);
for (auto& input : inputs)
{
input.replace_source_output(new_output);
}
}
}
}
return true; return ck_nodes;
}
// Do a sanity check on graph invariants
// - no cycles
// - inputs to sub-graph are inputs to CK
// - no outputs out of subgraph for output nodes
void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes)
{
NodeVector cycles;
bool is_bkwd_cycle;
if (check_for_cycles(func.get(), cycles, is_bkwd_cycle))
{
NGRAPH_CHECK(cycles.size() != 0, "Empty cycle ?");
if (is_bkwd_cycle)
{
NGRAPH_DEBUG << "Backward cycle:";
}
for (auto& node : cycles)
{
NGRAPH_DEBUG << node;
}
NGRAPH_UNREACHABLE("Function contains cycle after subgraph constructions");
}
for (auto& node : ck_nodes)
{
auto ck_node = std::static_pointer_cast<CompiledKernel>(node);
auto& node_list = ck_node->get_node_list();
std::unordered_set<std::shared_ptr<Node>> node_set(node_list.begin(), node_list.end());
// CK output nodes shouldn't have any users outside the sub-graph,
// they are all moved to the CK node instead
for (auto& ck_output : ck_node->get_kernel_outputs())
{
for (auto& user : ck_output->get_users())
{
NGRAPH_CHECK(node_set.find(user) != node_set.end(),
"CK output nodes users should be in the sub-graph");
}
}
// Any input to CK must also have at least one user in the sub-graph body
for (auto& arg : ck_node->get_arguments())
{
bool found = false;
for (auto& user : arg->get_users())
{
found = (node_set.find(user) != node_set.end());
if (found)
{
break;
}
}
NGRAPH_CHECK(found, "CK input is not input to sub-graph");
}
}
} }
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
...@@ -378,7 +489,7 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node, ...@@ -378,7 +489,7 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
unsigned depth) unsigned depth)
{ {
// Going too deep, bail out. // Going too deep, bail out.
if (depth >= MAX_CYCLE_DEPTH) if (depth >= m_max_cycle_depth)
return true; return true;
// root node is always inside merged sub-graphs. // root node is always inside merged sub-graphs.
...@@ -398,7 +509,21 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node, ...@@ -398,7 +509,21 @@ bool MLIRSubgraphExtractionPass::check_cycles(std::shared_ptr<Node> node,
inside_subgraphs = false; inside_subgraphs = false;
} }
} }
for (auto& input : node->get_arguments()) NodeVector node_inputs;
auto subgraph_id = get_subgraph_id(node);
// If the node is part of a sub-graph, capture all of sub-graph inputs, else only node's input
if (subgraph_id >= 0)
{
auto subgraph = get_subgraph(subgraph_id);
auto subgraph_inputs = subgraph.get_inputs();
node_inputs.insert(node_inputs.end(), subgraph_inputs.begin(), subgraph_inputs.end());
}
else
{
node_inputs = node->get_arguments();
}
for (auto& input : node_inputs)
{ {
if (check_cycles(input, subgraph_ids, inside_subgraphs, ++depth)) if (check_cycles(input, subgraph_ids, inside_subgraphs, ++depth))
return true; return true;
......
...@@ -83,7 +83,7 @@ namespace ngraph ...@@ -83,7 +83,7 @@ namespace ngraph
friend class MLIRSubgraph; friend class MLIRSubgraph;
public: public:
MLIRSubgraphExtractionPass() {} MLIRSubgraphExtractionPass();
bool run_on_function(std::shared_ptr<Function> func) override; bool run_on_function(std::shared_ptr<Function> func) override;
/// Checks if an ngraph node is supported by MLIR backend /// Checks if an ngraph node is supported by MLIR backend
bool is_supported_mlir_op(std::shared_ptr<Node> node); bool is_supported_mlir_op(std::shared_ptr<Node> node);
...@@ -123,7 +123,10 @@ namespace ngraph ...@@ -123,7 +123,10 @@ namespace ngraph
unsigned depth = 0); unsigned depth = 0);
private: private:
static const std::set<std::type_index> m_supported_ops; void build_subgraphs(std::shared_ptr<Function> func);
NodeVector build_ck_nodes(std::shared_ptr<Function> func);
void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes);
private: private:
using IDGraphMap = std::unordered_map<int, MLIRSubgraph>; using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
...@@ -132,6 +135,10 @@ namespace ngraph ...@@ -132,6 +135,10 @@ namespace ngraph
NodeGraphMap m_node_to_graph; NodeGraphMap m_node_to_graph;
// Mutex over sub-graph IDs // Mutex over sub-graph IDs
std::mutex m_subgraph_mutex; std::mutex m_subgraph_mutex;
static const std::set<std::type_index> m_supported_ops;
// Maximum depth to check for cycles during merging of sub-graphs.
// If exceeded, we conservatively assume a cycle.
int m_max_cycle_depth;
}; };
} }
} }
...@@ -653,3 +653,93 @@ std::vector<Output<Node>> ngraph::get_outputs_to(Node& src, Node& dst) ...@@ -653,3 +653,93 @@ std::vector<Output<Node>> ngraph::get_outputs_to(Node& src, Node& dst)
return result; return result;
} }
static bool check_for_cycles_bkwd(std::shared_ptr<ngraph::Node> node,
std::deque<std::shared_ptr<ngraph::Node>>& path,
std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
ngraph::NodeVector& cycle_nodes)
{
path.push_back(node);
path_set.insert(node);
for (auto& input : node->inputs())
{
auto arg = input.get_source_output().get_node_shared_ptr();
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
{
cycle_nodes.push_back(it);
}
// last node
cycle_nodes.push_back(arg);
return true;
}
if (check_for_cycles_bkwd(arg, path, path_set, cycle_nodes))
{
return true;
}
}
path_set.erase(path.back());
path.pop_back();
return false;
}
static bool check_for_cycles_fwd(std::shared_ptr<ngraph::Node> node,
std::deque<std::shared_ptr<ngraph::Node>>& path,
std::unordered_set<std::shared_ptr<ngraph::Node>>& path_set,
ngraph::NodeVector& cycle_nodes)
{
path.push_back(node);
path_set.insert(node);
for (auto& arg : node->get_users())
{
if (path_set.find(arg) != path_set.end())
{
for (auto it : path)
{
cycle_nodes.push_back(it);
}
// last node
cycle_nodes.push_back(arg);
return true;
}
if (check_for_cycles_fwd(arg, path, path_set, cycle_nodes))
{
return true;
}
}
path_set.erase(path.back());
path.pop_back();
return false;
}
bool ngraph::check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle)
{
for (auto res : func->get_results())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
{
is_bkwd_cycle = true;
return true;
};
}
for (auto param : func->get_parameters())
{
std::deque<std::shared_ptr<Node>> path;
// mirror of path stack for faster cycle check
std::unordered_set<std::shared_ptr<Node>> path_set;
if (check_for_cycles_fwd(param, path, path_set, cycle_nodes))
{
is_bkwd_cycle = false;
return true;
};
}
// no cycles
return false;
}
...@@ -299,7 +299,7 @@ namespace ngraph ...@@ -299,7 +299,7 @@ namespace ngraph
for (size_t i = 0; i < arg_count; ++i) for (size_t i = 0; i < arg_count; ++i)
{ {
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node(); Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
if (nodes_done.count(dep) == 0) if (nodes_done.count(dep) == 0 && nodes_to_emit.count(node) != 0)
{ {
can_add = false; can_add = false;
nodes_to_do.push(dep); nodes_to_do.push(dep);
...@@ -425,4 +425,11 @@ namespace ngraph ...@@ -425,4 +425,11 @@ namespace ngraph
/// \return A vector containing a handle for each output of src that is connected to an input /// \return A vector containing a handle for each output of src that is connected to an input
/// of `dst`. /// of `dst`.
std::vector<Output<Node>> get_outputs_to(Node& src, Node& dst); std::vector<Output<Node>> get_outputs_to(Node& src, Node& dst);
/// Checks the func for graph cycles starting from results going backwards, then from parameters
/// going forward.
/// It returns true if a cycle is found and the first cycle encountered.
bool check_for_cycles(const ngraph::Function* func,
ngraph::NodeVector& cycle_nodes,
bool& is_bkwd_cycle);
} }
...@@ -87,7 +87,7 @@ std::shared_ptr<Node> Node::copy_with_new_inputs(const OutputVector& inputs) con ...@@ -87,7 +87,7 @@ std::shared_ptr<Node> Node::copy_with_new_inputs(const OutputVector& inputs) con
return copy_with_new_inputs(inputs, get_control_dependencies()); return copy_with_new_inputs(inputs, get_control_dependencies());
} }
std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i) std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i, bool for_get_output_element)
{ {
for (auto in : output(i).get_target_inputs()) for (auto in : output(i).get_target_inputs())
{ {
...@@ -96,7 +96,7 @@ std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i) ...@@ -96,7 +96,7 @@ std::shared_ptr<Node> Node::get_output_as_single_output_node(size_t i)
return in.get_node()->shared_from_this(); return in.get_node()->shared_from_this();
} }
} }
return get_output_element(output(i), true); return get_output_element(output(i), for_get_output_element);
} }
std::shared_ptr<Node> std::shared_ptr<Node>
......
...@@ -281,7 +281,8 @@ namespace ngraph ...@@ -281,7 +281,8 @@ namespace ngraph
/// Returns the partial shape for output i /// Returns the partial shape for output i
const PartialShape& get_output_partial_shape(size_t i) const; const PartialShape& get_output_partial_shape(size_t i) const;
std::shared_ptr<Node> get_output_as_single_output_node(size_t i); std::shared_ptr<Node> get_output_as_single_output_node(size_t i,
bool for_get_output_element = true);
/// Checks that there is exactly one output and returns its shape /// Checks that there is exactly one output and returns its shape
// TODO: deprecate in favor of node->output(0).get_shape() with a suitable check in the // TODO: deprecate in favor of node->output(0).get_shape() with a suitable check in the
...@@ -565,9 +566,10 @@ namespace ngraph ...@@ -565,9 +566,10 @@ namespace ngraph
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; } std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return A useable shared pointer to this output. If index 0, the node, /// \return A useable shared pointer to this output. If index 0, the node,
/// otherwise find or create a GOE. /// otherwise find or create a GOE.
std::shared_ptr<Node> as_single_output_node() const NGRAPH_DEPRECATED("Transitional.") std::shared_ptr<Node> as_single_output_node(bool for_get_output_element = true) const
NGRAPH_DEPRECATED("Transitional.")
{ {
return m_node->get_output_as_single_output_node(m_index); return m_node->get_output_as_single_output_node(m_index, for_get_output_element);
} }
/// \return The index of the output referred to by this output handle. /// \return The index of the output referred to by this output handle.
......
...@@ -3888,7 +3888,7 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell) ...@@ -3888,7 +3888,7 @@ TEST(cpu_fusion, rnn_fusion_2rnn_layer_3lstm_cell)
} }
} }
TEST(cpu_fusion, MLIR_DISABLE_TEST(validate_fuse_gru_inputs)) TEST(cpu_fusion, validate_fuse_gru_inputs)
{ {
const std::string file_name("mxnet/gru_debug.json"); const std::string file_name("mxnet/gru_debug.json");
auto cpu_func = make_function_from_file(file_name); auto cpu_func = make_function_from_file(file_name);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment