Unverified Commit cf21a361 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Fix compilation warnings (#4236)

* Fix compilation warnings

* Camels
parent ca955d46
...@@ -346,7 +346,7 @@ namespace ...@@ -346,7 +346,7 @@ namespace
unsigned inputCount = f.getType().getNumInputs(); unsigned inputCount = f.getType().getNumInputs();
// we find out output values by looking at returned values // we find out output values by looking at returned values
// any return should return all outputs of the subgraph // any return should return all outputs of the subgraph
f.walk([this, &outputCount, inputCount](NGReturnOp ret) { f.walk([&outputCount, inputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++) for (unsigned i = 0; i < ret.getNumOperands(); i++)
{ {
// annotate instructions defining outputs with the arg idx of the output // annotate instructions defining outputs with the arg idx of the output
...@@ -1324,7 +1324,7 @@ namespace ...@@ -1324,7 +1324,7 @@ namespace
} }
attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n; attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n;
int broadcastHint; int broadcastHint = -2;
if (vBias.rank() == 0) if (vBias.rank() == 0)
{ {
// Scalar // Scalar
...@@ -1356,6 +1356,7 @@ namespace ...@@ -1356,6 +1356,7 @@ namespace
broadcastHint = 0; broadcastHint = 0;
} }
} }
NGRAPH_CHECK(broadcastHint != -2, "Unhandled broadcast");
attrs.gemmAttrs2d.broadcastHint = broadcastHint; attrs.gemmAttrs2d.broadcastHint = broadcastHint;
auto int64Ty = rewriter.getIntegerType(64); auto int64Ty = rewriter.getIntegerType(64);
......
...@@ -27,8 +27,6 @@ using namespace ngraph::descriptor; ...@@ -27,8 +27,6 @@ using namespace ngraph::descriptor;
using namespace ngraph::op; using namespace ngraph::op;
using namespace ngraph::pass; using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0; int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
void MLIRSubgraphExtractionPass::MLIRSubgraph::add_inputs(NodeVector& inputs) void MLIRSubgraphExtractionPass::MLIRSubgraph::add_inputs(NodeVector& inputs)
...@@ -162,11 +160,11 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func) ...@@ -162,11 +160,11 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
for (auto it = nodes_ready.begin(); it != nodes_ready.end();) for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{ {
auto node = *it; auto node = *it;
if (TI(Result) == TI(*node)) if (is_type<Result>(node))
{ {
erase_node(it, nodes_ready); erase_node(it, nodes_ready);
} }
else if (TI(Parameter) == TI(*node)) else if (is_type<Parameter>(node))
{ {
process_successors(node, node_to_size_map, nodes_ready); process_successors(node, node_to_size_map, nodes_ready);
erase_node(it, nodes_ready); erase_node(it, nodes_ready);
...@@ -209,7 +207,7 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func) ...@@ -209,7 +207,7 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
for (auto it = nodes_ready.begin(); it != nodes_ready.end();) for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{ {
auto node = *it; auto node = *it;
if (TI(Result) == TI(*node)) if (is_type<Result>(node))
{ {
erase_node(it, nodes_ready); erase_node(it, nodes_ready);
} }
...@@ -236,7 +234,7 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func) ...@@ -236,7 +234,7 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
for (auto it = nodes_ready.begin(); it != nodes_ready.end();) for (auto it = nodes_ready.begin(); it != nodes_ready.end();)
{ {
auto node = *it; auto node = *it;
if (TI(Result) == TI(*node)) if (is_type<Result>(node))
{ {
erase_node(it, nodes_ready); erase_node(it, nodes_ready);
} }
...@@ -264,10 +262,10 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func) ...@@ -264,10 +262,10 @@ void MLIRSubgraphExtractionPass::build_subgraphs(std::shared_ptr<Function> func)
{ {
MLIRSubgraph& sg = it->second; MLIRSubgraph& sg = it->second;
auto& nodes = sg.get_nodes(); auto& nodes = sg.get_nodes();
NodeVector outputs = std::move(get_subgraph_outputs(NodeVector(nodes.begin(), nodes.end()), NodeVector outputs = get_subgraph_outputs(NodeVector(nodes.begin(), nodes.end()),
{} /*exclusions*/, {} /*exclusions*/,
false /* ignore unused */, false /* ignore unused */,
false /* ignore output duplicates */)); false /* ignore output duplicates */);
sg.add_outputs(outputs); sg.add_outputs(outputs);
} }
} }
...@@ -420,26 +418,24 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No ...@@ -420,26 +418,24 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No
} }
} }
#define TI(x) std::type_index(typeid(x))
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node) bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{ {
if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node)) if (is_type<Parameter>(node) || is_type<Result>(node))
{ {
return true; return true;
} }
// supported by backend ? // supported by backend ?
if (m_supported_ops.find(TI(*node)) == m_supported_ops.end()) auto& supportedOps = getSupportedOps();
if (supportedOps.find(node->get_type_info()) == supportedOps.end())
{ {
return false; return false;
} }
// check on invariants expected by MLIR backend // check on invariants expected by MLIR backend
if (TI(ngraph::op::Divide) == TI(*node)) if (auto div = as_type_ptr<ngraph::op::Divide>(node))
{ {
auto* div = static_cast<ngraph::op::Divide*>(node.get());
if (div->is_pythondiv()) if (div->is_pythondiv())
{ {
// Python specific division rounding is not supported yet. // Python specific division rounding is not supported yet.
...@@ -450,7 +446,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -450,7 +446,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
} }
// Dot is 2D only // Dot is 2D only
if (TI(ngraph::op::Dot) == TI(*node)) if (is_type<ngraph::op::Dot>(node))
{ {
if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2) if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2)
{ {
...@@ -462,10 +458,9 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -462,10 +458,9 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
} }
} }
if (TI(ngraph::op::Convolution) == TI(*node)) if (auto conv_node = as_type_ptr<ngraph::op::Convolution>(node))
{ {
// No padding for now // No padding for now
auto conv_node = static_cast<ngraph::op::Convolution*>(node.get());
auto pad_below = conv_node->get_padding_below(); auto pad_below = conv_node->get_padding_below();
auto pad_above = conv_node->get_padding_above(); auto pad_above = conv_node->get_padding_above();
auto data_dilation = conv_node->get_data_dilation_strides(); auto data_dilation = conv_node->get_data_dilation_strides();
...@@ -478,14 +473,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -478,14 +473,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
} }
// MKLDNN only supports softmax across single axis // MKLDNN only supports softmax across single axis
if (TI(ngraph::op::Softmax) == TI(*node)) if (auto softmax = as_type_ptr<ngraph::op::Softmax>(node))
{ {
// Softmax is only supported through callback // Softmax is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{ {
return false; return false;
} }
auto softmax = static_cast<ngraph::op::Softmax*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -493,14 +487,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -493,14 +487,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32 && softmax->get_axes().size() == 1; node->get_input_element_type(0) == element::f32 && softmax->get_axes().size() == 1;
} }
if (TI(ngraph::op::AvgPool) == TI(*node)) if (auto avg_pool = as_type_ptr<ngraph::op::AvgPool>(node))
{ {
// AvgPool is only supported through callback // AvgPool is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{ {
return false; return false;
} }
auto avg_pool = static_cast<ngraph::op::AvgPool*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -509,14 +502,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -509,14 +502,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32; node->get_input_element_type(0) == element::f32;
} }
if (TI(ngraph::op::AvgPoolBackprop) == TI(*node)) if (auto avg_pool_backprop = as_type_ptr<ngraph::op::AvgPoolBackprop>(node))
{ {
// AvgPoolBackprop is only supported through callback // AvgPoolBackprop is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{ {
return false; return false;
} }
auto avg_pool_backprop = static_cast<ngraph::op::AvgPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -525,14 +517,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -525,14 +517,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32; node->get_input_element_type(0) == element::f32;
} }
if (TI(ngraph::op::MaxPoolBackprop) == TI(*node)) if (auto max_pool_backprop = as_type_ptr<ngraph::op::MaxPoolBackprop>(node))
{ {
// MaxPoolBackprop is only supported through callback // MaxPoolBackprop is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{ {
return false; return false;
} }
auto max_pool_backprop = static_cast<ngraph::op::MaxPoolBackprop*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -541,14 +532,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -541,14 +532,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32; node->get_input_element_type(0) == element::f32;
} }
if (TI(ngraph::op::MaxPool) == TI(*node)) if (auto max_pool = as_type_ptr<ngraph::op::MaxPool>(node))
{ {
// MaxPool is only supported through callback // MaxPool is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
{ {
return false; return false;
} }
auto max_pool = static_cast<ngraph::op::MaxPool*>(node.get());
auto arg0_shape = node->get_input_shape(0); auto arg0_shape = node->get_input_shape(0);
auto arg0_rank = arg0_shape.size(); auto arg0_rank = arg0_shape.size();
...@@ -557,7 +547,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -557,7 +547,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
node->get_input_element_type(0) == element::f32; node->get_input_element_type(0) == element::f32;
} }
if (TI(ngraph::op::MatMul) == TI(*node)) if (is_type<ngraph::op::MatMul>(node))
{ {
// MatMul is only supported through callback // MatMul is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
...@@ -566,7 +556,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -566,7 +556,7 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
} }
} }
if (TI(ngraph::op::Gemm) == TI(*node)) if (is_type<ngraph::op::Gemm>(node))
{ {
// Gemm is only supported through callback // Gemm is only supported through callback
if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr) if (std::getenv("NGRAPH_MLIR_CALLBACK") == nullptr)
...@@ -584,7 +574,11 @@ void MLIRSubgraphExtractionPass::clean_up() ...@@ -584,7 +574,11 @@ void MLIRSubgraphExtractionPass::clean_up()
m_node_to_graph.clear(); m_node_to_graph.clear();
} }
const std::set<std::type_index> MLIRSubgraphExtractionPass::m_supported_ops{ const std::set<ngraph::Node::type_info_t>& MLIRSubgraphExtractionPass::getSupportedOps()
#define MLIR_OP(OP) TI(ngraph::op::OP), {
static std::set<Node::type_info_t> supportedOps{
#define MLIR_OP(OP) OP::type_info,
#include "contrib/mlir/core/ops_supported.inc" #include "contrib/mlir/core/ops_supported.inc"
}; };
return supportedOps;
}
...@@ -107,6 +107,7 @@ namespace ngraph ...@@ -107,6 +107,7 @@ namespace ngraph
void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes); void sanity_check(std::shared_ptr<Function> func, NodeVector& ck_nodes);
void clean_up(); void clean_up();
static const std::set<ngraph::Node::type_info_t>& getSupportedOps();
private: private:
using IDGraphMap = std::unordered_map<int, MLIRSubgraph>; using IDGraphMap = std::unordered_map<int, MLIRSubgraph>;
...@@ -115,7 +116,6 @@ namespace ngraph ...@@ -115,7 +116,6 @@ namespace ngraph
NodeGraphMap m_node_to_graph; NodeGraphMap m_node_to_graph;
// Mutex over sub-graph IDs // Mutex over sub-graph IDs
std::mutex m_subgraph_mutex; std::mutex m_subgraph_mutex;
static const std::set<std::type_index> m_supported_ops;
}; };
} }
} }
...@@ -130,12 +130,12 @@ namespace ...@@ -130,12 +130,12 @@ namespace
using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>; using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>;
using MLIRCompOpFunction = std::function<mlir::Operation*( using MLIRCompOpFunction = std::function<mlir::Operation*(
NgDialectConversionPass& NgDialectObj, const ngraph::Node*)>; NgDialectConversionPass& NgDialectObj, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>; using MLIRCompOpMap = std::unordered_map<Node::type_info_t, MLIRCompOpFunction>;
// Maps tensor to the value it represents in the IR // Maps tensor to the value it represents in the IR
// use for MLIR dialect gen // use for MLIR dialect gen
TensorToInfoMap m_tensorToValueMap; TensorToInfoMap m_tensorToValueMap;
static const MLIRCompOpMap opDispatcher; static const MLIRCompOpMap& getOpDispatcher();
}; };
} // end of namespace } // end of namespace
...@@ -291,9 +291,10 @@ void NgDialectConversionPass::buildNgDialect(mlir::FuncOp function) ...@@ -291,9 +291,10 @@ void NgDialectConversionPass::buildNgDialect(mlir::FuncOp function)
m_builder.setInsertionPoint(&region.front(), region.front().begin()); m_builder.setInsertionPoint(&region.front(), region.front().begin());
const NodeVector& subGraph = m_compiledKernel->get_node_list(); const NodeVector& subGraph = m_compiledKernel->get_node_list();
auto& opDispatcher = getOpDispatcher();
for (auto np : subGraph) for (auto np : subGraph)
{ {
auto it = opDispatcher.find(TI(*np)); auto it = opDispatcher.find(np->get_type_info());
if (it == opDispatcher.end()) if (it == opDispatcher.end())
{ {
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} + throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
...@@ -655,10 +656,14 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng ...@@ -655,10 +656,14 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
.getOperation(); .getOperation();
} }
const NgDialectConversionPass::MLIRCompOpMap NgDialectConversionPass::opDispatcher{ const NgDialectConversionPass::MLIRCompOpMap& NgDialectConversionPass::getOpDispatcher()
#define MLIR_OP(OP) {TI(ngraph::op::OP), &NgDialectConversionPass::createOp<ngraph::op::OP>}, {
static MLIRCompOpMap opDispatcher{
#define MLIR_OP(OP) {ngraph::op::OP::type_info, &NgDialectConversionPass::createOp<ngraph::op::OP>},
#include "contrib/mlir/core/ops_supported.inc" #include "contrib/mlir/core/ops_supported.inc"
}; };
return opDispatcher;
}
void NgDialectConversionPass::createReturn() void NgDialectConversionPass::createReturn()
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment