Commit c8988ca9 authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

Helper in Constant to allow casting values to a different type (#4000)

* Helper in Constant to allow casting values to a different type

Simplify logic needed to extract values from a Constant node, when
the expected data type is specified only as integral or floating point.

* Review comment

* Review comment
Co-Authored-By: 's avatarTomasz Socha <tomasz.socha@intel.com>

* Style apply
parent f803feb7
...@@ -240,6 +240,87 @@ namespace ngraph ...@@ -240,6 +240,87 @@ namespace ngraph
return rc; return rc;
} }
/// \brief Return the Constant's value as a vector cast to type T
///
/// \tparam T Type to which data vector's entries will be cast.
/// \return Constant's data vector.
template <typename T>
std::vector<T> cast_vector() const
{
auto source_type = get_element_type();
switch (source_type)
{
case element::Type_t::boolean:
{
auto vector = get_vector<char>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::bf16:
{
auto vector = get_vector<bfloat16>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f16:
{
auto vector = get_vector<float16>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f32:
{
auto vector = get_vector<float>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::f64:
{
auto vector = get_vector<double>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i8:
{
auto vector = get_vector<int8_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i16:
{
auto vector = get_vector<int16_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i32:
{
auto vector = get_vector<int32_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::i64:
{
auto vector = get_vector<int64_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u8:
{
auto vector = get_vector<uint8_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u16:
{
auto vector = get_vector<uint16_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u32:
{
auto vector = get_vector<uint32_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u64:
{
auto vector = get_vector<uint64_t>();
return std::vector<T>(vector.begin(), vector.end());
}
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
}
const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); } const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); }
template <typename T> template <typename T>
const T* get_data_ptr() const const T* get_data_ptr() const
......
...@@ -165,10 +165,8 @@ void op::v1::OneHot::validate_and_infer_types() ...@@ -165,10 +165,8 @@ void op::v1::OneHot::validate_and_infer_types()
auto depth_element_type = depth->get_output_element_type(0); auto depth_element_type = depth->get_output_element_type(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
depth_element_type == element::i8 || depth_element_type.is_integral(),
depth_element_type == element::i32 || "'depth' input element type must be an integer (got ",
depth_element_type == element::i64,
"'depth' input element type must be i8, i32 or i64 (got ",
depth_element_type, depth_element_type,
")."); ").");
...@@ -179,7 +177,8 @@ void op::v1::OneHot::validate_and_infer_types() ...@@ -179,7 +177,8 @@ void op::v1::OneHot::validate_and_infer_types()
depth->get_shape(), depth->get_shape(),
" elements)."); " elements).");
int64_t depth_val = read_scalar_int_from_constant_node(depth); const auto depth_constant = as_type_ptr<op::Constant>(depth);
int64_t depth_val = depth_constant->cast_vector<int64_t>()[0];
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
depth_val > 0, depth_val > 0,
...@@ -201,38 +200,3 @@ shared_ptr<Node> op::v1::OneHot::copy_with_new_args(const NodeVector& new_args) ...@@ -201,38 +200,3 @@ shared_ptr<Node> op::v1::OneHot::copy_with_new_args(const NodeVector& new_args)
return make_shared<v1::OneHot>( return make_shared<v1::OneHot>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis); new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis);
} }
size_t op::v1::OneHot::read_scalar_int_from_constant_node(const shared_ptr<Node>& node) const
{
size_t scalar;
auto node_element_type = node->get_output_element_type(0);
const auto constant = as_type_ptr<op::Constant>(node);
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(node_element_type))
{
case element::Type_t::i8:
scalar = static_cast<size_t>(constant->get_vector<int8_t>()[0]);
break;
case element::Type_t::i32:
scalar = static_cast<size_t>(constant->get_vector<int32_t>()[0]);
break;
case element::Type_t::i64:
scalar = static_cast<size_t>(constant->get_vector<int64_t>()[0]);
break;
default:
NODE_VALIDATION_CHECK(node.get(),
false,
"Expected integer input of element type i8, i32 or i64 (got ",
node_element_type,
").");
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return scalar;
}
...@@ -107,8 +107,6 @@ namespace ngraph ...@@ -107,8 +107,6 @@ namespace ngraph
void set_axis(int64_t axis) { m_axis = axis; } void set_axis(int64_t axis) { m_axis = axis; }
protected: protected:
int64_t m_axis; int64_t m_axis;
size_t read_scalar_int_from_constant_node(const std::shared_ptr<Node>& node) const;
}; };
} }
// default opset version // default opset version
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment