Commit fccaa49c authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

Recursive FusedOpDecomposition pass. (#3073)

* Run fused op decomposition recursively until no more fused ops.

* Update callback member name.
parent fe6454ff
......@@ -13,37 +13,52 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Node> node)
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_has_direct_support{callback}
{
}
bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
if (auto fused_op = std::dynamic_pointer_cast<ngraph::op::util::FusedOp>(node))
if (auto fused_op = dynamic_pointer_cast<op::util::FusedOp>(node))
{
if (m_callback && m_callback(*node))
if (m_has_direct_support && m_has_direct_support(*node))
{
// Op supported by backend. Do not decompose
return modified;
}
auto subgraph_outputs = fused_op->decompose_op();
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
for (auto subgraph_node : subgraph)
{
if (auto nested_fused_op = dynamic_pointer_cast<op::util::FusedOp>(subgraph_node))
{
if (!(m_has_direct_support && m_has_direct_support(*nested_fused_op)))
{
run_on_node(nested_fused_op);
}
}
}
size_t i = 0;
for (auto output_node : subgraph_outputs)
{
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{
// TODO: Provenance
std::set<ngraph::descriptor::Input*> fop_users{
begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users)
{
if (auto goe =
......@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if (goe->get_n() == i && !goe->get_output_inputs(0).empty())
{
// Replace GOE users
std::set<ngraph::descriptor::Input*> goe_users{
set<descriptor::Input*> goe_users{
begin(goe->get_outputs().at(0).get_inputs()),
end(goe->get_outputs().at(0).get_inputs())};
for (auto goe_user : goe_users)
......@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return modified;
}
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_callback{callback}
{
}
......@@ -16,6 +16,9 @@
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
......@@ -25,13 +28,24 @@ namespace ngraph
class FusedOpDecomposition : public NodePass
{
public:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using op_query_t = std::function<bool(const Node& node)>;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition(op_query_t callback = nullptr);
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
private:
op_query_t m_callback = nullptr;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t m_has_direct_support = nullptr;
};
}
}
......@@ -1180,7 +1180,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(RecurrentReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment