mlir_subgraph_extraction.cpp 3.44 KB
Newer Older
1
//*****************************************************************************
nmostafa's avatar
nmostafa committed
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "mlir_subgraph_extraction.hpp"
Diego Caballero's avatar
Diego Caballero committed
18
#include "ngraph/assertion.hpp"
19
#include "ngraph/graph_util.hpp"
20 21
#include "ngraph/op/add.hpp"
#include "ngraph/op/dot.hpp"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"

using namespace ngraph::descriptor;
using namespace ngraph::op;
using namespace ngraph::pass;

#define TI(x) std::type_index(typeid(x))

bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{
    // Create a CompiledKernel for all the ops in the function, except Parameters and Results.
    NodeVector ck_ops;
    for (auto op : func->get_ordered_ops())
    {
37 38
        // All ops must be supported by MLIR compiler
        if (!is_supported_mlir_op(op))
nmostafa's avatar
nmostafa committed
39
        {
40
            return false;
nmostafa's avatar
nmostafa committed
41
        }
42

43 44 45 46 47 48 49 50 51 52 53 54 55
        if (TI(Parameter) != TI(*op) && TI(Result) != TI(*op))
        {
            ck_ops.push_back(op);
        }
    }

    NodeVector ck_args;
    for (auto& param : func->get_parameters())
    {
        ck_args.push_back(param);
    }

    NodeVector ck_outputs = std::move(get_subgraph_outputs(ck_ops, {} /*exclusions*/));
56
    if (ck_outputs.size() != 1)
nmostafa's avatar
nmostafa committed
57
    {
58
        return false;
nmostafa's avatar
nmostafa committed
59
    }
60 61 62 63 64 65 66 67 68 69 70 71

    auto ck = std::make_shared<CompiledKernel>(ck_ops, ck_outputs, ck_args);

    // Connect CompiledKernel to output nodes by replacing the output descriptors of the output
    // nodes.
    for (size_t i = 0, end = ck_outputs.size(); i < end; ++i)
    {
        auto& output_descs = ck_outputs[i]->get_outputs();
        NGRAPH_ASSERT(output_descs.size() == 1) << "Unexpected multiple output descriptors";
        auto& out_desc = output_descs[0];

        // 'replace_output' invalidates iterator of the original container. Use a copy instead.
72
        const std::set<descriptor::Input*> input_descs = out_desc.get_inputs();
73

74
        for (descriptor::Input* in_desc : input_descs)
75 76 77 78 79 80 81
        {
            in_desc->replace_output(ck, i);
        }
    }

    return true;
}
82 83 84 85 86 87

#define TI(x) std::type_index(typeid(x))

bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
    if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node))
nmostafa's avatar
nmostafa committed
88
    {
89
        return true;
nmostafa's avatar
nmostafa committed
90
    }
91 92 93

    // supported by backend ?
    if (m_supported_ops.find(TI(*node)) == m_supported_ops.end())
nmostafa's avatar
nmostafa committed
94
    {
95
        return false;
nmostafa's avatar
nmostafa committed
96
    }
97 98 99 100 101 102 103

    // check on invariants expected by MLIR backend

    // Dot is 2D only
    if (TI(ngraph::op::Dot) == TI(*node))
    {
        if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2)
nmostafa's avatar
nmostafa committed
104
        {
105
            return false;
nmostafa's avatar
nmostafa committed
106
        }
107 108 109 110 111 112 113 114
    }
    return true;
}

const std::set<std::type_index> MLIRSubgraphExtractionPass::m_supported_ops{
#define MLIR_OP(OP) TI(ngraph::op::OP),
#include "contrib/mlir/ops_supported.inc"
};