Unverified Commit 8ed67b44 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Master edition of Provenance for builders (#3647)

* Provenance for builders (#3644)

* Provenance for builders

* Tests

* Tests

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarSayantan Sarkar <sayantan.sarkar@intel.com>

* style

Finish migration

* Migration to master

* Handle op constructors that construct ops

* Update provenance for topk

* Add missing input (#3691)

* Add missing input

* If builder does nothing, don't dive into the inputs

* Reverse test

* simplify, add test

* Add test

* add test
parent af2b137b
...@@ -172,7 +172,7 @@ namespace ngraph ...@@ -172,7 +172,7 @@ namespace ngraph
return_value, final_shape, broadcast_axes); return_value, final_shape, broadcast_axes);
} }
return return_value.get_node_shared_ptr(); return return_value.get_node_shared_ptr()->add_provenance_group_members_above({value});
} }
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
......
...@@ -51,7 +51,8 @@ namespace ngraph ...@@ -51,7 +51,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0); auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_utils::get_scale(min, max, quant_type); auto scale = quantization_utils::get_scale(min, max, quant_type);
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes); return make_shared<op::Dequantize>(input, scale, zero, real_type, axes)
->add_provenance_group_members_above({input, min, max});
} }
} }
} }
...@@ -103,7 +103,7 @@ namespace ngraph ...@@ -103,7 +103,7 @@ namespace ngraph
val = std::make_shared<ngraph::op::Broadcast>(val, shape, axes); val = std::make_shared<ngraph::op::Broadcast>(val, shape, axes);
} }
return val; return val->add_provenance_group_members_above({});
} }
} }
} }
...@@ -80,7 +80,9 @@ NodeVector builder::MatmulFactory::make_matmul_op() ...@@ -80,7 +80,9 @@ NodeVector builder::MatmulFactory::make_matmul_op()
// Multiply two tensors where both of them has rank lower equal 2. // Multiply two tensors where both of them has rank lower equal 2.
if (left_rank <= 2 && right_rank <= 2) if (left_rank <= 2 && right_rank <= 2)
{ {
return {make_dot(left, right).get_node_shared_ptr()}; return {make_dot(left, right)
.get_node_shared_ptr()
->add_provenance_group_members_above(m_inputs)};
} }
// Second case: // Second case:
...@@ -135,7 +137,7 @@ NodeVector builder::MatmulFactory::make_matmul_op() ...@@ -135,7 +137,7 @@ NodeVector builder::MatmulFactory::make_matmul_op()
if (left_shape.size() <= 3 && right_shape.size() <= 3) if (left_shape.size() <= 3 && right_shape.size() <= 3)
{ {
return {result}; return {result->add_provenance_group_members_above(m_inputs)};
} }
// Expand result _stack of matrices_ axes to get expected result shape. // Expand result _stack of matrices_ axes to get expected result shape.
else else
...@@ -144,7 +146,8 @@ NodeVector builder::MatmulFactory::make_matmul_op() ...@@ -144,7 +146,8 @@ NodeVector builder::MatmulFactory::make_matmul_op()
Shape result_shape(next(begin(shape)), end(shape)); Shape result_shape(next(begin(shape)), end(shape));
result_shape.insert( result_shape.insert(
begin(result_shape), begin(left_shape), next(begin(left_shape), left_shape.size() - 2)); begin(result_shape), begin(left_shape), next(begin(left_shape), left_shape.size() - 2));
return {make_shared<op::Reshape>(result, get_default_order(shape.size()), result_shape)}; return {make_shared<op::Reshape>(result, get_default_order(shape.size()), result_shape)
->add_provenance_group_members_above(m_inputs)};
} }
} }
......
...@@ -65,7 +65,8 @@ namespace ngraph ...@@ -65,7 +65,8 @@ namespace ngraph
values->get_shape(), values->get_shape(),
vector<float>(shape_size(values->get_shape()), 1.f / p_norm)); vector<float>(shape_size(values->get_shape()), 1.f / p_norm));
return {make_shared<op::Power>(values, inv_p_node)}; return {make_shared<op::Power>(values, inv_p_node)
->add_provenance_group_members_above({value})};
} }
} }
...@@ -81,7 +82,8 @@ namespace ngraph ...@@ -81,7 +82,8 @@ namespace ngraph
shared_ptr<Node> non_zero_values = make_shared<op::Convert>( shared_ptr<Node> non_zero_values = make_shared<op::Convert>(
make_shared<op::NotEqual>(value, zero_node), value.get_element_type()); make_shared<op::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<op::Sum>(non_zero_values, reduction_axes); return make_shared<op::Sum>(non_zero_values, reduction_axes)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> shared_ptr<Node>
...@@ -95,7 +97,7 @@ namespace ngraph ...@@ -95,7 +97,7 @@ namespace ngraph
values->get_shape(), values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))}; vector<float>(shape_size(values->get_shape()), bias))};
return values + bias_node; return (values + bias_node)->add_provenance_group_members_above({value});
} }
shared_ptr<Node> l2_norm(const Output<Node>& value, shared_ptr<Node> l2_norm(const Output<Node>& value,
...@@ -109,16 +111,18 @@ namespace ngraph ...@@ -109,16 +111,18 @@ namespace ngraph
op::Constant::create(values->get_element_type(), op::Constant::create(values->get_element_type(),
values->get_shape(), values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))}; vector<float>(shape_size(values->get_shape()), bias))};
shared_ptr<Node> result;
switch (bias_mode) switch (bias_mode)
{ {
case BiasMode::MAX: case BiasMode::MAX:
{ {
return {make_shared<op::Sqrt>(make_shared<op::Maximum>(values, bias_node))}; result = make_shared<op::Sqrt>(make_shared<op::Maximum>(values, bias_node));
break;
} }
case BiasMode::ADD: case BiasMode::ADD:
default: { return {make_shared<op::Sqrt>(values + bias_node)}; default: result = make_shared<op::Sqrt>(values + bias_node);
}
} }
return result->add_provenance_group_members_above({value});
} }
shared_ptr<Node> lp_norm(const Output<Node>& value, shared_ptr<Node> lp_norm(const Output<Node>& value,
......
...@@ -74,7 +74,8 @@ namespace ngraph ...@@ -74,7 +74,8 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]); out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order // do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape); return std::make_shared<ngraph::op::Reshape>(value, order, out_shape)
->add_provenance_group_members_above({value});
} }
} // namespace builder } // namespace builder
......
...@@ -62,7 +62,6 @@ namespace ngraph ...@@ -62,7 +62,6 @@ namespace ngraph
mybias = make_shared<op::Quantize>( mybias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode); bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
} }
return make_shared<op::QuantizedConvolutionBias>(input, return make_shared<op::QuantizedConvolutionBias>(input,
filter, filter,
mybias, mybias,
...@@ -72,7 +71,9 @@ namespace ngraph ...@@ -72,7 +71,9 @@ namespace ngraph
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
false); false)
->add_provenance_group_members_above(
{input, filter, bias, input_scale, filter_scale, output_scale});
} }
} }
} }
......
...@@ -26,7 +26,8 @@ namespace ngraph ...@@ -26,7 +26,8 @@ namespace ngraph
{ {
auto abs_a = std::make_shared<op::Abs>(a); auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b); auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b); return std::make_shared<op::Maximum>(abs_a, abs_b)
->add_provenance_group_members_above({a, b});
} }
std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range, std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
...@@ -72,7 +73,8 @@ namespace ngraph ...@@ -72,7 +73,8 @@ namespace ngraph
auto max_abs_range = max_abs(min_range, max_range); auto max_abs_range = max_abs(min_range, max_range);
auto target_range = make_constant(type, shape, range); auto target_range = make_constant(type, shape, range);
return max_abs_range / target_range; return (max_abs_range / target_range)
->add_provenance_group_members_above({input_min_range, input_max_range});
} }
std::shared_ptr<Node> get_bias_scale(Output<Node> min_input, std::shared_ptr<Node> get_bias_scale(Output<Node> min_input,
......
...@@ -52,7 +52,8 @@ namespace ngraph ...@@ -52,7 +52,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0); auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_utils::get_scale(min, max, quant_type, true); auto scale = quantization_utils::get_scale(min, max, quant_type, true);
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode); return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode)
->add_provenance_group_members_above({input, min, max});
} }
} }
} }
...@@ -58,8 +58,17 @@ namespace ngraph ...@@ -58,8 +58,17 @@ namespace ngraph
AxisSet{}, AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
} }
OutputVector base = as_output_vector(args);
return make_shared<op::Concat>(rescaled_args, concatenation_axis); for (auto node : mins)
{
base.push_back(node);
};
for (auto node : maxs)
{
base.push_back(node);
};
return make_shared<op::Concat>(rescaled_args, concatenation_axis)
->add_provenance_group_members_above(base);
} }
} }
} }
...@@ -55,23 +55,31 @@ namespace ngraph ...@@ -55,23 +55,31 @@ namespace ngraph
auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0}); auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>( return make_shared<op::QuantizedConvolution>(
input, input,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
input_scale, input_scale,
input_zero_point, input_zero_point,
filter_scale, filter_scale,
filter_zero_point, filter_zero_point,
output_scale, output_scale,
filter_zero_point, // output type will be same as filter filter_zero_point, // output type will be same as filter
output_type, output_type,
input_axes, input_axes,
filter_axes, filter_axes,
output_axes); output_axes)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
shared_ptr<Node> QuantizedConvolutionBiasBuilder(const Output<Node>& input, shared_ptr<Node> QuantizedConvolutionBiasBuilder(const Output<Node>& input,
...@@ -121,7 +129,16 @@ namespace ngraph ...@@ -121,7 +129,16 @@ namespace ngraph
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
with_relu); with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
shared_ptr<Node> QuantizedConvolutionReluBuilder(const Output<Node>& input, shared_ptr<Node> QuantizedConvolutionReluBuilder(const Output<Node>& input,
...@@ -152,7 +169,15 @@ namespace ngraph ...@@ -152,7 +169,15 @@ namespace ngraph
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requantization_scale); requantization_scale)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
shared_ptr<Node> QuantizedConvolutionBiasAddBuilder(const Output<Node>& input, shared_ptr<Node> QuantizedConvolutionBiasAddBuilder(const Output<Node>& input,
...@@ -210,7 +235,19 @@ namespace ngraph ...@@ -210,7 +235,19 @@ namespace ngraph
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
sum_scale, sum_scale,
with_relu); with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
} }
shared_ptr<Node> shared_ptr<Node>
...@@ -275,7 +312,19 @@ namespace ngraph ...@@ -275,7 +312,19 @@ namespace ngraph
requantization_scale, requantization_scale,
sum_scale, sum_scale,
with_relu); with_relu);
return make_shared<op::Convert>(qconv, element::u8); return make_shared<op::Convert>(qconv, element::u8)
} ->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
};
} }
} }
...@@ -62,7 +62,8 @@ namespace ngraph ...@@ -62,7 +62,8 @@ namespace ngraph
output_type, output_type,
input0_axes, input0_axes,
input1_axes, input1_axes,
output_axes); output_axes)
->add_provenance_group_members_above({input0, input1});
} }
shared_ptr<Node> QuantizedDotBiasBuilder(const Output<Node>& input, shared_ptr<Node> QuantizedDotBiasBuilder(const Output<Node>& input,
...@@ -102,7 +103,16 @@ namespace ngraph ...@@ -102,7 +103,16 @@ namespace ngraph
bias, bias_scale, zero, element::i32, quantization_axes, round_mode); bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
} }
return make_shared<op::QuantizedDotBias>( return make_shared<op::QuantizedDotBias>(
input, filters, mybias, requantization_scale, requantize, with_relu); input, filters, mybias, requantization_scale, requantize, with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
} }
} }
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
auto x2 = node * node; auto x2 = node * node;
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes); auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
return std::make_shared<op::Sqrt>(x2sum); return std::make_shared<op::Sqrt>(x2sum)->add_provenance_group_members_above({node});
} }
std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes) std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes)
...@@ -59,14 +59,15 @@ namespace ngraph ...@@ -59,14 +59,15 @@ namespace ngraph
auto divisor = op::Constant::create(et, xsum->get_shape(), {N}); auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor; return (xsum / divisor)->add_provenance_group_members_above({value});
} }
std::shared_ptr<Node> std_dev(const Output<Node>& node, std::shared_ptr<Node> std_dev(const Output<Node>& node,
const AxisSet& reduction_axes, const AxisSet& reduction_axes,
const bool bessel_correction) const bool bessel_correction)
{ {
return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction)); return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction))
->add_provenance_group_members_above({node});
} }
// This currently calculates [E[X^2] - E[X]^2] instead of [E[(X-\mu)^2]] // This currently calculates [E[X^2] - E[X]^2] instead of [E[(X-\mu)^2]]
...@@ -96,16 +97,18 @@ namespace ngraph ...@@ -96,16 +97,18 @@ namespace ngraph
const auto& et = value.get_element_type(); const auto& et = value.get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes); auto N = get_num_elements(value.get_shape(), reduction_axes);
std::shared_ptr<Node> result;
if (bessel_correction) if (bessel_correction)
{ {
auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1}); auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1});
return diff / N1const; result = diff / N1const;
} }
else else
{ {
auto Nconst = op::Constant::create(et, diff.get_shape(), {N}); auto Nconst = op::Constant::create(et, diff.get_shape(), {N});
return diff / Nconst; result = diff / Nconst;
} }
return result->add_provenance_group_members_above({value});
} }
} // namespace builder } // namespace builder
......
...@@ -35,7 +35,8 @@ using namespace std; ...@@ -35,7 +35,8 @@ using namespace std;
shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape) shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape)
{ {
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape); return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order) shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
...@@ -55,7 +56,8 @@ shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> ...@@ -55,7 +56,8 @@ shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t>
} }
auto axis_vector = AxisVector{begin(axes_order), end(axes_order)}; auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return make_shared<op::Reshape>(value, axis_vector, out_shape); return make_shared<op::Reshape>(value, axis_vector, out_shape)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::transpose(const Output<Node>& value) shared_ptr<Node> builder::transpose(const Output<Node>& value)
...@@ -80,7 +82,8 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis) ...@@ -80,7 +82,8 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis)
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>()); accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>( return make_shared<op::Reshape>(
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size}); value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size})
->add_provenance_group_members_above({value});
} }
// Dynamic version of "flatten". // Dynamic version of "flatten".
...@@ -118,7 +121,8 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, const Output<Node>& ...@@ -118,7 +121,8 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, const Output<Node>&
auto flattened_dims = make_shared<op::Concat>(NodeVector{row_dims_prod, col_dims_prod}, 0); auto flattened_dims = make_shared<op::Concat>(NodeVector{row_dims_prod, col_dims_prod}, 0);
// result := DynReshape(value, flattened_dims) // result := DynReshape(value, flattened_dims)
return make_shared<op::DynReshape>(value, flattened_dims); return make_shared<op::DynReshape>(value, flattened_dims)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::squeeze(const Output<Node>& value, vector<size_t> axes) shared_ptr<Node> builder::squeeze(const Output<Node>& value, vector<size_t> axes)
...@@ -166,5 +170,6 @@ shared_ptr<Node> builder::expand_dims(const Output<Node>& value, size_t axis) ...@@ -166,5 +170,6 @@ shared_ptr<Node> builder::expand_dims(const Output<Node>& value, size_t axis)
advance(empty_axis_it, axis); advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1); output_shape.insert(empty_axis_it, 1);
return make_shared<op::Reshape>( return make_shared<op::Reshape>(
value, get_default_order(value.get_shape().size()), output_shape); value, get_default_order(value.get_shape().size()), output_shape)
->add_provenance_group_members_above({value});
} }
...@@ -41,7 +41,9 @@ namespace ...@@ -41,7 +41,9 @@ namespace
upper_bounds.at(axis) = upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), output.get_shape().at(axis)); get_valid_array_index(ends.at(index), output.get_shape().at(axis));
} }
return std::make_shared<op::Slice>(output, lower_bounds, upper_bounds); return std::static_pointer_cast<op::Slice>(
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
} }
} }
......
...@@ -96,7 +96,8 @@ namespace ngraph ...@@ -96,7 +96,8 @@ namespace ngraph
convert_sequence, mask_shape, non_sequence_axes); convert_sequence, mask_shape, non_sequence_axes);
// mask = sequence_length < sequence // mask = sequence_length < sequence
return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths); return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths)
->add_provenance_group_members_above({sequence_lengths});
} }
} }
} }
...@@ -323,21 +323,111 @@ void Node::set_placement_index(size_t placement) ...@@ -323,21 +323,111 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement; m_placement_index = placement;
} }
const std::set<std::string>& Node::get_provenance_tags() const void Node::add_provenance_group_member(const shared_ptr<Node>& node)
{ {
return m_provenance_tags; m_provenance_group.insert(node);
} }
void Node::add_provenance_tag(const std::string& tag) void Node::remove_provenance_group_member(const shared_ptr<Node>& node)
{ {
m_provenance_tags.insert(tag); m_provenance_group.erase(node);
} }
void Node::add_provenance_tags(const std::set<std::string>& tag_set) void Node::replace_provenance_group_member(const shared_ptr<Node>& current_node,
const shared_ptr<Node>& replacement_node)
{ {
for (auto tag : tag_set) // Catch up with the current state of the group
replacement_node->add_provenance_tags(get_provenance_tags());
if (current_node != nullptr)
{ {
add_provenance_tag(tag); remove_provenance_group_member(current_node);
// Catch up with what was added to the current node
replacement_node->add_provenance_tags(current_node->get_provenance_tags());
}
add_provenance_group_member(replacement_node);
}
const set<shared_ptr<Node>>& Node::get_provenance_group_members() const
{
return m_provenance_group;
}
shared_ptr<Node> Node::add_provenance_group_members_above(const OutputVector& base)
{
set<Node*> base_set;
for (auto& output : base)
{
Node* node = output.get_node();
if (node == this)
{
// A builder did nothing
return shared_from_this();
}
base_set.insert(node);
}
vector<Node*> todo;
for (auto value : input_values())
{
todo.push_back(value.get_node());
}
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0)
{
continue;
}
add_provenance_group_member(node->shared_from_this());
for (auto value : node->input_values())
{
if (0 == node->m_provenance_group.count(value.get_node_shared_ptr()))
{
todo.push_back(value.get_node());
}
}
base_set.insert(node);
}
return shared_from_this();
}
void Node::add_provenance_tags_above(const OutputVector& base,
const std::unordered_set<std::string>& tag_set)
{
set<Node*> base_set;
for (auto& output : base)
{
base_set.insert(output.get_node());
}
vector<Node*> todo{this};
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0)
{
continue;
}
node->add_provenance_tags(tag_set);
for (auto value : node->input_values())
{
todo.push_back(value.get_node());
}
base_set.insert(node);
}
}
const std::unordered_set<std::string>& Node::get_provenance_tags() const
{
return m_provenance_tags;
}
void Node::add_provenance_tag(const std::string& tag)
{
m_provenance_tags.insert(tag);
for (auto node : m_provenance_group)
{
node->add_provenance_tag(tag);
} }
} }
......
...@@ -409,10 +409,33 @@ namespace ngraph ...@@ -409,10 +409,33 @@ namespace ngraph
/// Set device placement /// Set device placement
void set_placement_index(size_t placement); void set_placement_index(size_t placement);
const std::set<std::string>& get_provenance_tags() const; const std::unordered_set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag); void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set); template <typename T>
void add_provenance_tags(T tag_set)
{
for (auto tag : tag_set)
{
add_provenance_tag(tag);
}
}
/// \brief Adds tag_set to this node and all intermediate nodes above base
void add_provenance_tags_above(const OutputVector& base,
const std::unordered_set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag); void remove_provenance_tag(const std::string& tag);
/// \brief Add node to additional nodes that receive tags
void add_provenance_group_member(const std::shared_ptr<Node>& node);
/// \brief Remove node to additional nodes that receive tags
void remove_provenance_group_member(const std::shared_ptr<Node>& node);
/// \brief Replace current_node with replacement_node and transfer tags
void replace_provenance_group_member(const std::shared_ptr<Node>& current_node,
const std::shared_ptr<Node>& replacement_node);
/// \return Provenance group nodes
const std::set<std::shared_ptr<Node>>& get_provenance_group_members() const;
/// \brief Add all nodes between this node and nodes in base as additional nodes to receive
/// provenance tags.
std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
// to be used when nodes are replaced // to be used when nodes are replaced
void merge_provenance_tags_from(const std::shared_ptr<const Node>& source); void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
...@@ -474,7 +497,8 @@ namespace ngraph ...@@ -474,7 +497,8 @@ namespace ngraph
std::string m_unique_name; std::string m_unique_name;
NGRAPH_API NGRAPH_API
static std::atomic<size_t> m_next_instance_id; static std::atomic<size_t> m_next_instance_id;
std::set<std::string> m_provenance_tags; std::unordered_set<std::string> m_provenance_tags;
std::set<std::shared_ptr<Node>> m_provenance_group;
std::deque<descriptor::Input> m_inputs; std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs; std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
......
...@@ -54,6 +54,10 @@ op::GenerateMask::GenerateMask(const Output<Node>& training, ...@@ -54,6 +54,10 @@ op::GenerateMask::GenerateMask(const Output<Node>& training,
make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{use_seed})); make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{use_seed}));
set_argument(3, make_shared<op::Constant>(element::u64, Shape{}, std::vector<uint64_t>{seed})); set_argument(3, make_shared<op::Constant>(element::u64, Shape{}, std::vector<uint64_t>{seed}));
set_argument(4, make_shared<op::Constant>(element::f64, Shape{}, std::vector<double>{prob})); set_argument(4, make_shared<op::Constant>(element::f64, Shape{}, std::vector<double>{prob}));
add_provenance_group_member(input_value(1).get_node_shared_ptr());
add_provenance_group_member(input_value(2).get_node_shared_ptr());
add_provenance_group_member(input_value(3).get_node_shared_ptr());
add_provenance_group_member(input_value(4).get_node_shared_ptr());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,7 @@ constexpr NodeTypeInfo op::LRN::type_info; ...@@ -26,6 +26,7 @@ constexpr NodeTypeInfo op::LRN::type_info;
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size) op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: LRN(arg, op::Constant::create(element::i64, Shape{1}, {1}), alpha, beta, bias, size) : LRN(arg, op::Constant::create(element::i64, Shape{1}, {1}), alpha, beta, bias, size)
{ {
add_provenance_group_member(input_value(1).get_node_shared_ptr());
} }
op::LRN::LRN(const Output<Node>& arg, op::LRN::LRN(const Output<Node>& arg,
......
...@@ -38,6 +38,7 @@ op::TopK::TopK(const Output<Node>& arg, ...@@ -38,6 +38,7 @@ op::TopK::TopK(const Output<Node>& arg,
, m_compute_max(compute_max) , m_compute_max(compute_max)
, m_sort(sort) , m_sort(sort)
{ {
add_provenance_group_member(input_value(1).get_node_shared_ptr());
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -72,8 +73,10 @@ size_t op::TopK::get_k() const ...@@ -72,8 +73,10 @@ size_t op::TopK::get_k() const
void op::TopK::set_k(size_t k) void op::TopK::set_k(size_t k)
{ {
this->input(1).replace_source_output( shared_ptr<Node> current_const =
op::Constant::create(element::i64, Shape{1}, {k})->output(0)); get_input_size() == 1 ? nullptr : input_value(1).get_node_shared_ptr();
auto replacement_const = op::Constant::create(element::i64, Shape{1}, {k})->output(0);
replace_provenance_group_member(current_const, replacement_const.get_node_shared_ptr());
} }
void op::TopK::validate_and_infer_types() void op::TopK::validate_and_infer_types()
......
...@@ -31,6 +31,7 @@ op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg, ...@@ -31,6 +31,7 @@ op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()) element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)}) ->output(0)})
{ {
add_provenance_group_member(input_value(1).get_node_shared_ptr());
} }
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg, op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
......
...@@ -30,6 +30,7 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const Axis ...@@ -30,6 +30,7 @@ op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const Axis
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()) element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)}) ->output(0)})
{ {
add_provenance_group_member(input_value(1).get_node_shared_ptr());
} }
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg,
......
...@@ -36,8 +36,19 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -36,8 +36,19 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
// Op supported by backend. Do not decompose // Op supported by backend. Do not decompose
return modified; return modified;
} }
// Capture the input values as a base for provenance
OutputVector base_input_values;
for (auto value : node->input_values())
{
base_input_values.push_back(value);
}
auto subgraph_outputs = node->decompose_op(); auto subgraph_outputs = node->decompose_op();
// Transfer the new provenance tags to the newly created ops
auto provenance_tags = node->get_provenance_tags();
for (auto subgraph : subgraph_outputs)
{
subgraph->add_provenance_tags_above(base_input_values, provenance_tags);
}
// Run recursively untill no more fused ops // Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, node->get_arguments()); auto subgraph = extract_subgraph(subgraph_outputs, node->get_arguments());
for (auto subgraph_node : subgraph) for (auto subgraph_node : subgraph)
......
...@@ -22,14 +22,18 @@ ...@@ -22,14 +22,18 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/builder/norm.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/provenance.hpp" #include "ngraph/provenance.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
using ::testing::Return; using ::testing::Return;
using ProvSet = std::set<std::string>; using ProvSet = std::unordered_set<std::string>;
TEST(provenance, provenance) TEST(provenance, provenance)
{ {
...@@ -313,3 +317,127 @@ TEST(provenance, provenance) ...@@ -313,3 +317,127 @@ TEST(provenance, provenance)
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_e"})); EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_e"}));
} }
} }
TEST(provenance, add_group_above)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto p2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p2->add_provenance_tag("P2");
auto a1 = p1 + p2;
auto m1 = (a1 * a1)->add_provenance_group_members_above({p1, p2});
m1->add_provenance_tag("m1");
EXPECT_EQ(p1->get_provenance_tags(), (ProvSet{"P1"}));
EXPECT_EQ(p2->get_provenance_tags(), (ProvSet{"P2"}));
EXPECT_EQ(a1->get_provenance_tags(), (ProvSet{"m1"}));
EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"}));
}
TEST(provenance, builder)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto norm = builder::lp_norm(p1, {0}, 1, 0);
norm->add_provenance_tag("norm");
for (auto node : topological_sort(NodeVector{norm}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"norm"}));
}
}
}
TEST(provenance, fused)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto g = make_shared<op::Gelu>(p1);
g->add_provenance_tag("G");
auto r = make_shared<op::Result>(g);
auto f = make_shared<Function>(ResultVector{r}, ParameterVector{p1});
pass::Manager manager;
manager.register_pass<pass::FusedOpDecomposition>();
manager.run_passes(f);
traverse_nodes(f, [&](const std::shared_ptr<Node>& node) {
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == r)
{
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"G"}));
}
});
}
TEST(provenance, topk_setk)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{20, 3, 4});
p1->add_provenance_tag("P1");
auto tk = make_shared<op::TopK>(p1, 0, element::i32, 10);
tk->add_provenance_tag("TK");
auto tkc0 = tk->input_value(1).get_node_shared_ptr();
tkc0->add_provenance_tag("TKC0");
for (auto node : topological_sort(NodeVector{tk}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == tkc0)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK", "TKC0"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK"}));
}
}
tk->set_k(5);
auto tkc1 = tk->input_value(1).get_node_shared_ptr();
tkc1->add_provenance_tag("TKC1");
for (auto node : topological_sort(NodeVector{tk}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == tkc1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK", "TKC0", "TKC1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"TK"}));
}
}
}
TEST(provenance, empty_group)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto abs = make_shared<op::Abs>(p1);
// Make sure group is empty
abs->add_provenance_group_members_above({abs});
abs->add_provenance_tag("abs");
for (auto node : topological_sort(NodeVector{abs}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"abs"}));
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment