Unverified Commit 5803d20f authored by Michał Karzyński's avatar Michał Karzyński Committed by GitHub

Update OneHot to use v1 ops only (#4333)

Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 7f98942d
......@@ -18,9 +18,7 @@
#include <memory>
#include "default_opset.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "onehot.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
......@@ -38,13 +36,15 @@ namespace ngraph
std::make_shared<default_opset::Convert>(inputs.at(0), element::i64);
auto depth = reshape::interpret_as_scalar(inputs.at(1));
// Rank 1 tensor containing exactly two elements: [off_value, on_value]
auto values = inputs.at(2);
std::shared_ptr<ngraph::Node> off_value =
reshape::interpret_as_scalar(std::make_shared<ngraph::opset0::Slice>(
values, Coordinate{0}, Coordinate{1}));
std::shared_ptr<ngraph::Node> on_value =
reshape::interpret_as_scalar(std::make_shared<ngraph::opset0::Slice>(
values, Coordinate{1}, Coordinate{2}));
auto split_axis = default_opset::Constant::create(element::i64, {}, {0});
auto off_on_values =
std::make_shared<default_opset::Split>(values, split_axis, 2);
auto off_value =
reshape::interpret_as_scalar(get_output_element(off_on_values, 0ul));
auto on_value =
reshape::interpret_as_scalar(get_output_element(off_on_values, 1ul));
auto axis = node.get_attribute_value<std::int64_t>("axis", -1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment