Commit dd66409a authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

Enable multitple int types for OneHot 'depth' param (#3971)

* Enable multitple int types for OneHot 'depth' param

* Move validation logic out of helper

* Make helper more generic
parent 867a8b67
......@@ -162,7 +162,32 @@ void op::v1::OneHot::validate_and_infer_types()
}
m_axis =
ngraph::normalize_axis(this, m_axis, indices_rank + 1, -indices_rank - 1, indices_rank);
int64_t depth_val = as_type_ptr<op::Constant>(depth)->get_vector<int64_t>()[0];
auto depth_element_type = depth->get_output_element_type(0);
NODE_VALIDATION_CHECK(this,
depth_element_type == element::i8 ||
depth_element_type == element::i32 ||
depth_element_type == element::i64,
"'depth' input element type must be i8, i32 or i64 (got ",
depth_element_type,
").");
NODE_VALIDATION_CHECK(this,
is_scalar(depth->get_shape()),
"A scalar input should be provided as 'depth' to OneHot",
" (got ",
depth->get_shape(),
" elements).");
int64_t depth_val = read_scalar_int_from_constant_node(depth);
NODE_VALIDATION_CHECK(this,
depth_val > 0,
"The value of 'depth' must be a positive number.",
" (got ",
depth_val,
").");
out_dims.insert(out_dims.begin() + m_axis, Dimension(depth_val));
result_shape = out_dims;
}
......@@ -176,3 +201,38 @@ shared_ptr<Node> op::v1::OneHot::copy_with_new_args(const NodeVector& new_args)
return make_shared<v1::OneHot>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_axis);
}
size_t op::v1::OneHot::read_scalar_int_from_constant_node(const shared_ptr<Node>& node) const
{
size_t scalar;
auto node_element_type = node->get_output_element_type(0);
const auto constant = as_type_ptr<op::Constant>(node);
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(node_element_type))
{
case element::Type_t::i8:
scalar = static_cast<size_t>(constant->get_vector<int8_t>()[0]);
break;
case element::Type_t::i32:
scalar = static_cast<size_t>(constant->get_vector<int32_t>()[0]);
break;
case element::Type_t::i64:
scalar = static_cast<size_t>(constant->get_vector<int64_t>()[0]);
break;
default:
NODE_VALIDATION_CHECK(node.get(),
false,
"Expected integer input of element type i8, i32 or i64 (got ",
node_element_type,
").");
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return scalar;
}
......@@ -107,6 +107,8 @@ namespace ngraph
void set_axis(int64_t axis) { m_axis = axis; }
protected:
int64_t m_axis;
size_t read_scalar_int_from_constant_node(const std::shared_ptr<Node>& node) const;
};
}
// default opset version
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment