Commit c1680ce3 authored by fenglei's avatar fenglei

Merge branch 'tfl/send_recv_op' of github.com:NervanaSystems/ngraph into tfl/send_recv_op

parents c643cb5e 603cbdab
......@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP(DepthToSpace, ngraph::op)
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GroupConvolution, ngraph::op)
NGRAPH_OP(GroupConvolutionTranspose, ngraph::op)
NGRAPH_OP(HardSigmoid, ngraph::op)
......@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP(Normalize, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(ScaleShift, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(ShuffleChannels, ngraph::op)
NGRAPH_OP(SpaceToDepth, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(SquaredDifference, ngraph::op)
NGRAPH_OP(Squeeze, ngraph::op)
NGRAPH_OP(Split, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
......@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(DynBroadcast, ngraph::op)
NGRAPH_OP(Dot, ngraph::op)
NGRAPH_OP(DynBroadcast, ngraph::op)
NGRAPH_OP(DynPad, ngraph::op)
NGRAPH_OP(DynReshape, ngraph::op)
NGRAPH_OP(DynSlice, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op)
NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Erf, ngraph::op)
NGRAPH_OP(Exp, ngraph::op)
......@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP(Product, ngraph::op)
NGRAPH_OP(Quantize, ngraph::op)
NGRAPH_OP(QuantizedAvgPool, ngraph::op)
NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBias, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionBiasSignedAdd, ngraph::op)
NGRAPH_OP(QuantizedConvolutionRelu, ngraph::op)
NGRAPH_OP(QuantizedConvolution, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedDot, ngraph::op)
NGRAPH_OP(QuantizedDotBias, ngraph::op)
NGRAPH_OP(QuantizedMaxPool, ngraph::op)
NGRAPH_OP(Recv, ngraph::op)
NGRAPH_OP(Range, ngraph::op)
......@@ -155,7 +156,6 @@ NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP(Sum, ngraph::op)
NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(EmbeddingLookup, ngraph::op)
......@@ -13,36 +13,51 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Node> node)
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_has_direct_support{callback}
{
}
bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
if (auto fused_op = std::dynamic_pointer_cast<ngraph::op::util::FusedOp>(node))
if (auto fused_op = dynamic_pointer_cast<op::util::FusedOp>(node))
{
if (m_callback && m_callback(*node))
if (m_has_direct_support && m_has_direct_support(*node))
{
// Op supported by backend. Do not decompose
return modified;
}
auto subgraph_outputs = fused_op->decompose_op();
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
for (auto subgraph_node : subgraph)
{
if (auto nested_fused_op = dynamic_pointer_cast<op::util::FusedOp>(subgraph_node))
{
if (!(m_has_direct_support && m_has_direct_support(*nested_fused_op)))
{
run_on_node(nested_fused_op);
}
}
}
size_t i = 0;
for (auto output_node : subgraph_outputs)
{
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{
// TODO: Provenance
std::set<ngraph::descriptor::Input*> fop_users{
begin(fused_op->get_outputs().at(i).get_inputs()),
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users)
{
......@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if (goe->get_n() == i && !goe->get_output_inputs(0).empty())
{
// Replace GOE users
std::set<ngraph::descriptor::Input*> goe_users{
set<descriptor::Input*> goe_users{
begin(goe->get_outputs().at(0).get_inputs()),
end(goe->get_outputs().at(0).get_inputs())};
for (auto goe_user : goe_users)
......@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return modified;
}
pass::FusedOpDecomposition::FusedOpDecomposition(op_query_t callback)
: m_callback{callback}
{
}
......@@ -16,6 +16,9 @@
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
......@@ -25,13 +28,24 @@ namespace ngraph
class FusedOpDecomposition : public NodePass
{
public:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using op_query_t = std::function<bool(const Node& node)>;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition(op_query_t callback = nullptr);
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
private:
op_query_t m_callback = nullptr;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t m_has_direct_support = nullptr;
};
}
}
......@@ -1180,7 +1180,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(RecurrentReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment