Unverified Commit 9f1afc65 authored by Katarzyna Mitrus's avatar Katarzyna Mitrus Committed by GitHub

Add provenance tags to opset transformation passes (#4198)

* Add provenance tag in opset0 downgrade pass

* Add provenance tag in opset1 upgrade pass

* Type name as string

* Change op_cast to return replacement_node instead of bool

* Add provenance tags to all nodes created while downgrade

* Add provenance tags to all nodes created while upgrade

* Comments

* Style apply

* Update ONNX import provenance test function

* Add const statement

* Add upgrade/downgrade provenance tag tests

* Update tests

* Style apply

* Provenance enabled check

* Removed redundant add_tag

* Test for add_provenance_tags above

* Add graph test for provenance tags in transformation pass

* Use EXPECT_TRUE and EXPECT_FALSE instead of EXPECT_EQ

* Return replacement node directly

* Style apply

* Test downgrade provenance tag with ONNX importer

* Update test/onnx/onnx_import_provenance.in.cpp
Co-Authored-By: 's avatarMichał Karzyński <postrational@users.noreply.github.com>

* Test update

* Update provenance test to check node type occurence

* Style apply
Co-authored-by: 's avatarMichał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 3409bda8
......@@ -27,6 +27,7 @@
#include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/slice_plan.hpp"
#include "ngraph/type.hpp"
#include "ngraph/validation_util.hpp"
......@@ -37,17 +38,18 @@ using namespace ngraph;
namespace
{
template <typename OpV0, typename OpV1>
void op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV1>& node)
{
const auto input_arg0 = node->input_value(0);
const auto input_arg1 = node->input_value(1);
const auto autob = node->get_autob();
auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
replace_node(node, replacement_node);
return replacement_node;
}
template <typename OpV0, typename OpV1>
void op_cast_reduction_node(const shared_ptr<OpV1>& node)
shared_ptr<Node> op_cast_reduction_node(const shared_ptr<OpV1>& node)
{
auto replacement_node = make_shared<OpV0>(node->input_value(0), node->input_value(1));
if (node->get_keep_dims())
......@@ -85,17 +87,17 @@ namespace
{
replace_node(node, replacement_node);
}
return replacement_node;
}
// Default is that we did nothing
bool op_cast(shared_ptr<Node> node) { return false; }
bool op_cast(shared_ptr<op::v1::Add> node)
shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
shared_ptr<Node> op_cast(shared_ptr<op::v1::Add> node)
{
op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
}
bool op_cast(shared_ptr<op::v1::AvgPool> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPool> node)
{
auto const input_arg = node->input_value(0);
const auto ceil_mode = static_cast<bool>(node->get_rounding_type());
......@@ -115,10 +117,10 @@ namespace
pad_type,
ceil_mode);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::AvgPoolBackprop> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::AvgPoolBackprop> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
const auto forward_arg_shape =
......@@ -140,10 +142,10 @@ namespace
padding_above,
include_padding_in_avg_computation);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Broadcast> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Broadcast> node)
{
auto arg = node->input_value(0);
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
......@@ -155,10 +157,10 @@ namespace
make_shared<op::v0::Broadcast>(arg, target_shape, node->get_broadcast_axes().second);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Convolution> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Convolution> node)
{
const auto data_arg = node->input_value(0);
const auto filters_arg = node->input_value(1);
......@@ -173,10 +175,10 @@ namespace
Strides(num_spatial_dims, 1),
node->get_auto_pad());
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropData> node)
{
auto output_shape_node =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
......@@ -223,10 +225,10 @@ namespace
node->get_pads_end(),
Strides(num_spatial_dims, 1));
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::ConvolutionBackpropFilters> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::ConvolutionBackpropFilters> node)
{
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant());
auto filters_shape =
......@@ -246,10 +248,10 @@ namespace
node->get_pads_end(),
Strides(num_spatial_dims, 1));
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Divide> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Divide> node)
{
const auto input_arg0 = node->input_value(0);
const auto input_arg1 = node->input_value(1);
......@@ -257,10 +259,10 @@ namespace
const bool pydiv = node->is_pythondiv();
auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Reshape> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reshape> node)
{
shared_ptr<Node> replacement_node;
......@@ -282,16 +284,15 @@ namespace
}
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Equal> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Equal> node)
{
op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
}
bool op_cast(shared_ptr<op::v1::Gather> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Gather> node)
{
auto axis_node = as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
......@@ -309,10 +310,10 @@ namespace
auto replacement_node =
make_shared<op::v0::Gather>(node->input_value(0), node->input_value(1), axis);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::GenerateMask> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::GenerateMask> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant());
auto mask_shape =
......@@ -327,22 +328,20 @@ namespace
node->input_value(0), mask_shape, et, seed, probability, use_seed);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Greater> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Greater> node)
{
op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
}
bool op_cast(shared_ptr<op::v1::GreaterEqual> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::GreaterEqual> node)
{
op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
}
bool op_cast(shared_ptr<op::v1::GroupConvolution> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolution> node)
{
const auto data_arg = node->input_value(0);
const auto filters_arg = node->input_value(1);
......@@ -357,10 +356,10 @@ namespace
Strides(num_spatial_dims, 1),
node->get_auto_pad());
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::GroupConvolutionBackpropData> node)
{
auto output_shape_input =
as_type_ptr<op::Constant>(node->input_value(2).get_node_shared_ptr());
......@@ -434,52 +433,47 @@ namespace
pads_end,
groups);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Less> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Less> node)
{
op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
}
bool op_cast(shared_ptr<op::v1::LessEqual> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::LessEqual> node)
{
op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
}
bool op_cast(shared_ptr<op::v1::LogicalAnd> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalAnd> node)
{
op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
}
bool op_cast(shared_ptr<op::v1::LogicalNot> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalNot> node)
{
replace_node(node, make_shared<op::v0::Not>(node->input_value(0)));
return true;
auto replacement_node = make_shared<op::v0::Not>(node->input_value(0));
replace_node(node, replacement_node);
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::LogicalOr> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalOr> node)
{
op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
}
bool op_cast(shared_ptr<op::v1::LogicalXor> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::LogicalXor> node)
{
op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
}
bool op_cast(shared_ptr<op::v1::Maximum> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Maximum> node)
{
op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
}
bool op_cast(shared_ptr<op::v1::MaxPool> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::MaxPool> node)
{
auto const input_arg = node->input_value(0);
auto ceil_mode = static_cast<bool>(node->get_rounding_type());
......@@ -497,10 +491,10 @@ namespace
pad_type,
ceil_mode);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::MaxPoolBackprop> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::MaxPoolBackprop> node)
{
const auto padding_below = node->get_pads_begin();
const auto padding_above = node->get_pads_end();
......@@ -532,28 +526,25 @@ namespace
padding_above);
}
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Minimum> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Minimum> node)
{
op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
}
bool op_cast(shared_ptr<op::v1::Multiply> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Multiply> node)
{
op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
}
bool op_cast(shared_ptr<op::v1::NotEqual> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::NotEqual> node)
{
op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
}
bool op_cast(shared_ptr<op::v1::OneHot> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
{
const auto indices = node->input_value(0).get_node_shared_ptr();
const auto depth = node->input_value(1).get_node_shared_ptr();
......@@ -577,10 +568,10 @@ namespace
auto replacement_node = one_hot * (on_value - off_value) + off_value;
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Pad> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Pad> node)
{
const auto pad_arg = node->input_value(0);
const auto pad_value = node->input_value(3);
......@@ -588,40 +579,35 @@ namespace
pad_arg, pad_value, node->get_pads_begin(), node->get_pads_end(), node->get_pad_mode());
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Power> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
{
op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
}
bool op_cast(shared_ptr<op::v1::ReduceMax> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMax> node)
{
op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
return true;
return op_cast_reduction_node<op::v0::Max, op::v1::ReduceMax>(node);
}
bool op_cast(shared_ptr<op::v1::ReduceMin> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceMin> node)
{
op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
return true;
return op_cast_reduction_node<op::v0::Min, op::v1::ReduceMin>(node);
}
bool op_cast(shared_ptr<op::v1::ReduceProd> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceProd> node)
{
op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
return true;
return op_cast_reduction_node<op::v0::Product, op::v1::ReduceProd>(node);
}
bool op_cast(shared_ptr<op::v1::ReduceSum> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::ReduceSum> node)
{
op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
return true;
return op_cast_reduction_node<op::v0::Sum, op::v1::ReduceSum>(node);
}
bool op_cast(shared_ptr<op::v1::Reverse> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
{
auto axes_node = node->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(axes_node->is_constant(),
......@@ -648,19 +634,19 @@ namespace
auto replacement_node = make_shared<op::v0::Reverse>(node->input_value(0), axes);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Select> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Select> node)
{
ngraph::pass::ImplicitBroadcastElimination().run_on_node(node);
auto replacement_node = make_shared<op::v0::Select>(
node->input_value(0), node->input_value(1), node->input_value(2));
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::StridedSlice> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::StridedSlice> node)
{
auto convert_mask_to_axes = [](const std::vector<int64_t>& mask) {
AxisSet axes{};
......@@ -723,10 +709,10 @@ namespace
}
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Softmax> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Softmax> node)
{
auto axis = node->get_axis();
auto data = node->input(0);
......@@ -735,10 +721,10 @@ namespace
std::iota(std::begin(axes), std::end(axes), axis);
auto replacement_node = make_shared<op::v0::Softmax>(node->input_value(0), axes);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Split> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Split> node)
{
const auto num_splits = node->get_num_splits();
......@@ -746,16 +732,15 @@ namespace
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), num_splits);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Subtract> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
}
bool op_cast(shared_ptr<op::v1::TopK> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::TopK> node)
{
const auto axis = node->get_axis();
const auto sort_type = node->get_sort_type();
......@@ -778,10 +763,10 @@ namespace
// values output will be 0, indices 1
vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::Transpose> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::Transpose> node)
{
const auto data = node->input_value(0);
......@@ -816,10 +801,10 @@ namespace
auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v1::VariadicSplit> node)
shared_ptr<Node> op_cast(shared_ptr<op::v1::VariadicSplit> node)
{
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
......@@ -835,7 +820,7 @@ namespace
make_shared<op::v0::Split>(node->input_value(0), node->input_value(1), splits_unsigned);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
......@@ -843,7 +828,18 @@ namespace
template <typename T>
bool op_cast_thunk(shared_ptr<Node> node)
{
return op_cast(as_type_ptr<T>(node));
auto downgraded_node = op_cast(as_type_ptr<T>(node));
if (downgraded_node)
{
if (ngraph::get_provenance_enabled())
{
const std::string provenance_tag =
"<Opset0_Downgrade (v1 " + std::string(node->get_type_name()) + ")>";
downgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
}
return true;
}
return false;
}
DispatchMap& get_dispatch_map()
......
......@@ -22,6 +22,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
using namespace ngraph;
......@@ -29,29 +30,28 @@ using namespace ngraph;
namespace
{
template <typename OpV0, typename OpV1>
void op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
shared_ptr<Node> op_cast_binary_elementwise_node(const shared_ptr<OpV0>& node)
{
const auto autob = node->get_autob();
auto replacement_node =
make_shared<OpV1>(node->input_value(0), node->input_value(1), autob);
replace_node(node, replacement_node);
return replacement_node;
}
// Default is that we didn nothing
bool op_cast(shared_ptr<Node> node) { return false; }
bool op_cast(shared_ptr<op::Add> node)
shared_ptr<Node> op_cast(shared_ptr<Node> node) { return nullptr; }
shared_ptr<Node> op_cast(shared_ptr<op::Add> node)
{
op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
}
bool op_cast(shared_ptr<op::And> node)
shared_ptr<Node> op_cast(shared_ptr<op::And> node)
{
op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
}
bool op_cast(shared_ptr<op::AvgPool> node)
shared_ptr<Node> op_cast(shared_ptr<op::AvgPool> node)
{
auto rounding_mode =
node->get_ceil_mode() ? op::RoundingType::CEIL : op::RoundingType::FLOOR;
......@@ -82,10 +82,10 @@ namespace
}
#endif
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::AvgPoolBackprop> node)
shared_ptr<Node> op_cast(shared_ptr<op::AvgPoolBackprop> node)
{
auto exclude_pad = !node->get_include_padding_in_avg_computation();
auto pads_begin = node->get_padding_below();
......@@ -101,10 +101,10 @@ namespace
kernel,
exclude_pad);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Broadcast> node)
shared_ptr<Node> op_cast(shared_ptr<op::Broadcast> node)
{
auto result_shape = node->get_broadcast_shape();
auto result_shape_node =
......@@ -124,11 +124,11 @@ namespace
auto replacement_node = make_shared<op::v1::Broadcast>(
node->input_value(0), result_shape_node->output(0), axes_mapping_node->output(0));
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::BroadcastLike> node) { return false; }
bool op_cast(shared_ptr<op::Convolution> node)
shared_ptr<Node> op_cast(shared_ptr<op::BroadcastLike> node) { return nullptr; }
shared_ptr<Node> op_cast(shared_ptr<op::Convolution> node)
{
auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides();
......@@ -154,10 +154,10 @@ namespace
dilations,
auto_pad);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::ConvolutionBackpropData> node)
shared_ptr<Node> op_cast(shared_ptr<op::ConvolutionBackpropData> node)
{
auto data_batch_shape = node->get_data_batch_shape();
auto strides = node->get_window_movement_strides_forward();
......@@ -188,10 +188,10 @@ namespace
pads_end,
dilations);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::ConvolutionBackpropFilters> node)
shared_ptr<Node> op_cast(shared_ptr<op::ConvolutionBackpropFilters> node)
{
auto filters_shape = node->get_filters_shape();
auto strides = node->get_window_movement_strides_forward();
......@@ -220,35 +220,34 @@ namespace
pads_begin,
pads_end);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Divide> node)
shared_ptr<Node> op_cast(shared_ptr<op::Divide> node)
{
const auto autob = node->get_autob();
const bool pydiv = node->is_pythondiv();
auto replacement_node =
make_shared<op::v1::Divide>(node->input_value(0), node->input_value(1), pydiv, autob);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::DynReshape> node)
shared_ptr<Node> op_cast(shared_ptr<op::DynReshape> node)
{
auto zero_flag = false;
auto replacement_node =
make_shared<op::v1::Reshape>(node->input_value(0), node->input_value(1), zero_flag);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Equal> node)
shared_ptr<Node> op_cast(shared_ptr<op::Equal> node)
{
op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
}
bool op_cast(shared_ptr<op::Gather> node)
shared_ptr<Node> op_cast(shared_ptr<op::Gather> node)
{
int64_t axis = node->get_axis();
......@@ -256,22 +255,20 @@ namespace
auto replacement_node =
make_shared<op::v1::Gather>(node->input_value(0), node->input_value(1), axis_node);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Greater> node)
shared_ptr<Node> op_cast(shared_ptr<op::Greater> node)
{
op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
}
bool op_cast(shared_ptr<op::GreaterEq> node)
shared_ptr<Node> op_cast(shared_ptr<op::GreaterEq> node)
{
op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEqual>(node);
}
bool op_cast(shared_ptr<op::v0::GroupConvolution> node)
shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolution> node)
{
auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides();
......@@ -324,10 +321,10 @@ namespace
auto_pad);
}
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
shared_ptr<Node> op_cast(shared_ptr<op::v0::GroupConvolutionBackpropData> node)
{
auto strides = node->get_window_movement_strides();
auto dilations = node->get_window_dilation_strides();
......@@ -364,37 +361,34 @@ namespace
pads_end,
dilations);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Less> node)
shared_ptr<Node> op_cast(shared_ptr<op::Less> node)
{
op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
}
bool op_cast(shared_ptr<op::LessEq> node)
shared_ptr<Node> op_cast(shared_ptr<op::LessEq> node)
{
op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
}
bool op_cast(shared_ptr<op::Max> node)
shared_ptr<Node> op_cast(shared_ptr<op::Max> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceMax>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Maximum> node)
shared_ptr<Node> op_cast(shared_ptr<op::Maximum> node)
{
op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
}
bool op_cast(shared_ptr<op::MaxPool> node)
shared_ptr<Node> op_cast(shared_ptr<op::MaxPool> node)
{
auto rounding_type =
node->get_ceil_mode() ? op::RoundingType::CEIL : op::RoundingType::FLOOR;
......@@ -418,10 +412,10 @@ namespace
}
#endif
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::MaxPoolBackprop> node)
shared_ptr<Node> op_cast(shared_ptr<op::MaxPoolBackprop> node)
{
auto pads_begin = node->get_padding_below();
auto pads_end = node->get_padding_above();
......@@ -445,43 +439,41 @@ namespace
node->input_value(0), node->input_value(1), strides, pads_begin, pads_end, kernel);
}
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Min> node)
shared_ptr<Node> op_cast(shared_ptr<op::Min> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceMin>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Minimum> node)
shared_ptr<Node> op_cast(shared_ptr<op::Minimum> node)
{
op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
}
bool op_cast(shared_ptr<op::Multiply> node)
shared_ptr<Node> op_cast(shared_ptr<op::Multiply> node)
{
op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
}
bool op_cast(shared_ptr<op::Not> node)
shared_ptr<Node> op_cast(shared_ptr<op::Not> node)
{
replace_node(node, make_shared<op::v1::LogicalNot>(node->input_value(0)));
return true;
auto replacement_node = make_shared<op::v1::LogicalNot>(node->input_value(0));
replace_node(node, replacement_node);
return replacement_node;
}
bool op_cast(shared_ptr<op::NotEqual> node)
shared_ptr<Node> op_cast(shared_ptr<op::NotEqual> node)
{
op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
}
bool op_cast(shared_ptr<op::OneHot> node)
shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
{
const auto indices = node->input_value(0).get_node_shared_ptr();
const auto one_hot_axis = node->get_one_hot_axis();
......@@ -499,16 +491,15 @@ namespace
auto replacement_node =
make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Or> node)
shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
{
op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
}
bool op_cast(shared_ptr<op::Pad> node)
shared_ptr<Node> op_cast(shared_ptr<op::Pad> node)
{
auto padding_below = node->get_padding_below();
auto pads_begin_node =
......@@ -524,25 +515,24 @@ namespace
node->get_pad_mode());
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Power> node)
shared_ptr<Node> op_cast(shared_ptr<op::Power> node)
{
op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
}
bool op_cast(shared_ptr<op::Product> node)
shared_ptr<Node> op_cast(shared_ptr<op::Product> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceProd>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Reverse> node)
shared_ptr<Node> op_cast(shared_ptr<op::Reverse> node)
{
// creates a Constant node from the v0::Reverse reversed_axes attribute
// and uses it as the second input of v1::Reverse
......@@ -551,24 +541,24 @@ namespace
const auto reversed_axes_constant = op::Constant::create(
element::i64, Shape{reversed_axes.size()}, reversed_axes.to_vector());
const auto reverse_v1 = make_shared<op::v1::Reverse>(
const auto replacement_node = make_shared<op::v1::Reverse>(
node->input_value(0), reversed_axes_constant, op::v1::Reverse::Mode::INDEX);
replace_node(node, reverse_v1);
return true;
replace_node(node, replacement_node);
return replacement_node;
}
bool op_cast(shared_ptr<op::Select> node)
shared_ptr<Node> op_cast(shared_ptr<op::Select> node)
{
auto replacement_node = make_shared<op::v1::Select>(node->input_value(0),
node->input_value(1),
node->input_value(2),
op::AutoBroadcastSpec());
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Softmax> node)
shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"axes parameter is expected to be a static constant");
......@@ -583,10 +573,10 @@ namespace
auto replacement_node =
make_shared<op::v1::Softmax>(node->input_value(0), axes.to_vector()[0]);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Slice> node)
shared_ptr<Node> op_cast(shared_ptr<op::Slice> node)
{
const auto data = node->input_value(0);
const auto begin = op::Constant::create(
......@@ -605,10 +595,10 @@ namespace
vector<int64_t>(input_size, 0));
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Split> node)
shared_ptr<Node> op_cast(shared_ptr<op::Split> node)
{
const auto& splits_vec = node->get_splits();
const auto first_elem = splits_vec.front();
......@@ -634,25 +624,24 @@ namespace
}
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Subtract> node)
shared_ptr<Node> op_cast(shared_ptr<op::Subtract> node)
{
op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
return true;
return op_cast_binary_elementwise_node<op::v0::Subtract, op::v1::Subtract>(node);
}
bool op_cast(shared_ptr<op::Sum> node)
shared_ptr<Node> op_cast(shared_ptr<op::Sum> node)
{
bool keep_dims = false;
auto replacement_node =
make_shared<op::v1::ReduceSum>(node->input_value(0), node->input_value(1), keep_dims);
replace_node(node, replacement_node);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::TopK> node)
shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
"parameter k is expected to be a static constant");
......@@ -687,15 +676,15 @@ namespace
// indices output will be 0, values 1
vector<int64_t> output_order{1, 0};
replace_node(node, replacement_node, output_order);
return true;
return replacement_node;
}
bool op_cast(shared_ptr<op::Xor> node)
shared_ptr<Node> op_cast(shared_ptr<op::Xor> node)
{
auto replacement_node = make_shared<op::v1::LogicalXor>(
node->input_value(0), node->input_value(1), node->get_autob());
replace_node(node, replacement_node);
return true;
return replacement_node;
}
using DispatchMap = map<NodeTypeInfo, std::function<bool(shared_ptr<Node> node)>>;
......@@ -703,7 +692,18 @@ namespace
template <typename T>
bool op_cast_thunk(shared_ptr<Node> node)
{
return op_cast(as_type_ptr<T>(node));
auto upgraded_node = op_cast(as_type_ptr<T>(node));
if (upgraded_node)
{
if (ngraph::get_provenance_enabled())
{
const std::string provenance_tag =
"<Opset1_Upgrade (v0 " + std::string(node->get_type_name()) + ")>";
upgraded_node->add_provenance_tags_above(node->input_values(), {provenance_tag});
}
return true;
}
return false;
}
DispatchMap& get_dispatch_map()
......
ir_version: 4
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "k"
output: "values"
output: "indices"
op_type: "TopK"
name: "TOPK"
}
name: "test_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "k"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "values"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "indices"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 10
}
......@@ -18,6 +18,10 @@
#include "ngraph/file_util.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/provenance.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"
......@@ -44,52 +48,73 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_text)
// the NodeToCheck parameter of this template is used to find a node in the whole subgraph
// that a particular unit test is supposed to check against the expected provenance tag
template <typename NodeToCheck>
void test_provenance_tags(const std::string& model_path, const std::string& expected_provenance_tag)
void test_provenance_tags(const std::shared_ptr<Function> function,
const std::string& expected_provenance_tag)
{
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, model_path));
int node_count = 0;
for (const auto ng_node : function->get_ordered_ops())
{
if (as_type_ptr<NodeToCheck>(ng_node))
{
++node_count;
const auto tags = ng_node->get_provenance_tags();
ASSERT_EQ(tags.size(), 1) << "There should be exactly one provenance tag set for "
<< ng_node;
EXPECT_EQ(*(tags.cbegin()), expected_provenance_tag);
ASSERT_TRUE(tags.size() > 0) << "Node " << ng_node->get_friendly_name()
<< " should have at least one provenance tag.";
EXPECT_TRUE(tags.find(expected_provenance_tag) != tags.end());
}
}
EXPECT_TRUE(node_count > 0) << "Expected type of node doesn't exist in graph.";
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_only_output)
{
// the Add node in the model does not have a name,
// only its output name should be found in the provenance tags
test_provenance_tags<default_opset::Add>("onnx/provenance_only_outputs.prototxt",
"<ONNX Add (-> output_of_add)>");
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_only_outputs.prototxt"));
test_provenance_tags<default_opset::Add>(function, "<ONNX Add (-> output_of_add)>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_node_name_and_outputs)
{
test_provenance_tags<default_opset::Add>("onnx/provenance_node_name_and_outputs.prototxt",
"<ONNX Add (Add_node -> output_of_add)>");
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_node_name_and_outputs.prototxt"));
test_provenance_tags<default_opset::Add>(function, "<ONNX Add (Add_node -> output_of_add)>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_multiple_outputs_op)
{
test_provenance_tags<default_opset::TopK>("onnx/provenance_multiple_outputs_op.prototxt",
"<ONNX TopK (TOPK -> values, indices)>");
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_multiple_outputs_op.prototxt"));
test_provenance_tags<default_opset::TopK>(function, "<ONNX TopK (TOPK -> values, indices)>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_constants)
{
test_provenance_tags<default_opset::Constant>("onnx/provenance_input_tags.prototxt",
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_input_tags.prototxt"));
test_provenance_tags<default_opset::Constant>(function,
"<ONNX Input (initializer_of_A) Shape:{1}>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tagging_parameters)
{
test_provenance_tags<default_opset::Parameter>("onnx/provenance_input_tags.prototxt",
"<ONNX Input (input_B) Shape:{}>");
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_input_tags.prototxt"));
test_provenance_tags<default_opset::Parameter>(function, "<ONNX Input (input_B) Shape:{}>");
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, provenance_tag_downgrade_pass)
{
set_provenance_enabled(true);
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/provenance_downgrade_topk.prototxt"));
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(function);
test_provenance_tags<op::v0::TopK>(function, "<ONNX TopK (TOPK -> values, indices)>");
test_provenance_tags<op::v0::TopK>(function, "<Opset0_Downgrade (v1 TopK)>");
}
......@@ -27,6 +27,8 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
......@@ -333,6 +335,56 @@ TEST(provenance, add_group_above)
EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"}));
}
TEST(provenance, add_tags_above)
{
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::Add>(x, y);
auto b = make_shared<op::Multiply>(x, y);
auto c = make_shared<op::Subtract>(a, b);
auto d = make_shared<op::Abs>(c);
// Add tags to Subtract and all nodes until Parameters (all above c, until params x, y)
c->add_provenance_tags_above(OutputVector{x, y}, {"tag_above_c - until_params"});
// Add tags to Abs and Subtract (above d, until c inputs)
d->add_provenance_tags_above(c->input_values(), {"tag_above_d - until_c_inputs"});
// Add tags to Abs and all nodes above
d->add_provenance_tags_above(OutputVector{}, {"tag_all_above_d"});
auto x_tags = x->get_provenance_tags();
EXPECT_EQ(x_tags.size(), 1);
EXPECT_TRUE(x_tags.find("tag_all_above_d") != x_tags.end());
auto y_tags = y->get_provenance_tags();
EXPECT_EQ(y_tags.size(), 1);
EXPECT_TRUE(y_tags.find("tag_all_above_d") != y_tags.end());
auto a_tags = a->get_provenance_tags();
EXPECT_EQ(a_tags.size(), 2);
EXPECT_TRUE(a_tags.find("tag_above_c - until_params") != a_tags.end());
EXPECT_FALSE(a_tags.find("tag_above_d - until_c_inputs") != a_tags.end());
EXPECT_TRUE(a_tags.find("tag_all_above_d") != a_tags.end());
auto b_tags = b->get_provenance_tags();
EXPECT_EQ(b_tags.size(), 2);
EXPECT_TRUE(b_tags.find("tag_above_c - until_params") != b_tags.end());
EXPECT_FALSE(b_tags.find("tag_above_d - until_c_inputs") != b_tags.end());
EXPECT_TRUE(b_tags.find("tag_all_above_d") != b_tags.end());
auto c_tags = c->get_provenance_tags();
EXPECT_EQ(c_tags.size(), 3);
EXPECT_TRUE(c_tags.find("tag_above_c - until_params") != c_tags.end());
EXPECT_TRUE(c_tags.find("tag_above_d - until_c_inputs") != c_tags.end());
EXPECT_TRUE(c_tags.find("tag_all_above_d") != c_tags.end());
auto d_tags = d->get_provenance_tags();
EXPECT_EQ(d_tags.size(), 2);
EXPECT_FALSE(d_tags.find("tag_above_c - until_params") != d_tags.end());
EXPECT_TRUE(d_tags.find("tag_above_d - until_c_inputs") != d_tags.end());
EXPECT_TRUE(d_tags.find("tag_all_above_d") != d_tags.end());
}
TEST(provenance, builder)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
......@@ -501,3 +553,155 @@ TEST(provenance, scaled_quantize_concat_unsigned)
}
}
}
TEST(provenance, opset1_upgrade_pass_topk)
{
set_provenance_enabled(true);
const size_t axis = 2;
const size_t k = 10;
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const auto topk_v0 = make_shared<op::v0::TopK>(data, axis, element::i32, k);
const auto result = make_shared<op::Result>(topk_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v1 = as_type_ptr<op::v1::TopK>(pass_replacement_node);
const std::string tag = "<Opset1_Upgrade (v0 TopK)>";
auto tag_check = [&tag](std::shared_ptr<ngraph::Node> node) {
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
traverse_nodes(as_node_vector(topk_v1->outputs()),
tag_check,
false,
as_node_vector(topk_v0->input_values()));
}
TEST(provenance, opset0_downgrade_pass_topk)
{
set_provenance_enabled(true);
const auto data = make_shared<op::Parameter>(element::i32, Shape{5, 10, 15});
const int32_t k = 10;
const auto k_node = op::Constant::create(element::i64, Shape{}, {k});
const size_t axis = 2;
const auto mode = op::v1::TopK::Mode::MAX;
const auto sort = op::v1::TopK::SortType::SORT_INDICES;
const auto elem_type = element::i64;
const auto topk_v1 = make_shared<op::v1::TopK>(data, k_node, axis, mode, sort, elem_type);
const auto result = make_shared<op::Result>(topk_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto topk_v0 = as_type_ptr<op::v0::TopK>(pass_replacement_node);
const std::string tag = "<Opset0_Downgrade (v1 TopK)>";
auto tag_check = [&tag](std::shared_ptr<ngraph::Node> node) {
auto tags = node->get_provenance_tags();
EXPECT_TRUE(tags.find(tag) != tags.end());
};
traverse_nodes(as_node_vector(topk_v0->outputs()),
tag_check,
false,
as_node_vector(topk_v1->input_values()));
}
TEST(provenance, opset1_upgrade_pass_graph)
{
set_provenance_enabled(true);
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::v0::Add>(x, y);
auto b = make_shared<op::v0::Subtract>(x, y);
auto c = make_shared<op::v0::Abs>(b);
auto d = make_shared<op::v0::Multiply>(a, b);
auto f = make_shared<Function>(d, ParameterVector{x, y});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
for (auto node : f->get_ordered_ops())
{
auto tags = node->get_provenance_tags();
if (as_type_ptr<op::v1::Add>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset1_Upgrade (v0 Add)>") != tags.end());
}
else if (as_type_ptr<op::v1::Multiply>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset1_Upgrade (v0 Multiply)>") != tags.end());
}
else if (as_type_ptr<op::v1::Subtract>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset1_Upgrade (v0 Subtract)>") != tags.end());
}
else if (as_type_ptr<op::v0::Abs>(node))
{
EXPECT_TRUE(tags.empty());
}
}
}
TEST(provenance, opset0_downgrade_pass_graph)
{
set_provenance_enabled(true);
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
auto a = make_shared<op::v1::Add>(x, y);
auto b = make_shared<op::v1::Subtract>(x, y);
auto c = make_shared<op::v0::Abs>(b);
auto d = make_shared<op::v1::Multiply>(a, b);
auto f = make_shared<Function>(d, ParameterVector{x, y});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
for (auto node : f->get_ordered_ops())
{
auto tags = node->get_provenance_tags();
if (as_type_ptr<op::v0::Add>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset0_Downgrade (v1 Add)>") != tags.end());
}
else if (as_type_ptr<op::v0::Multiply>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset0_Downgrade (v1 Multiply)>") != tags.end());
}
else if (as_type_ptr<op::v0::Subtract>(node))
{
EXPECT_EQ(tags.size(), 1);
EXPECT_TRUE(tags.find("<Opset0_Downgrade (v1 Subtract)>") != tags.end());
}
else if (as_type_ptr<op::v0::Abs>(node))
{
EXPECT_TRUE(tags.empty());
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment