Unverified Commit 9f1afc65 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by GitHub

Add provenance tags to opset transformation passes (#4198)

* Add provenance tag in opset0 downgrade pass

* Add provenance tag in opset1 upgrade pass

* Type name as string

* Change op_cast to return replacement_node instead of bool

* Add provenance tags to all nodes created while downgrade

* Add provenance tags to all nodes created while upgrade

* Comments

* Style apply

* Update ONNX import provenance test function

* Add const statement

* Add upgrade/downgrade provenance tag tests

* Update tests

* Style apply

* Provenance enabled check

* Removed redundant add_tag

* Test for add_provenance_tags above

* Add graph test for provenance tags in transformation pass

* Use EXPECT_TRUE and EXPECT_FALSE instead of EXPECT_EQ

* Return replacement node directly

* Style apply

* Test downgrade provenance tag with ONNX importer

* Update test/onnx/onnx_import_provenance.in.cpp
Co-Authored-By: 's avatarMichał Karzyński <postrational@users.noreply.github.com>

* Test update

* Update provenance test to check node type occurence

* Style apply
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 3409bda8
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp" #include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/opset0_downgrade.hpp" #include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/slice_plan.hpp" #include "ngraph/slice_plan.hpp"
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
...@@ -37,17 +38,18 @@ using namespace ngraph; ...@@ -37,17 +38,18 @@ using namespace ngraph;
namespace namespace
{ {
template <typename OpV0, typename OpV1> template <typename OpV0, typename OpV1>
void op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node) shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
{ {
const auto input_arg0 = node->input_value(0); const auto input_arg0 = node->input_value(0);
const auto input_arg1 = node->input_value(1); const auto input_arg1 = node->input_value(1);
const auto autob = node->get_autob(); const auto autob = node->get_autob();
auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob); auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return replacement_node;
} }
template <typename OpV0, typename OpV1> template <typename OpV0, typename OpV1>
void op_cast_reduction_node(const shared_ptr<OpV1>& node) shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
{ {
auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1)); auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims()) if (node->get_keep_dims())
...@@ -85,17 +87,17 @@ namespace ...@@ -85,17 +87,17 @@ namespace
{ {
replace_node(node, replacement_node); replace_node(node, replacement_node);
} }
return replacement_node;
} }
// Default is that we did nothing // Default is that we did nothing
bool op_cast(shared_ptr<Node> node) { return false; } shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
bool op_cast(shared_ptr<op::v1::Add> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
{ {
op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node); return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::AvgPool> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
{ {
auto const input_arg = node->input_value(0); auto const input_arg = node->input_value(0);
const auto ceil_mode = static_cast<bool>(node->get_rounding_type()); const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
...@@ -115,10 +117,10 @@ namespace ...@@ -115,10 +117,10 @@ namespace
pad_type, pad_type,
ceil_mode); ceil_mode);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::AvgPoolBackprop> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPoolBackprop> node)
{ {
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant()); NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
const auto forward_arg_shape = const auto forward_arg_shape =
...@@ -140,10 +142,10 @@ namespace ...@@ -140,10 +142,10 @@ namespace
padding_above, padding_above,
include_padding_in_avg_computation); include_padding_in_avg_computation);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Broadcast> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
{ {
auto arg = node->input_value(0); auto arg = node->input_value(0);
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant()); NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
...@@ -155,10 +157,10 @@ namespace ...@@ -155,10 +157,10 @@ namespace
make_shared<op::v0::Broadcast>(arg, target_shape, node->get_broadcast_axes().second); make_shared<op::v0::Broadcast>(arg, target_shape, node->get_broadcast_axes().second);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Convolution> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
{ {
const auto data_arg = node->input_value(0); const auto data_arg = node->input_value(0);
const auto filters_arg = node->input_value(1); const auto filters_arg = node->input_value(1);
...@@ -173,10 +175,10 @@ namespace ...@@ -173,10 +175,10 @@ namespace
Strides(num_spatial_dims, 1), Strides(num_spatial_dims, 1),
node->get_auto_pad()); node->get_auto_pad());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
{ {
auto output_shape_node = auto output_shape_node =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr()); as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
...@@ -223,10 +225,10 @@ namespace ...@@ -223,10 +225,10 @@ namespace
node->get_pads_end(), node->get_pads_end(),
Strides(num_spatial_dims, 1)); Strides(num_spatial_dims, 1));
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::ConvolutionBackpropFilters> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropFilters> node)
{ {
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant()); NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant());
auto filters_shape = auto filters_shape =
...@@ -246,10 +248,10 @@ namespace ...@@ -246,10 +248,10 @@ namespace
node->get_pads_end(), node->get_pads_end(),
Strides(num_spatial_dims, 1)); Strides(num_spatial_dims, 1));
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Divide> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
{ {
const auto input_arg0 = node->input_value(0); const auto input_arg0 = node->input_value(0);
const auto input_arg1 = node->input_value(1); const auto input_arg1 = node->input_value(1);
...@@ -257,10 +259,10 @@ namespace ...@@ -257,10 +259,10 @@ namespace
const bool pydiv = node->is_pythondiv(); const bool pydiv = node->is_pythondiv();
auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob); auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Reshape> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
{ {
shared_ptr<Node> replacement_node; shared_ptr<Node> replacement_node;
...@@ -282,16 +284,15 @@ namespace ...@@ -282,16 +284,15 @@ namespace
} }
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Equal> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
{ {
op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node); return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::Gather> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Gather> node)
{ {
auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr()); auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
...@@ -309,10 +310,10 @@ namespace ...@@ -309,10 +310,10 @@ namespace
auto replacement_node = auto replacement_node =
make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis); make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::GenerateMask> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::GenerateMask> node)
{ {
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant()); NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto mask_shape = auto mask_shape =
...@@ -327,22 +328,20 @@ namespace ...@@ -327,22 +328,20 @@ namespace
node->input_value(0), mask_shape, et, seed, probability, use_seed); node->input_value(0), mask_shape, et, seed, probability, use_seed);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Greater> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
{ {
op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node); return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::GreaterEqual> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
{ {
op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node); return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::GroupConvolution> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
{ {
const auto data_arg = node->input_value(0); const auto data_arg = node->input_value(0);
const auto filters_arg = node->input_value(1); const auto filters_arg = node->input_value(1);
...@@ -357,10 +356,10 @@ namespace ...@@ -357,10 +356,10 @@ namespace
Strides(num_spatial_dims, 1), Strides(num_spatial_dims, 1),
node->get_auto_pad()); node->get_auto_pad());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
{ {
auto output_shape_input = auto output_shape_input =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr()); as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
...@@ -434,52 +433,47 @@ namespace ...@@ -434,52 +433,47 @@ namespace
pads_end, pads_end,
groups); groups);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Less> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
{ {
op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node); return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::LessEqual> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
{ {
op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node); return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::LogicalAnd> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalAnd> node)
{ {
op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node); return op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::LogicalNot> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
{ {
replace_node(node, make_shared<op::v0::Not>(node->input_value(0))); auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
return true; replace_node(node, replacement_node);
return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::LogicalOr> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
{ {
op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node); return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::LogicalXor> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
{ {
op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node); return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::Maximum> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
{ {
op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node); return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::MaxPool> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::MaxPool> node)
{ {
auto const input_arg = node->input_value(0); auto const input_arg = node->input_value(0);
auto ceil_mode = static_cast<bool>(node->get_rounding_type()); auto ceil_mode = static_cast<bool>(node->get_rounding_type());
...@@ -497,10 +491,10 @@ namespace ...@@ -497,10 +491,10 @@ namespace
pad_type, pad_type,
ceil_mode); ceil_mode);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::MaxPoolBackprop> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::MaxPoolBackprop> node)
{ {
const auto padding_below = node->get_pads_begin(); const auto padding_below = node->get_pads_begin();
const auto padding_above = node->get_pads_end(); const auto padding_above = node->get_pads_end();
...@@ -532,28 +526,25 @@ namespace ...@@ -532,28 +526,25 @@ namespace
padding_above); padding_above);
} }
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Minimum> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
{ {
op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node); return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::Multiply> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
{ {
op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node); return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::NotEqual> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
{ {
op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node); return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::OneHot> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
{ {
const auto indices = node->input_value(0).get_node_shared_ptr(); const auto indices = node->input_value(0).get_node_shared_ptr();
const auto depth = node->input_value(1).get_node_shared_ptr(); const auto depth = node->input_value(1).get_node_shared_ptr();
...@@ -577,10 +568,10 @@ namespace ...@@ -577,10 +568,10 @@ namespace
auto replacement_node = one_hot * (on_value - off_value) + off_value; auto replacement_node = one_hot * (on_value - off_value) + off_value;
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Pad> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Pad> node)
{ {
const auto pad_arg = node->input_value(0); const auto pad_arg = node->input_value(0);
const auto pad_value = node->input_value(3); const auto pad_value = node->input_value(3);
...@@ -588,40 +579,35 @@ namespace ...@@ -588,40 +579,35 @@ namespace
pad_arg, pad_value, node->get_pads_begin(), node->get_pads_end(), node->get_pad_mode()); pad_arg, pad_value, node->get_pads_begin(), node->get_pads_end(), node->get_pad_mode());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Power> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
{ {
op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node); return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceMax> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
{ {
op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node); return op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceMin> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
{ {
op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node); return op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceProd> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
{ {
op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node); return op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::ReduceSum> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
{ {
op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node); return op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::Reverse> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
{ {
auto axes_node = node->input_value(1).get_node_shared_ptr(); auto axes_node = node->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(axes_node->is_constant(), NGRAPH_CHECK(axes_node->is_constant(),
...@@ -648,19 +634,19 @@ namespace ...@@ -648,19 +634,19 @@ namespace
auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes); auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Select> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
{ {
ngraph::pass::ImplicitBroadcastElimination().run_on_node(node); ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
auto replacement_node = make_shared<op::v0::Select>( auto replacement_node = make_shared<op::v0::Select>(
node->input_value(0), node->input_value(1), node->input_value(2)); node->input_value(0), node->input_value(1), node->input_value(2));
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::StridedSlice> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
{ {
auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) { auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
AxisSet axes{}; AxisSet axes{};
...@@ -723,10 +709,10 @@ namespace ...@@ -723,10 +709,10 @@ namespace
} }
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Softmax> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Softmax> node)
{ {
auto axis = node->get_axis(); auto axis = node->get_axis();
auto data = node->input(0); auto data = node->input(0);
...@@ -735,10 +721,10 @@ namespace ...@@ -735,10 +721,10 @@ namespace
std::iota(std::begin(axes), std::end(axes), axis); std::iota(std::begin(axes), std::end(axes), axis);
auto replacement_node = make_shared<op::v0::Softmax>(node->input_value(0), axes); auto replacement_node = make_shared<op::v0::Softmax>(node->input_value(0), axes);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Split> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
{ {
const auto num_splits = node->get_num_splits(); const auto num_splits = node->get_num_splits();
...@@ -746,16 +732,15 @@ namespace ...@@ -746,16 +732,15 @@ namespace
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits); make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Subtract> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
{ {
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node); return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
} }
bool op_cast(shared_ptr<op::v1::TopK> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
{ {
const auto axis = node->get_axis(); const auto axis = node->get_axis();
const auto sort_type = node->get_sort_type(); const auto sort_type = node->get_sort_type();
...@@ -778,10 +763,10 @@ namespace ...@@ -778,10 +763,10 @@ namespace
// values output will be 0, indices 1 // values output will be 0, indices 1
vector<int64_t> output_order{1, 0}; vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order); replace_node(node, replacement_node, output_order);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::Transpose> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
{ {
const auto data = node->input_value(0); const auto data = node->input_value(0);
...@@ -816,10 +801,10 @@ namespace ...@@ -816,10 +801,10 @@ namespace
auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape); auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v1::VariadicSplit> node) shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
{ {
const auto split_lengths = node->input_value(2).get_node_shared_ptr(); const auto split_lengths = node->input_value(2).get_node_shared_ptr();
...@@ -835,7 +820,7 @@ namespace ...@@ -835,7 +820,7 @@ namespace
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned); make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>; using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
...@@ -843,7 +828,18 @@ namespace ...@@ -843,7 +828,18 @@ namespace
template <typename T> template <typename T>
bool op_cast_thunk(shared_ptr<Node> node) bool op_cast_thunk(shared_ptr<Node> node)
{ {
return op_cast(as_type_ptr<T>(node)); auto downgraded_node = op_cast(as_type_ptr<T>(node));
if (downgraded_node)
{
if (ngraph::get_provenance_enabled())
{
const std::string provenance_tag =
"<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
}
return true;
}
return false;
} }
DispatchMap& get_dispatch_map() DispatchMap& get_dispatch_map()
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -29,29 +30,28 @@ using namespace ngraph; ...@@ -29,29 +30,28 @@ using namespace ngraph;
namespace namespace
{ {
template <typename OpV0, typename OpV1> template <typename OpV0, typename OpV1>
void op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node) shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
{ {
const auto autob = node->get_autob(); const auto autob = node->get_autob();
auto replacement_node = auto replacement_node =
make_shared<OpV1>(node->input_value(0), node->input_value(1), autob); make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return replacement_node;
} }
// Default is that we didn nothing // Default is that we didn nothing
bool op_cast(shared_ptr<Node> node) { return false; } shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
bool op_cast(shared_ptr<op::Add> node) shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
{ {
op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node); return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
return true;
} }
bool op_cast(shared_ptr<op::And> node) shared_ptr<Node> op_cast(shared_ptr<op::And> node)
{ {
op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node); return op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
return true;
} }
bool op_cast(shared_ptr<op::AvgPool> node) shared_ptr<Node> op_cast(shared_ptr<op::AvgPool> node)
{ {
auto rounding_mode = auto rounding_mode =
node->get_ceil_mode() ? op::RoundingType::CEIL : op::RoundingType::FLOOR; node->get_ceil_mode() ? op::RoundingType::CEIL : op::RoundingType::FLOOR;
...@@ -82,10 +82,10 @@ namespace ...@@ -82,10 +82,10 @@ namespace
} }
#endif #endif
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::AvgPoolBackprop> node) shared_ptr<Node> op_cast(shared_ptr<op::AvgPoolBackprop> node)
{ {
auto exclude_pad = !node->get_include_padding_in_avg_computation(); auto exclude_pad = !node->get_include_padding_in_avg_computation();
auto pads_begin = node->get_padding_below(); auto pads_begin = node->get_padding_below();
...@@ -101,10 +101,10 @@ namespace ...@@ -101,10 +101,10 @@ namespace
kernel, kernel,
exclude_pad); exclude_pad);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Broadcast> node) shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
{ {
auto result_shape = node->get_broadcast_shape(); auto result_shape = node->get_broadcast_shape();
auto result_shape_node = auto result_shape_node =
...@@ -124,11 +124,11 @@ namespace ...@@ -124,11 +124,11 @@ namespace
auto replacement_node = make_shared<op::v1::Broadcast>( auto replacement_node = make_shared<op::v1::Broadcast>(
node->input_value(0), result_shape_node->output(0), axes_mapping_node->output(0)); node->input_value(0), result_shape_node->output(0), axes_mapping_node->output(0));
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::BroadcastLike> node) { return false; } shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
bool op_cast(shared_ptr<op::Convolution> node) shared_ptr<Node> op_cast(shared_ptr<op::Convolution> node)
{ {
auto strides = node->get_window_movement_strides(); auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides(); auto dilations = node->get_window_dilation_strides();
...@@ -154,10 +154,10 @@ namespace ...@@ -154,10 +154,10 @@ namespace
dilations, dilations,
auto_pad); auto_pad);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::ConvolutionBackpropData> node) shared_ptr<Node> op_cast(shared_ptr<op::ConvolutionBackpropData> node)
{ {
auto data_batch_shape = node->get_data_batch_shape(); auto data_batch_shape = node->get_data_batch_shape();
auto strides = node->get_window_movement_strides_forward(); auto strides = node->get_window_movement_strides_forward();
...@@ -188,10 +188,10 @@ namespace ...@@ -188,10 +188,10 @@ namespace
pads_end, pads_end,
dilations); dilations);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::ConvolutionBackpropFilters> node) shared_ptr<Node> op_cast(shared_ptr<op::ConvolutionBackpropFilters> node)
{ {
auto filters_shape = node->get_filters_shape(); auto filters_shape = node->get_filters_shape();
auto strides = node->get_window_movement_strides_forward(); auto strides = node->get_window_movement_strides_forward();
...@@ -220,35 +220,34 @@ namespace ...@@ -220,35 +220,34 @@ namespace
pads_begin, pads_begin,
pads_end); pads_end);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Divide> node) shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
{ {
const auto autob = node->get_autob(); const auto autob = node->get_autob();
const bool pydiv = node->is_pythondiv(); const bool pydiv = node->is_pythondiv();
auto replacement_node = auto replacement_node =
make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob); make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::DynReshape> node) shared_ptr<Node> op_cast(shared_ptr<op::DynReshape> node)
{ {
auto zero_flag = false; auto zero_flag = false;
auto replacement_node = auto replacement_node =
make_shared<op::v1::Reshape>(node->input_value(0), node->input_value(1), zero_flag); make_shared<op::v1::Reshape>(node->input_value(0), node->input_value(1), zero_flag);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Equal> node) shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
{ {
op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node); return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
return true;
} }
bool op_cast(shared_ptr<op::Gather> node) shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
{ {
int64_t axis = node->get_axis(); int64_t axis = node->get_axis();
...@@ -256,22 +255,20 @@ namespace ...@@ -256,22 +255,20 @@ namespace
auto replacement_node = auto replacement_node =
make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node); make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Greater> node) shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
{ {
op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node); return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
return true;
} }
bool op_cast(shared_ptr<op::GreaterEq> node) shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
{ {
op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node); return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
return true;
} }
bool op_cast(shared_ptr<op::v0::GroupConvolution> node) shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
{ {
auto strides = node->get_window_movement_strides(); auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides(); auto dilations = node->get_window_dilation_strides();
...@@ -324,10 +321,10 @@ namespace ...@@ -324,10 +321,10 @@ namespace
auto_pad); auto_pad);
} }
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node) shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
{ {
auto strides = node->get_window_movement_strides(); auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides(); auto dilations = node->get_window_dilation_strides();
...@@ -364,37 +361,34 @@ namespace ...@@ -364,37 +361,34 @@ namespace
pads_end, pads_end,
dilations); dilations);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Less> node) shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
{ {
op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node); return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
return true;
} }
bool op_cast(shared_ptr<op::LessEq> node) shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
{ {
op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node); return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
return true;
} }
bool op_cast(shared_ptr<op::Max> node) shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
{ {
bool keep_dims = false; bool keep_dims = false;
auto replacement_node = auto replacement_node =
make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims); make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Maximum> node) shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
{ {
op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node); return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
return true;
} }
bool op_cast(shared_ptr<op::MaxPool> node) shared_ptr<Node> op_cast(shared_ptr<op::MaxPool> node)
{ {
auto rounding_type = auto rounding_type =
node->get_ceil_mode() ? op::RoundingType::CEIL : op::RoundingType::FLOOR; node->get_ceil_mode() ? op::RoundingType::CEIL : op::RoundingType::FLOOR;
...@@ -418,10 +412,10 @@ namespace ...@@ -418,10 +412,10 @@ namespace
} }
#endif #endif
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::MaxPoolBackprop> node) shared_ptr<Node> op_cast(shared_ptr<op::MaxPoolBackprop> node)
{ {
auto pads_begin = node->get_padding_below(); auto pads_begin = node->get_padding_below();
auto pads_end = node->get_padding_above(); auto pads_end = node->get_padding_above();
...@@ -445,43 +439,41 @@ namespace ...@@ -445,43 +439,41 @@ namespace
node->input_value(0), node->input_value(1), strides, pads_begin, pads_end, kernel); node->input_value(0), node->input_value(1), strides, pads_begin, pads_end, kernel);
} }
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Min> node) shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
{ {
bool keep_dims = false; bool keep_dims = false;
auto replacement_node = auto replacement_node =
make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims); make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Minimum> node) shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
{ {
op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node); return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
return true;
} }
bool op_cast(shared_ptr<op::Multiply> node) shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
{ {
op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node); return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
return true;
} }
bool op_cast(shared_ptr<op::Not> node) shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
{ {
replace_node(node, make_shared<op::v1::LogicalNot>(node->input_value(0))); auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
return true; replace_node(node, replacement_node);
return replacement_node;
} }
bool op_cast(shared_ptr<op::NotEqual> node) shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
{ {
op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node); return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
return true;
} }
bool op_cast(shared_ptr<op::OneHot> node) shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
{ {
const auto indices = node->input_value(0).get_node_shared_ptr(); const auto indices = node->input_value(0).get_node_shared_ptr();
const auto one_hot_axis = node->get_one_hot_axis(); const auto one_hot_axis = node->get_one_hot_axis();
...@@ -499,16 +491,15 @@ namespace ...@@ -499,16 +491,15 @@ namespace
auto replacement_node = auto replacement_node =
make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis); make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Or> node) shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
{ {
op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node); return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
return true;
} }
bool op_cast(shared_ptr<op::Pad> node) shared_ptr<Node> op_cast(shared_ptr<op::Pad> node)
{ {
auto padding_below = node->get_padding_below(); auto padding_below = node->get_padding_below();
auto pads_begin_node = auto pads_begin_node =
...@@ -524,25 +515,24 @@ namespace ...@@ -524,25 +515,24 @@ namespace
node->get_pad_mode()); node->get_pad_mode());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Power> node) shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
{ {
op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node); return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
return true;
} }
bool op_cast(shared_ptr<op::Product> node) shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
{ {
bool keep_dims = false; bool keep_dims = false;
auto replacement_node = auto replacement_node =
make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims); make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Reverse> node) shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
{ {
// creates a Constant node from the v0::Reverse reversed_axes attribute // creates a Constant node from the v0::Reverse reversed_axes attribute
// and uses it as the second input of v1::Reverse // and uses it as the second input of v1::Reverse
...@@ -551,24 +541,24 @@ namespace ...@@ -551,24 +541,24 @@ namespace
const auto reversed_axes_constant = op::Constant::create( const auto reversed_axes_constant = op::Constant::create(
element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector()); element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
const auto reverse_v1 = make_shared<op::v1::Reverse>( const auto replacement_node = make_shared<op::v1::Reverse>(
node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX); node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
replace_node(node, reverse_v1); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Select> node) shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
{ {
auto replacement_node = make_shared<op::v1::Select>(node->input_value(0), auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
node->input_value(1), node->input_value(1),
node->input_value(2), node->input_value(2),
op::AutoBroadcastSpec()); op::AutoBroadcastSpec());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Softmax> node) shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
{ {
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(), NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"axes parameter is expected to be a static constant"); "axes parameter is expected to be a static constant");
...@@ -583,10 +573,10 @@ namespace ...@@ -583,10 +573,10 @@ namespace
auto replacement_node = auto replacement_node =
make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]); make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Slice> node) shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
{ {
const auto data = node->input_value(0); const auto data = node->input_value(0);
const auto begin = op::Constant::create( const auto begin = op::Constant::create(
...@@ -605,10 +595,10 @@ namespace ...@@ -605,10 +595,10 @@ namespace
vector<int64_t>(input_size, 0)); vector<int64_t>(input_size, 0));
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Split> node) shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
{ {
const auto& splits_vec = node->get_splits(); const auto& splits_vec = node->get_splits();
const auto first_elem = splits_vec.front(); const auto first_elem = splits_vec.front();
...@@ -634,25 +624,24 @@ namespace ...@@ -634,25 +624,24 @@ namespace
} }
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Subtract> node) shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
{ {
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node); return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
} }
bool op_cast(shared_ptr<op::Sum> node) shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
{ {
bool keep_dims = false; bool keep_dims = false;
auto replacement_node = auto replacement_node =
make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims); make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::TopK> node) shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
{ {
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(), NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"parameter k is expected to be a static constant"); "parameter k is expected to be a static constant");
...@@ -687,15 +676,15 @@ namespace ...@@ -687,15 +676,15 @@ namespace
// indices output will be 0, values 1 // indices output will be 0, values 1
vector<int64_t> output_order{1, 0}; vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order); replace_node(node, replacement_node, output_order);
return true; return replacement_node;
} }
bool op_cast(shared_ptr<op::Xor> node) shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
{ {
auto replacement_node = make_shared<op::v1::LogicalXor>( auto replacement_node = make_shared<op::v1::LogicalXor>(
node->input_value(0), node->input_value(1), node->get_autob()); node->input_value(0), node->input_value(1), node->get_autob());
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return replacement_node;
} }
using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>; using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
...@@ -703,7 +692,18 @@ namespace ...@@ -703,7 +692,18 @@ namespace
template <typename T> template <typename T>
bool op_cast_thunk(shared_ptr<Node> node) bool op_cast_thunk(shared_ptr<Node> node)
{ {
return op_cast(as_type_ptr<T>(node)); auto upgraded_node = op_cast(as_type_ptr<T>(node));
if (upgraded_node)
{
if (ngraph::get_provenance_enabled())
{
const std::string provenance_tag =
"<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
}
return true;
}
return false;
} }
DispatchMap& get_dispatch_map() DispatchMap& get_dispatch_map()
......
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
name: "TOPK"
}
name: "test_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
#include "ngraph/file_util.hpp" #include "ngraph/file_util.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp" #include "ngraph/frontend/onnx_import/default_opset.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp" #include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/provenance.hpp"
#include "util/test_control.hpp" #include "util/test_control.hpp"
#include "util/type_prop.hpp" #include "util/type_prop.hpp"
...@@ -44,52 +48,73 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_text) ...@@ -44,52 +48,73 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_text)
// the NodeToCheck parameter of this template is used to find a node in the whole subgraph // the NodeToCheck parameter of this template is used to find a node in the whole subgraph
// that a particular unit test is supposed to check against the expected provenance tag // that a particular unit test is supposed to check against the expected provenance tag
template <typename NodeToCheck> template <typename NodeToCheck>
void test_provenance_tags(const std::string& model_path, const std::string& expected_provenance_tag) void test_provenance_tags(const std::shared_ptr<Function> function,
const std::string& expected_provenance_tag)
{ {
const auto function = int node_count = 0;
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, model_path));
for (const auto ng_node : function->get_ordered_ops()) for (const auto ng_node : function->get_ordered_ops())
{ {
if (as_type_ptr<NodeToCheck>(ng_node)) if (as_type_ptr<NodeToCheck>(ng_node))
{ {
++node_count;
const auto tags = ng_node->get_provenance_tags(); const auto tags = ng_node->get_provenance_tags();
ASSERT_EQ(tags.size(), 1) << "There should be exactly one provenance tag set for " ASSERT_TRUE(tags.size() > 0) << "Node " << ng_node->get_friendly_name()
<< ng_node; << " should have at least one provenance tag.";
EXPECT_TRUE(tags.find(expected_provenance_tag) != tags.end());
EXPECT_EQ(*(tags.cbegin()), expected_provenance_tag);
} }
} }
EXPECT_TRUE(node_count > 0) << "Expected type of node doesn't exist in graph.";
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_only_output) NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_only_output)
{ {
// the Add node in the model does not have a name, // the Add node in the model does not have a name,
// only its output name should be found in the provenance tags // only its output name should be found in the provenance tags
test_provenance_tags<default_opset::Add>("onnx/provenance_only_outputs.prototxt", const auto function = onnx_import::import_onnx_model(
"<ONNX Add (-> output_of_add)>"); file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_only_outputs.prototxt"));
test_provenance_tags<default_opset::Add>(function, "<ONNX Add (-> output_of_add)>");
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_node_name_and_outputs) NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_node_name_and_outputs)
{ {
test_provenance_tags<default_opset::Add>("onnx/provenance_node_name_and_outputs.prototxt", const auto function = onnx_import::import_onnx_model(
"<ONNX Add (Add_node -> output_of_add)>"); file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_node_name_and_outputs.prototxt"));
test_provenance_tags<default_opset::Add>(function, "<ONNX Add (Add_node -> output_of_add)>");
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_multiple_outputs_op) NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_multiple_outputs_op)
{ {
test_provenance_tags<default_opset::TopK>("onnx/provenance_multiple_outputs_op.prototxt", const auto function = onnx_import::import_onnx_model(
"<ONNX TopK (TOPK -> values, indices)>"); file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_multiple_outputs_op.prototxt"));
test_provenance_tags<default_opset::TopK>(function, "<ONNX TopK (TOPK -> values, indices)>");
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_constants) NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_constants)
{ {
test_provenance_tags<default_opset::Constant>("onnx/provenance_input_tags.prototxt", const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_input_tags.prototxt"));
test_provenance_tags<default_opset::Constant>(function,
"<ONNX Input (initializer_of_A) Shape:{1}>"); "<ONNX Input (initializer_of_A) Shape:{1}>");
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_parameters) NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_parameters)
{ {
test_provenance_tags<default_opset::Parameter>("onnx/provenance_input_tags.prototxt", const auto function = onnx_import::import_onnx_model(
"<ONNX Input (input_B) Shape:{}>"); file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_input_tags.prototxt"));
test_provenance_tags<default_opset::Parameter>(function, "<ONNX Input (input_B) Shape:{}>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_downgrade_pass)
{
set_provenance_enabled(true);
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_downgrade_topk.prototxt"));
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(function);
test_provenance_tags<op::v0::TopK>(function, "<ONNX TopK (TOPK -> values, indices)>");
test_provenance_tags<op::v0::TopK>(function, "<Opset0_Downgrade (v1 TopK)>");
} }
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp" #include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp" #include "ngraph/provenance.hpp"
using namespace std; using namespace std;
...@@ -333,6 +335,56 @@ TEST(provenance, add_group_above) ...@@ -333,6 +335,56 @@ TEST(provenance, add_group_above)
EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"})); EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"}));
} }
TEST(provenance, add_tags_above)
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
auto b = make_shared<op::Multiply>(x, y);
auto c = make_shared<op::Subtract>(a, b);
auto d = make_shared<op::Abs>(c);
// Add tags to Subtract and all nodes until Parameters (all above c, until params x, y)
c->add_provenance_tags_above(OutputVector{x, y}, {"tag_above_c - until_params"});
// Add tags to Abs and Subtract (above d, until c inputs)
d->add_provenance_tags_above(c->input_values(), {"tag_above_d - until_c_inputs"});
// Add tags to Abs and all nodes above
d->add_provenance_tags_above(OutputVector{}, {"tag_all_above_d"});
auto x_tags = x->get_provenance_tags();
EXPECT_EQ(x_tags.size(), 1);
EXPECT_TRUE(x_tags.find("tag_all_above_d") != x_tags.end());
auto y_tags = y->get_provenance_tags();
EXPECT_EQ(y_tags.size(), 1);
EXPECT_TRUE(y_tags.find("tag_all_above_d") != y_tags.end());
auto a_tags = a->get_provenance_tags();
EXPECT_EQ(a_tags.size(), 2);
EXPECT_TRUE(a_tags.find("tag_above_c - until_params") != a_tags.end());
EXPECT_FALSE(a_tags.find("tag_above_d - until_c_inputs") != a_tags.end());
EXPECT_TRUE(a_tags.find("tag_all_above_d") != a_tags.end());
auto b_tags = b->get_provenance_tags();
EXPECT_EQ(b_tags.size(), 2);
EXPECT_TRUE(b_tags.find("tag_above_c - until_params") != b_tags.end());
EXPECT_FALSE(b_tags.find("tag_above_d - until_c_inputs") != b_tags.end());
EXPECT_TRUE(b_tags.find("tag_all_above_d") != b_tags.end());
auto c_tags = c->get_provenance_tags();
EXPECT_EQ(c_tags.size(), 3);
EXPECT_TRUE(c_tags.find("tag_above_c - until_params") != c_tags.end());
EXPECT_TRUE(c_tags.find("tag_above_d - until_c_inputs") != c_tags.end());
EXPECT_TRUE(c_tags.find("tag_all_above_d") != c_tags.end());
auto d_tags = d->get_provenance_tags();
EXPECT_EQ(d_tags.size(), 2);
EXPECT_FALSE(d_tags.find("tag_above_c - until_params") != d_tags.end());
EXPECT_TRUE(d_tags.find("tag_above_d - until_c_inputs") != d_tags.end());
EXPECT_TRUE(d_tags.find("tag_all_above_d") != d_tags.end());
}
TEST(provenance, builder) TEST(provenance, builder)
{ {
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4}); auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
...@@ -501,3 +553,155 @@ TEST(provenance, scaled_quantize_concat_unsigned) ...@@ -501,3 +553,155 @@ TEST(provenance, scaled_quantize_concat_unsigned)
} }
} }
} }
TEST(provenance, opset1_upgrade_pass_topk)
{
set_provenance_enabled(true);
const size_t axis = 2;
const size_t k = 10;
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const auto topk_v0 = make_shared<op::v0::TopK>(data, axis, element::i32, k);
const auto result = make_shared<op::Result>(topk_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v1 = as_type_ptr<op::v1::TopK>(pass_replacement_node);
const std::string tag = "<Opset1_Upgrade (v0 TopK)>";
auto tag_check = [&tag](std::shared_ptr<ngraph::Node> node) {
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
traverse_nodes(as_node_vector(topk_v1->outputs()),
tag_check,
false,
as_node_vector(topk_v0->input_values()));
}
TEST(provenance, opset0_downgrade_pass_topk)
{
set_provenance_enabled(true);
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const int32_t k = 10;
const auto k_node = op::Constant::create(element::i64, Shape{}, {k});
const size_t axis = 2;
const auto mode = op::v1::TopK::Mode::MAX;
const auto sort = op::v1::TopK::SortType::SORT_INDICES;
const auto elem_type = element::i64;
const auto topk_v1 = make_shared<op::v1::TopK>(data, k_node, axis, mode, sort, elem_type);
const auto result = make_shared<op::Result>(topk_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
const std::string tag = "<Opset0_Downgrade (v1 TopK)>";
auto tag_check = [&tag](std::shared_ptr<ngraph::Node> node) {
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
traverse_nodes(as_node_vector(topk_v0->outputs()),
tag_check,
false,
as_node_vector(topk_v1->input_values()));
}
TEST(provenance, opset1_upgrade_pass_graph)
{
set_provenance_enabled(true);
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::v0::Add>(x, y);
auto b = make_shared<op::v0::Subtract>(x, y);
auto c = make_shared<op::v0::Abs>(b);
auto d = make_shared<op::v0::Multiply>(a, b);
auto f = make_shared<Function>(d, ParameterVector{x, y});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
for (auto node : f->get_ordered_ops())
{
auto tags = node->get_provenance_tags();
if (as_type_ptr<op::v1::Add>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset1_Upgrade (v0 Add)>") != tags.end());
}
else if (as_type_ptr<op::v1::Multiply>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset1_Upgrade (v0 Multiply)>") != tags.end());
}
else if (as_type_ptr<op::v1::Subtract>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset1_Upgrade (v0 Subtract)>") != tags.end());
}
else if (as_type_ptr<op::v0::Abs>(node))
{
EXPECT_TRUE(tags.empty());
}
}
}
TEST(provenance, opset0_downgrade_pass_graph)
{
set_provenance_enabled(true);
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::v1::Add>(x, y);
auto b = make_shared<op::v1::Subtract>(x, y);
auto c = make_shared<op::v0::Abs>(b);
auto d = make_shared<op::v1::Multiply>(a, b);
auto f = make_shared<Function>(d, ParameterVector{x, y});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
for (auto node : f->get_ordered_ops())
{
auto tags = node->get_provenance_tags();
if (as_type_ptr<op::v0::Add>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset0_Downgrade (v1 Add)>") != tags.end());
}
else if (as_type_ptr<op::v0::Multiply>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset0_Downgrade (v1 Multiply)>") != tags.end());
}
else if (as_type_ptr<op::v0::Subtract>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset0_Downgrade (v1 Subtract)>") != tags.end());
}
else if (as_type_ptr<op::v0::Abs>(node))
{
EXPECT_TRUE(tags.empty());
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment