Commit 6e6c23ff authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Add method so that element::Type can be cast into enum (#3386)

parent 34ae1ee4
......@@ -193,7 +193,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (type.get_type_enum())
switch (type)
{
case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic:
......
......@@ -59,7 +59,7 @@ string op::Constant::convert_value_to_string(size_t index) const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (get_element_type().get_type_enum())
switch (get_element_type())
{
case element::Type_t::boolean: rc = to_string(get_vector<char>()[index]); break;
case element::Type_t::bf16:
......@@ -96,7 +96,7 @@ vector<string> op::Constant::get_value_strings() const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (get_element_type().get_type_enum())
switch (get_element_type())
{
case element::Type_t::boolean:
for (int value : get_vector<char>())
......@@ -292,7 +292,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (get_element_type().get_type_enum())
switch (get_element_type())
{
case element::Type_t::boolean:
case element::Type_t::i8:
......
......@@ -289,7 +289,7 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (target_type.get_type_enum())
switch (target_type)
{
case element::Type_t::boolean:
write_buffer<char, T>(target, source, target_element_count);
......
......@@ -219,7 +219,7 @@ void op::Range::validate_and_infer_types()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (result_et.get_type_enum())
switch (result_et)
{
case element::Type_t::bf16: result_shape = infer_output_shape<bfloat16>(this, result_et); break;
case element::Type_t::f16: result_shape = infer_output_shape<float16>(this, result_et); break;
......
......@@ -42,7 +42,7 @@ shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Max::get_default_value() const
{
switch (get_element_type().get_type_enum())
switch (get_element_type())
{
case element::Type_t::boolean:
return make_constant_from_string("0", get_element_type(), get_shape());
......
......@@ -42,7 +42,7 @@ shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::Min::get_default_value() const
{
switch (get_element_type().get_type_enum())
switch (get_element_type())
{
case element::Type_t::boolean:
return make_constant_from_string("1", get_element_type(), get_shape());
......
......@@ -185,7 +185,7 @@ void pass::ConstantFolding::construct_constant_reshape()
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
......@@ -316,7 +316,7 @@ void pass::ConstantFolding::construct_constant_pad()
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_pad_callback");
......@@ -418,7 +418,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
std::shared_ptr<Node> replacement;
auto type = dyn_reshape_match->get_element_type();
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
......@@ -532,7 +532,7 @@ void pass::ConstantFolding::construct_constant_transpose()
std::shared_ptr<Node> replacement;
auto type = transpose_match->get_element_type();
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
......@@ -664,7 +664,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
std::shared_ptr<Node> replacement;
auto type = broadcast_match->get_element_type();
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
......@@ -774,7 +774,7 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
std::shared_ptr<Node> replacement;
auto type = dyn_broadcast_match->get_output_element_type(0);
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
......@@ -1079,7 +1079,7 @@ shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out
shared_ptr<Node> binary,
NodeExecutorTy func)
{
switch (et_out.get_type_enum())
switch (et_out)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback");
......@@ -1159,7 +1159,7 @@ void pass::ConstantFolding::construct_constant_binary()
std::shared_ptr<Node> replacement;
auto in_type = a_match->get_output_element_type(0);
auto out_type = binary_match->get_output_element_type(0);
switch (in_type.get_type_enum())
switch (in_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback");
......@@ -1355,7 +1355,7 @@ void pass::ConstantFolding::construct_constant_unary()
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_unary_callback");
......@@ -1598,7 +1598,7 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (output_element_type.get_type_enum())
switch (output_element_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
......@@ -1655,7 +1655,7 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum())
switch (input_element_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
......@@ -1788,7 +1788,7 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum())
switch (input_element_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_convert");
......@@ -1912,7 +1912,7 @@ static shared_ptr<op::Constant>
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
switch (input_element_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
......@@ -2113,7 +2113,7 @@ void pass::ConstantFolding::construct_constant_concat()
std::shared_ptr<op::Constant> replacement;
switch (concat_node->get_output_element_type(0).get_type_enum())
switch (concat_node->get_output_element_type(0))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_concat");
......@@ -2199,7 +2199,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta
{
auto indices_type = indices->get_output_element_type(0);
switch (indices_type.get_type_enum())
switch (indices_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_gather_callback");
......@@ -2255,7 +2255,7 @@ void pass::ConstantFolding::construct_constant_gather()
std::shared_ptr<Node> replacement;
auto data_type = data->get_output_element_type(0);
auto indices_type = indices->get_output_element_type(0);
switch (data_type.get_type_enum())
switch (data_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_gather_callback");
......@@ -2351,7 +2351,7 @@ void pass::ConstantFolding::construct_constant_slice()
std::shared_ptr<op::Constant> replacement;
switch (slice->get_output_element_type(0).get_type_enum())
switch (slice->get_output_element_type(0))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_slice");
......@@ -2489,7 +2489,7 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
std::shared_ptr<op::Constant> replacement;
switch (dyn_slice->get_output_element_type(0).get_type_enum())
switch (dyn_slice->get_output_element_type(0))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_dyn_slice");
......@@ -2600,7 +2600,7 @@ void pass::ConstantFolding::construct_constant_range()
std::shared_ptr<op::Constant> replacement;
switch (range->get_output_element_type(0).get_type_enum())
switch (range->get_output_element_type(0))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_range_callback");
......@@ -2700,7 +2700,7 @@ void pass::ConstantFolding::construct_constant_select()
std::shared_ptr<op::Constant> replacement;
switch (select->get_output_element_type(0).get_type_enum())
switch (select->get_output_element_type(0))
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_select_callback");
......
......@@ -390,7 +390,7 @@ void pass::DynElimination::construct_range()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (et.get_type_enum())
switch (et)
{
case element::Type_t::bf16:
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
......
......@@ -36,7 +36,7 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto count = static_cast<int>(out[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum();
auto data_type = args[0].get_element_type();
const ngraph::op::AllReduce* allreduce =
static_cast<const ngraph::op::AllReduce*>(node);
auto reduce_type = allreduce->get_reduce_type();
......
......@@ -33,7 +33,7 @@ namespace ngraph
auto arg_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto count = static_cast<int>(args[0].get_size());
auto data_type = args[0].get_element_type().get_type_enum();
auto data_type = args[0].get_element_type();
auto broadcast = static_cast<const ngraph::op::BroadcastDistributed*>(node);
auto root_id = broadcast->get_root_id();
auto functor = [&, count, data_type, arg_buffer_index, root_id](
......
......@@ -1261,7 +1261,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace
const std::string& tensor_name,
const std::string& in_out)
{
switch (t_attrs.m_type_of_element.get_type_enum())
switch (t_attrs.m_type_of_element)
{
case element::Type_t::f32:
debug_tracer.dump_one_tensor<float>(kernel_name,
......
......@@ -213,7 +213,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
const vector<shared_ptr<HostTensor>>& in)
{
stringstream ss;
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::boolean: op_engine<char>(op, out, in); break;
case element::Type_t::f32: op_engine<float>(op, out, in); break;
......
......@@ -267,7 +267,7 @@ private:
static_cast<const ngraph::op::AllReduce*>(&node);
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
allreduce->get_reduce_type(),
static_cast<int>(shape_size(node.get_input_shape(0))));
break;
......@@ -504,7 +504,7 @@ private:
{
reference::broadcastdistributed<T>(
args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
......@@ -514,7 +514,7 @@ private:
{
reference::broadcastdistributed<T>(
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
}
......@@ -559,7 +559,7 @@ private:
element::Type type = node.get_element_type();
std::stringstream ss;
size_t element_count = shape_size(node.get_output_shape(0));
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::boolean:
reference::convert_to_bool<T>(
......@@ -1300,10 +1300,8 @@ private:
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
reference::recv<T>(
args[0]->get_data_ptr<T>(), node.get_input_element_type(0), element_count, src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
......@@ -1467,7 +1465,7 @@ private:
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
element_count,
dest_id);
......
......@@ -54,7 +54,7 @@ bool runtime::intelgpu::IntelGPULayout::
cldnn::data_types
runtime::intelgpu::IntelGPULayout::get_cldnn_type(const element::Type& element_type)
{
switch (element_type.get_type_enum())
switch (element_type)
{
case element::Type_t::i8:
case element::Type_t::boolean: return cldnn::data_types::i8;
......@@ -118,7 +118,7 @@ cldnn::layout runtime::intelgpu::IntelGPULayout::create_cldnn_layout(
const cldnn::tensor tensor = create_cldnn_tensor(element_shape);
cldnn::data_types data_type;
switch (element_type.get_type_enum())
switch (element_type)
{
case element::Type_t::i16:
case element::Type_t::u16:
......
......@@ -33,7 +33,7 @@ using namespace ngraph::runtime::intelgpu;
string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
{
switch (ngraph_type.get_type_enum())
switch (ngraph_type)
{
case element::Type_t::i64: return "long";
case element::Type_t::u64: return "ulong";
......@@ -52,7 +52,7 @@ string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
string runtime::intelgpu::get_opencl_type_min_max_value(const element::Type& ngraph_type,
bool is_min)
{
switch (ngraph_type.get_type_enum())
switch (ngraph_type)
{
case element::Type_t::f32: return is_min ? "-INFINITY" : "INFINITY";
case element::Type_t::f64: return is_min ? "-INFINITY" : "INFINITY";
......@@ -1839,9 +1839,8 @@ void runtime::intelgpu::do_convert_operation(cldnn::topology& topology,
{
gws = generate_loops(writer, output_shape, true);
if (((input_type.get_type_enum() == element::Type_t::f64) ||
(input_type.get_type_enum() == element::Type_t::f32)) &&
(output_type.get_type_enum() != element::Type_t::boolean))
if (((input_type == element::Type_t::f64) || (input_type == element::Type_t::f32)) &&
(output_type != element::Type_t::boolean))
{
// this is the workaround for OpenCL to be same as with CPU floating point operations
writer << input_type_name << " input_var = input0" << access_dims(output_shape) << ";\n"
......
......@@ -215,7 +215,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
const vector<shared_ptr<HostTensor>>& in)
{
stringstream ss;
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::boolean: op_engine<char>(op, out, in); break;
case element::Type_t::f32: op_engine<float>(op, out, in); break;
......
......@@ -294,7 +294,7 @@ private:
static_cast<const ngraph::op::AllReduce*>(&node);
reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
allreduce->get_reduce_type(),
static_cast<int>(shape_size(node.get_input_shape(0))));
break;
......@@ -530,7 +530,7 @@ private:
{
reference::broadcastdistributed<T>(
args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
......@@ -540,7 +540,7 @@ private:
{
reference::broadcastdistributed<T>(
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
}
......@@ -585,7 +585,7 @@ private:
element::Type type = node.get_element_type();
std::stringstream ss;
size_t element_count = shape_size(node.get_output_shape(0));
switch (type.get_type_enum())
switch (type)
{
case element::Type_t::boolean:
reference::convert_to_bool<T>(
......@@ -1349,10 +1349,8 @@ private:
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
reference::recv<T>(
args[0]->get_data_ptr<T>(), node.get_input_element_type(0), element_count, src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
......@@ -1516,7 +1514,7 @@ private:
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
node.get_input_element_type(0),
element_count,
dest_id);
......
......@@ -26,6 +26,7 @@
#include <string>
#include <vector>
#include "ngraph/deprecated.hpp"
#include "ngraph/except.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/type/bfloat16.hpp"
......@@ -73,7 +74,10 @@ namespace ngraph
const std::string& cname);
~Type() {}
Type& operator=(const Type&) = default;
Type_t get_type_enum() const { return m_type; }
NGRAPH_DEPRECATED("Use operator Type_t()") Type_t get_type_enum() const
{
return m_type;
}
const std::string& c_type_string() const;
size_t size() const;
size_t hash() const;
......@@ -119,6 +123,8 @@ namespace ngraph
/// does nothing to dst, and returns false
static bool merge(element::Type& dst, const element::Type& t1, const element::Type& t2);
// \brief This allows switch(element_type)
operator Type_t() const { return m_type; }
private:
Type_t m_type{Type_t::undefined};
};
......
......@@ -85,7 +85,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor)
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (et.get_type_enum())
switch (et)
{
case element::Type_t::boolean: init_int_tensor<char>(tensor, 0, 1); break;
case element::Type_t::f32: init_real_tensor<float>(tensor, -1, 1); break;
......
......@@ -39,14 +39,14 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
auto result_shape = result_tensor->get_shape();
EXPECT_EQ(expected_shape, result_shape);
if (m_value_comparators.count(element_type.get_type_enum()) == 0)
if (m_value_comparators.count(element_type) == 0)
{
NGRAPH_FAIL() << "Please add support for " << element_type
<< " to ngraph::test::NgraphTestCase::run()";
}
else
{
auto values_match = m_value_comparators.at(element_type.get_type_enum());
auto values_match = m_value_comparators.at(element_type);
EXPECT_TRUE(values_match(expected_result_constant, result_tensor));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment