Unverified Commit cf21a361 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Fix compilation warnings (#4236)

* Fix compilation warnings

* Camels
parent ca955d46
......@@ -346,7 +346,7 @@ namespace
unsigned inputCount = f.getType().getNumInputs();
// we find out output values by looking at returned values
// any return should return all outputs of the subgraph
f.walk([this, &outputCount, inputCount](NGReturnOp ret) {
f.walk([&outputCount, inputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++)
{
// annotate instructions defining outputs with the arg idx of the output
......@@ -1324,7 +1324,7 @@ namespace
}
attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n;
int broadcastHint;
int broadcastHint = -2;
if (vBias.rank() == 0)
{
// Scalar
......@@ -1356,6 +1356,7 @@ namespace
broadcastHint = 0;
}
}
NGRAPH_CHECK(broadcastHint != -2, "Unhandled broadcast");
attrs.gemmAttrs2d.broadcastHint = broadcastHint;
auto int64Ty = rewriter.getIntegerType(64);
......
......@@ -27,8 +27,6 @@ using namespace ngraph::descriptor;
using namespace ngraph::op;
using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_inputs(NodeVector& inputs)
......@@ -162,11 +160,11 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{
auto node = *it;
if (TI(Result) == TI(*node))
if (is_type<Result>(node))
{
erase_node(it, nodes_ready);
}
else if (TI(Parameter) == TI(*node))
else if (is_type<Parameter>(node))
{
process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready);
......@@ -209,7 +207,7 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{
auto node = *it;
if (TI(Result) == TI(*node))
if (is_type<Result>(node))
{
erase_node(it, nodes_ready);
}
......@@ -236,7 +234,7 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{
auto node = *it;
if (TI(Result) == TI(*node))
if (is_type<Result>(node))
{
erase_node(it, nodes_ready);
}
......@@ -264,10 +262,10 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
{
MLIRSubgraph& sg = it->second;
auto& nodes = sg.get_nodes();
NodeVector outputs = std::move(get_subgraph_outputs(NodeVector(nodes.begin(), nodes.end()),
{} /*exclusions*/,
false /* ignore unused */,
false /* ignore output duplicates */));
NodeVector outputs = get_subgraph_outputs(NodeVector(nodes.begin(), nodes.end()),
{} /*exclusions*/,
false /* ignore unused */,
false /* ignore output duplicates */);
sg.add_outputs(outputs);
}
}
......@@ -420,26 +418,24 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No
}
}
#define TI(x) std::type_index(typeid(x))
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node))
if (is_type<Parameter>(node) || is_type<Result>(node))
{
return true;
}
// supported by backend ?
if (m_supported_ops.find(TI(*node)) == m_supported_ops.end())
auto& supportedOps = getSupportedOps();
if (supportedOps.find(node->get_type_info()) == supportedOps.end())
{
return false;
}
// check on invariants expected by MLIR backend
if (TI(ngraph::op::Divide) == TI(*node))
if (auto div = as_type_ptr<ngraph::op::Divide>(node))
{
auto* div = static_cast<ngraph::op::Divide*>(node.get());
if (div->is_pythondiv())
{
// Python specific division rounding is not supported yet.
......@@ -450,7 +446,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
}
// Dot is 2D only
if (TI(ngraph::op::Dot) == TI(*node))
if (is_type<ngraph::op::Dot>(node))
{
if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2)
{
......@@ -462,10 +458,9 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
}
}
if (TI(ngraph::op::Convolution) == TI(*node))
if (auto conv_node = as_type_ptr<ngraph::op::Convolution>(node))
{
// No padding for now
auto conv_node = static_cast<ngraph::op::Convolution*>(node.get());
auto pad_below = conv_node->get_padding_below();
auto pad_above = conv_node->get_padding_above();
auto data_dilation = conv_node->get_data_dilation_strides();
......@@ -478,14 +473,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
}
// MKLDNN only supports softmax across single axis
if (TI(ngraph::op::Softmax) == TI(*node))
if (auto softmax = as_type_ptr<ngraph::op::Softmax>(node))
{
// Softmax is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto softmax = static_cast<ngraph::op::Softmax*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -493,14 +487,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32 && softmax->get_axes().size() == 1;
}
if (TI(ngraph::op::AvgPool) == TI(*node))
if (auto avg_pool = as_type_ptr<ngraph::op::AvgPool>(node))
{
// AvgPool is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto avg_pool = static_cast<ngraph::op::AvgPool*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -509,14 +502,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::AvgPoolBackprop) == TI(*node))
if (auto avg_pool_backprop = as_type_ptr<ngraph::op::AvgPoolBackprop>(node))
{
// AvgPoolBackprop is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto avg_pool_backprop = static_cast<ngraph::op::AvgPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -525,14 +517,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::MaxPoolBackprop) == TI(*node))
if (auto max_pool_backprop = as_type_ptr<ngraph::op::MaxPoolBackprop>(node))
{
// MaxPoolBackprop is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto max_pool_backprop = static_cast<ngraph::op::MaxPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -541,14 +532,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::MaxPool) == TI(*node))
if (auto max_pool = as_type_ptr<ngraph::op::MaxPool>(node))
{
// MaxPool is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{
return false;
}
auto max_pool = static_cast<ngraph::op::MaxPool*>(node.get());
auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size();
......@@ -557,7 +547,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32;
}
if (TI(ngraph::op::MatMul) == TI(*node))
if (is_type<ngraph::op::MatMul>(node))
{
// MatMul is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
......@@ -566,7 +556,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
}
}
if (TI(ngraph::op::Gemm) == TI(*node))
if (is_type<ngraph::op::Gemm>(node))
{
// Gemm is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
......@@ -584,7 +574,11 @@ void MLIRSubgraphExtractionPass::clean_up()
m_node_to_graph.clear();
}
const std::set<std::type_index> MLIRSubgraphExtractionPass::m_supported_ops{
#define MLIR_OP(OP) TI(ngraph::op::OP),
const std::set<ngraph::Node::type_info_t>& MLIRSubgraphExtractionPass::getSupportedOps()
{
static std::set<Node::type_info_t> supportedOps{
#define MLIR_OP(OP) OP::type_info,
#include "contrib/mlir/core/ops_supported.inc"
};
};
return supportedOps;
}
......@@ -107,6 +107,7 @@ namespace ngraph
void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes);
void clean_up();
static const std::set<ngraph::Node::type_info_t>& getSupportedOps();
private:
using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
......@@ -115,7 +116,6 @@ namespace ngraph
NodeGraphMap m_node_to_graph;
// Mutex over sub-graph IDs
std::mutex m_subgraph_mutex;
static const std::set<std::type_index> m_supported_ops;
};
}
}
......@@ -130,12 +130,12 @@ namespace
using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>;
using MLIRCompOpFunction = std::function<mlir::Operation*(
NgDialectConversionPass& NgDialectObj, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
using MLIRCompOpMap = std::unordered_map<Node::type_info_t, MLIRCompOpFunction>;
// Maps tensor to the value it represents in the IR
// use for MLIR dialect gen
TensorToInfoMap m_tensorToValueMap;
static const MLIRCompOpMap opDispatcher;
static const MLIRCompOpMap& getOpDispatcher();
};
} // end of namespace
......@@ -291,9 +291,10 @@ void NgDialectConversionPass::buildNgDialect(mlir::FuncOp function)
m_builder.setInsertionPoint(&region.front(), region.front().begin());
const NodeVector& subGraph = m_compiledKernel->get_node_list();
auto& opDispatcher = getOpDispatcher();
for (auto np : subGraph)
{
auto it = opDispatcher.find(TI(*np));
auto it = opDispatcher.find(np->get_type_info());
if (it == opDispatcher.end())
{
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
......@@ -655,10 +656,14 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
.getOperation();
}
const NgDialectConversionPass::MLIRCompOpMap NgDialectConversionPass::opDispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &NgDialectConversionPass::createOp<ngraph::op::OP>},
const NgDialectConversionPass::MLIRCompOpMap& NgDialectConversionPass::getOpDispatcher()
{
static MLIRCompOpMap opDispatcher{
#define MLIR_OP(OP) {ngraph::op::OP::type_info, &NgDialectConversionPass::createOp<ngraph::op::OP>},
#include "contrib/mlir/core/ops_supported.inc"
};
};
return opDispatcher;
}
void NgDialectConversionPass::createReturn()
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment