Unverified Commit 8ed67b44 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Master edition of Provenance for builders (#3647)

* Provenance for builders (#3644)

* Provenance for builders

* Tests

* Tests

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarSayantan Sarkar <sayantan.sarkar@intel.com>

* style

Finish migration

* Migration to master

* Handle op constructors that construct ops

* Update provenance for topk

* Add missing input (#3691)

* Add missing input

* If builder does nothing, don't dive into the inputs

* Reverse test

* simplify, add test

* Add test

* add test
parent af2b137b
......@@ -172,7 +172,7 @@ namespace ngraph
return_value, final_shape, broadcast_axes);
}
return return_value.get_node_shared_ptr();
return return_value.get_node_shared_ptr()->add_provenance_group_members_above({value});
}
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
......
......@@ -51,7 +51,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_utils::get_scale(min, max, quant_type);
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes);
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes)
->add_provenance_group_members_above({input, min, max});
}
}
}
......@@ -103,7 +103,7 @@ namespace ngraph
val = std::make_shared<ngraph::op::Broadcast>(val, shape, axes);
}
return val;
return val->add_provenance_group_members_above({});
}
}
}
......@@ -80,7 +80,9 @@ NodeVector builder::MatmulFactory::make_matmul_op()
// Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2)
{
return {make_dot(left, right).get_node_shared_ptr()};
return {make_dot(left, right)
.get_node_shared_ptr()
->add_provenance_group_members_above(m_inputs)};
}
// Second case:
......@@ -135,7 +137,7 @@ NodeVector builder::MatmulFactory::make_matmul_op()
if (left_shape.size() <= 3 && right_shape.size() <= 3)
{
return {result};
return {result->add_provenance_group_members_above(m_inputs)};
}
// Expand result _stack of matrices_ axes to get expected result shape.
else
......@@ -144,7 +146,8 @@ NodeVector builder::MatmulFactory::make_matmul_op()
Shape result_shape(next(begin(shape)), end(shape));
result_shape.insert(
begin(result_shape), begin(left_shape), next(begin(left_shape), left_shape.size() - 2));
return {make_shared<op::Reshape>(result, get_default_order(shape.size()), result_shape)};
return {make_shared<op::Reshape>(result, get_default_order(shape.size()), result_shape)
->add_provenance_group_members_above(m_inputs)};
}
}
......
......@@ -65,7 +65,8 @@ namespace ngraph
values->get_shape(),
vector<float>(shape_size(values->get_shape()), 1.f / p_norm));
return {make_shared<op::Power>(values, inv_p_node)};
return {make_shared<op::Power>(values, inv_p_node)
->add_provenance_group_members_above({value})};
}
}
......@@ -81,7 +82,8 @@ namespace ngraph
shared_ptr<Node> non_zero_values = make_shared<op::Convert>(
make_shared<op::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<op::Sum>(non_zero_values, reduction_axes);
return make_shared<op::Sum>(non_zero_values, reduction_axes)
->add_provenance_group_members_above({value});
}
shared_ptr<Node>
......@@ -95,7 +97,7 @@ namespace ngraph
values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))};
return values + bias_node;
return (values + bias_node)->add_provenance_group_members_above({value});
}
shared_ptr<Node> l2_norm(const Output<Node>& value,
......@@ -109,16 +111,18 @@ namespace ngraph
op::Constant::create(values->get_element_type(),
values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))};
shared_ptr<Node> result;
switch (bias_mode)
{
case BiasMode::MAX:
{
return {make_shared<op::Sqrt>(make_shared<op::Maximum>(values, bias_node))};
result = make_shared<op::Sqrt>(make_shared<op::Maximum>(values, bias_node));
break;
}
case BiasMode::ADD:
default: { return {make_shared<op::Sqrt>(values + bias_node)};
}
default: result = make_shared<op::Sqrt>(values + bias_node);
}
return result->add_provenance_group_members_above({value});
}
shared_ptr<Node> lp_norm(const Output<Node>& value,
......
......@@ -74,7 +74,8 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape);
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape)
->add_provenance_group_members_above({value});
}
} // namespace builder
......
......@@ -62,7 +62,6 @@ namespace ngraph
mybias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedConvolutionBias>(input,
filter,
mybias,
......@@ -72,7 +71,9 @@ namespace ngraph
padding_above,
data_dilation_strides,
requantization_scale,
false);
false)
->add_provenance_group_members_above(
{input, filter, bias, input_scale, filter_scale, output_scale});
}
}
}
......
......@@ -26,7 +26,8 @@ namespace ngraph
{
auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b);
return std::make_shared<op::Maximum>(abs_a, abs_b)
->add_provenance_group_members_above({a, b});
}
std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
......@@ -72,7 +73,8 @@ namespace ngraph
auto max_abs_range = max_abs(min_range, max_range);
auto target_range = make_constant(type, shape, range);
return max_abs_range / target_range;
return (max_abs_range / target_range)
->add_provenance_group_members_above({input_min_range, input_max_range});
}
std::shared_ptr<Node> get_bias_scale(Output<Node> min_input,
......
......@@ -52,7 +52,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_utils::get_scale(min, max, quant_type, true);
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode);
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode)
->add_provenance_group_members_above({input, min, max});
}
}
}
......@@ -58,8 +58,17 @@ namespace ngraph
AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
}
return make_shared<op::Concat>(rescaled_args, concatenation_axis);
OutputVector base = as_output_vector(args);
for (auto node : mins)
{
base.push_back(node);
};
for (auto node : maxs)
{
base.push_back(node);
};
return make_shared<op::Concat>(rescaled_args, concatenation_axis)
->add_provenance_group_members_above(base);
}
}
}
......@@ -55,23 +55,31 @@ namespace ngraph
auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>(
input,
filters,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
filter_zero_point, // output type will be same as filter
output_type,
input_axes,
filter_axes,
output_axes);
input,
filters,
window_movement_strides,
window_dilation_strides,
padding_below,
padding_above,
data_dilation_strides,
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
filter_zero_point, // output type will be same as filter
output_type,
input_axes,
filter_axes,
output_axes)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
shared_ptr<Node> QuantizedConvolutionBiasBuilder(const Output<Node>& input,
......@@ -121,7 +129,16 @@ namespace ngraph
padding_above,
data_dilation_strides,
requantization_scale,
with_relu);
with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
shared_ptr<Node> QuantizedConvolutionReluBuilder(const Output<Node>& input,
......@@ -152,7 +169,15 @@ namespace ngraph
padding_below,
padding_above,
data_dilation_strides,
requantization_scale);
requantization_scale)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
shared_ptr<Node> QuantizedConvolutionBiasAddBuilder(const Output<Node>& input,
......@@ -210,7 +235,19 @@ namespace ngraph
data_dilation_strides,
requantization_scale,
sum_scale,
with_relu);
with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
}
shared_ptr<Node>
......@@ -275,7 +312,19 @@ namespace ngraph
requantization_scale,
sum_scale,
with_relu);
return make_shared<op::Convert>(qconv, element::u8);
}
return make_shared<op::Convert>(qconv, element::u8)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
};
}
}
......@@ -62,7 +62,8 @@ namespace ngraph
output_type,
input0_axes,
input1_axes,
output_axes);
output_axes)
->add_provenance_group_members_above({input0, input1});
}
shared_ptr<Node> QuantizedDotBiasBuilder(const Output<Node>& input,
......@@ -102,7 +103,16 @@ namespace ngraph
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedDotBias>(
input, filters, mybias, requantization_scale, requantize, with_relu);
input, filters, mybias, requantization_scale, requantize, with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
}
}
......@@ -47,7 +47,7 @@ namespace ngraph
auto x2 = node * node;
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
return std::make_shared<op::Sqrt>(x2sum);
return std::make_shared<op::Sqrt>(x2sum)->add_provenance_group_members_above({node});
}
std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes)
......@@ -59,14 +59,15 @@ namespace ngraph
auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor;
return (xsum / divisor)->add_provenance_group_members_above({value});
}
std::shared_ptr<Node> std_dev(const Output<Node>& node,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction));
return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction))
->add_provenance_group_members_above({node});
}
// This currently calculates [E[X^2] - E[X]^2] instead of [E[(X-\mu)^2]]
......@@ -96,16 +97,18 @@ namespace ngraph
const auto& et = value.get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes);
std::shared_ptr<Node> result;
if (bessel_correction)
{
auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1});
return diff / N1const;
result = diff / N1const;
}
else
{
auto Nconst = op::Constant::create(et, diff.get_shape(), {N});
return diff / Nconst;
result = diff / Nconst;
}
return result->add_provenance_group_members_above({value});
}
} // namespace builder
......
......@@ -35,7 +35,8 @@ using namespace std;
shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape)
{
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape);
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
......@@ -55,7 +56,8 @@ shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t>
}
auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return make_shared<op::Reshape>(value, axis_vector, out_shape);
return make_shared<op::Reshape>(value, axis_vector, out_shape)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::transpose(const Output<Node>& value)
......@@ -80,7 +82,8 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis)
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>(
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size})
->add_provenance_group_members_above({value});
}
// Dynamic version of "flatten".
......@@ -118,7 +121,8 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, const Output<Node>&
auto flattened_dims = make_shared<op::Concat>(NodeVector{row_dims_prod, col_dims_prod}, 0);
// result := DynReshape(value, flattened_dims)
return make_shared<op::DynReshape>(value, flattened_dims);
return make_shared<op::DynReshape>(value, flattened_dims)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::squeeze(const Output<Node>& value, vector<size_t> axes)
......@@ -166,5 +170,6 @@ shared_ptr<Node> builder::expand_dims(const Output<Node>& value, size_t axis)
advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1);
return make_shared<op::Reshape>(
value, get_default_order(value.get_shape().size()), output_shape);
value, get_default_order(value.get_shape().size()), output_shape)
->add_provenance_group_members_above({value});
}
......@@ -41,7 +41,9 @@ namespace
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), output.get_shape().at(axis));
}
return std::make_shared<op::Slice>(output, lower_bounds, upper_bounds);
return std::static_pointer_cast<op::Slice>(
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
}
}
......
......@@ -96,7 +96,8 @@ namespace ngraph
convert_sequence, mask_shape, non_sequence_axes);
// mask = sequence_length < sequence
return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths);
return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths)
->add_provenance_group_members_above({sequence_lengths});
}
}
}
......@@ -323,21 +323,111 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement;
}
const std::set<std::string>& Node::get_provenance_tags() const
void Node::add_provenance_group_member(const shared_ptr<Node>& node)
{
return m_provenance_tags;
m_provenance_group.insert(node);
}
void Node::add_provenance_tag(const std::string& tag)
void Node::remove_provenance_group_member(const shared_ptr<Node>& node)
{
m_provenance_tags.insert(tag);
m_provenance_group.erase(node);
}
void Node::add_provenance_tags(const std::set<std::string>& tag_set)
void Node::replace_provenance_group_member(const shared_ptr<Node>& current_node,
const shared_ptr<Node>& replacement_node)
{
for (auto tag : tag_set)
// Catch up with the current state of the group
replacement_node->add_provenance_tags(get_provenance_tags());
if (current_node != nullptr)
{
add_provenance_tag(tag);
remove_provenance_group_member(current_node);
// Catch up with what was added to the current node
replacement_node->add_provenance_tags(current_node->get_provenance_tags());
}
add_provenance_group_member(replacement_node);
}
const set<shared_ptr<Node>>& Node::get_provenance_group_members() const
{
return m_provenance_group;
}
shared_ptr<Node> Node::add_provenance_group_members_above(const OutputVector& base)
{
set<Node*> base_set;
for (auto& output : base)
{
Node* node = output.get_node();
if (node == this)
{
// A builder did nothing
return shared_from_this();
}
base_set.insert(node);
}
vector<Node*> todo;
for (auto value : input_values())
{
todo.push_back(value.get_node());
}
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0)
{
continue;
}
add_provenance_group_member(node->shared_from_this());
for (auto value : node->input_values())
{
if (0 == node->m_provenance_group.count(value.get_node_shared_ptr()))
{
todo.push_back(value.get_node());
}
}
base_set.insert(node);
}
return shared_from_this();
}
void Node::add_provenance_tags_above(const OutputVector& base,
const std::unordered_set<std::string>& tag_set)
{
set<Node*> base_set;
for (auto& output : base)
{
base_set.insert(output.get_node());
}
vector<Node*> todo{this};
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0)
{
continue;
}
node->add_provenance_tags(tag_set);
for (auto value : node->input_values())
{
todo.push_back(value.get_node());
}
base_set.insert(node);
}
}
const std::unordered_set<std::string>& Node::get_provenance_tags() const
{
return m_provenance_tags;
}
void Node::add_provenance_tag(const std::string& tag)
{
m_provenance_tags.insert(tag);
for (auto node : m_provenance_group)
{
node->add_provenance_tag(tag);
}
}
......
......@@ -409,10 +409,33 @@ namespace ngraph
/// Set device placement
void set_placement_index(size_t placement);
const std::set<std::string>& get_provenance_tags() const;
const std::unordered_set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set);
template <typename T>
void add_provenance_tags(T tag_set)
{
for (auto tag : tag_set)
{
add_provenance_tag(tag);
}
}
/// \brief Adds tag_set to this node and all intermediate nodes above base
void add_provenance_tags_above(const OutputVector& base,
const std::unordered_set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag);
/// \brief Add node to additional nodes that receive tags
void add_provenance_group_member(const std::shared_ptr<Node>& node);
/// \brief Remove node to additional nodes that receive tags
void remove_provenance_group_member(const std::shared_ptr<Node>& node);
/// \brief Replace current_node with replacement_node and transfer tags
void replace_provenance_group_member(const std::shared_ptr<Node>& current_node,
const std::shared_ptr<Node>& replacement_node);
/// \return Provenance group nodes
const std::set<std::shared_ptr<Node>>& get_provenance_group_members() const;
/// \brief Add all nodes between this node and nodes in base as additional nodes to receive
/// provenance tags.
std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
// to be used when nodes are replaced
void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
......@@ -474,7 +497,8 @@ namespace ngraph
std::string m_unique_name;
NGRAPH_API
static std::atomic<size_t> m_next_instance_id;
std::set<std::string> m_provenance_tags;
std::unordered_set<std::string> m_provenance_tags;
std::set<std::shared_ptr<Node>> m_provenance_group;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
......
......@@ -54,6 +54,10 @@ op::GenerateMask::GenerateMask(const Output<Node>& training,
make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{use_seed}));
set_argument(3, make_shared<op::Constant>(element::u64, Shape{}, std::vector<uint64_t>{seed}));
set_argument(4, make_shared<op::Constant>(element::f64, Shape{}, std::vector<double>{prob}));
add_provenance_group_member(input_value(1).get_node_shared_ptr());
add_provenance_group_member(input_value(2).get_node_shared_ptr());
add_provenance_group_member(input_value(3).get_node_shared_ptr());
add_provenance_group_member(input_value(4).get_node_shared_ptr());
constructor_validate_and_infer_types();
}
......
......@@ -26,6 +26,7 @@ constexpr NodeTypeInfo op::LRN::type_info;
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: LRN(arg, op::Constant::create(element::i64, Shape{1}, {1}), alpha, beta, bias, size)
{
add_provenance_group_member(input_value(1).get_node_shared_ptr());
}
op::LRN::LRN(const Output<Node>& arg,
......
......@@ -38,6 +38,7 @@ op::TopK::TopK(const Output<Node>& arg,
, m_compute_max(compute_max)
, m_sort(sort)
{
add_provenance_group_member(input_value(1).get_node_shared_ptr());
constructor_validate_and_infer_types();
}
......@@ -72,8 +73,10 @@ size_t op::TopK::get_k() const
void op::TopK::set_k(size_t k)
{
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{1}, {k})->output(0));
shared_ptr<Node> current_const =
get_input_size() == 1 ? nullptr : input_value(1).get_node_shared_ptr();
auto replacement_const = op::Constant::create(element::i64, Shape{1}, {k})->output(0);
replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr());
}
void op::TopK::validate_and_infer_types()
......
......@@ -31,6 +31,7 @@ op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{
add_provenance_group_member(input_value(1).get_node_shared_ptr());
}
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
......
......@@ -30,6 +30,7 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const Axis
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{
add_provenance_group_member(input_value(1).get_node_shared_ptr());
}
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg,
......
......@@ -36,8 +36,19 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
// Op supported by backend. Do not decompose
return modified;
}
// Capture the input values as a base for provenance
OutputVector base_input_values;
for (auto value : node->input_values())
{
base_input_values.push_back(value);
}
auto subgraph_outputs = node->decompose_op();
// Transfer the new provenance tags to the newly created ops
auto provenance_tags = node->get_provenance_tags();
for (auto subgraph : subgraph_outputs)
{
subgraph->add_provenance_tags_above(base_input_values, provenance_tags);
}
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, node->get_arguments());
for (auto subgraph_node : subgraph)
......
......@@ -22,14 +22,18 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/builder/norm.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
using namespace ngraph;
using ::testing::Return;
using ProvSet = std::set<std::string>;
using ProvSet = std::unordered_set<std::string>;
TEST(provenance, provenance)
{
......@@ -313,3 +317,127 @@ TEST(provenance, provenance)
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_e"}));
}
}
TEST(provenance, add_group_above)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto p2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p2->add_provenance_tag("P2");
auto a1 = p1 + p2;
auto m1 = (a1 * a1)->add_provenance_group_members_above({p1, p2});
m1->add_provenance_tag("m1");
EXPECT_EQ(p1->get_provenance_tags(), (ProvSet{"P1"}));
EXPECT_EQ(p2->get_provenance_tags(), (ProvSet{"P2"}));
EXPECT_EQ(a1->get_provenance_tags(), (ProvSet{"m1"}));
EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"}));
}
TEST(provenance, builder)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto norm = builder::lp_norm(p1, {0}, 1, 0);
norm->add_provenance_tag("norm");
for (auto node : topological_sort(NodeVector{norm}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"norm"}));
}
}
}
TEST(provenance, fused)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto g = make_shared<op::Gelu>(p1);
g->add_provenance_tag("G");
auto r = make_shared<op::Result>(g);
auto f = make_shared<Function>(ResultVector{r}, ParameterVector{p1});
pass::Manager manager;
manager.register_pass<pass::FusedOpDecomposition>();
manager.run_passes(f);
traverse_nodes(f, [&](const std::shared_ptr<Node>& node) {
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == r)
{
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"G"}));
}
});
}
TEST(provenance, topk_setk)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{20, 3, 4});
p1->add_provenance_tag("P1");
auto tk = make_shared<op::TopK>(p1, 0, element::i32, 10);
tk->add_provenance_tag("TK");
auto tkc0 = tk->input_value(1).get_node_shared_ptr();
tkc0->add_provenance_tag("TKC0");
for (auto node : topological_sort(NodeVector{tk}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == tkc0)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK", "TKC0"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK"}));
}
}
tk->set_k(5);
auto tkc1 = tk->input_value(1).get_node_shared_ptr();
tkc1->add_provenance_tag("TKC1");
for (auto node : topological_sort(NodeVector{tk}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == tkc1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK", "TKC0", "TKC1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK"}));
}
}
}
TEST(provenance, empty_group)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto abs = make_shared<op::Abs>(p1);
// Make sure group is empty
abs->add_provenance_group_members_above({abs});
abs->add_provenance_tag("abs");
for (auto node : topological_sort(NodeVector{abs}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"abs"}));
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment