Unverified Commit 3f672f08 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Provenance for builders (#3644)

* Provenance for builders

* Tests

* Tests

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarSayantan Sarkar <sayantan.sarkar@intel.com>

* style
parent 02909e48
...@@ -172,7 +172,7 @@ namespace ngraph ...@@ -172,7 +172,7 @@ namespace ngraph
return_value, final_shape, broadcast_axes); return_value, final_shape, broadcast_axes);
} }
return return_value.get_node_shared_ptr(); return return_value.get_node_shared_ptr()->add_provenance_group_members_above({value});
} }
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
......
...@@ -103,7 +103,7 @@ namespace ngraph ...@@ -103,7 +103,7 @@ namespace ngraph
val = std::make_shared<ngraph::op::Broadcast>(val, shape, axes); val = std::make_shared<ngraph::op::Broadcast>(val, shape, axes);
} }
return val; return val->add_provenance_group_members_above({});
} }
} }
} }
...@@ -64,7 +64,8 @@ namespace ngraph ...@@ -64,7 +64,8 @@ namespace ngraph
values->get_shape(), values->get_shape(),
vector<float>(shape_size(values->get_shape()), 1.f / p_norm)); vector<float>(shape_size(values->get_shape()), 1.f / p_norm));
return {make_shared<op::Power>(values, inv_p_node)}; return {make_shared<op::Power>(values, inv_p_node)
->add_provenance_group_members_above({value})};
} }
} }
...@@ -80,7 +81,8 @@ namespace ngraph ...@@ -80,7 +81,8 @@ namespace ngraph
shared_ptr<Node> non_zero_values = make_shared<op::Convert>( shared_ptr<Node> non_zero_values = make_shared<op::Convert>(
make_shared<op::NotEqual>(value, zero_node), value.get_element_type()); make_shared<op::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<op::Sum>(non_zero_values, reduction_axes); return make_shared<op::Sum>(non_zero_values, reduction_axes)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> shared_ptr<Node>
...@@ -94,7 +96,7 @@ namespace ngraph ...@@ -94,7 +96,7 @@ namespace ngraph
values->get_shape(), values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))}; vector<float>(shape_size(values->get_shape()), bias))};
return values + bias_node; return (values + bias_node)->add_provenance_group_members_above({value});
} }
shared_ptr<Node> shared_ptr<Node>
...@@ -107,7 +109,8 @@ namespace ngraph ...@@ -107,7 +109,8 @@ namespace ngraph
values->get_shape(), values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))}; vector<float>(shape_size(values->get_shape()), bias))};
return {make_shared<op::Sqrt>(values + bias_node)}; return {make_shared<op::Sqrt>(values + bias_node)
->add_provenance_group_members_above({value})};
} }
shared_ptr<Node> lp_norm(const Output<Node>& value, shared_ptr<Node> lp_norm(const Output<Node>& value,
......
...@@ -74,7 +74,8 @@ namespace ngraph ...@@ -74,7 +74,8 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]); out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order // do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape); return std::make_shared<ngraph::op::Reshape>(value, order, out_shape)
->add_provenance_group_members_above({value});
} }
} // namespace builder } // namespace builder
......
...@@ -60,7 +60,8 @@ namespace ngraph ...@@ -60,7 +60,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0); auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_util::get_scale(min, max, quant_type, true); auto scale = quantization_util::get_scale(min, max, quant_type, true);
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode); return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode)
->add_provenance_group_members_above({input, min, max});
} }
shared_ptr<Node> ScaledDequantize(const Output<Node>& input, shared_ptr<Node> ScaledDequantize(const Output<Node>& input,
...@@ -89,7 +90,8 @@ namespace ngraph ...@@ -89,7 +90,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0); auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_util::get_scale(min, max, quant_type); auto scale = quantization_util::get_scale(min, max, quant_type);
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes); return make_shared<op::Dequantize>(input, scale, zero, real_type, axes)
->add_provenance_group_members_above({input, min, max});
} }
shared_ptr<Node> ScaledQuantizedConcat(const NodeVector& args, shared_ptr<Node> ScaledQuantizedConcat(const NodeVector& args,
...@@ -123,8 +125,21 @@ namespace ngraph ...@@ -123,8 +125,21 @@ namespace ngraph
AxisSet{}, AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
} }
OutputVector base;
return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis); for (auto node : args)
{
base.push_back(node);
}
for (auto node : mins)
{
base.push_back(node);
}
for (auto node : maxs)
{
base.push_back(node);
}
return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis)
->add_provenance_group_members_above(base);
} }
shared_ptr<Node> ScaledQuantizedAvgPool(const Output<Node>& input, shared_ptr<Node> ScaledQuantizedAvgPool(const Output<Node>& input,
...@@ -141,7 +156,8 @@ namespace ngraph ...@@ -141,7 +156,8 @@ namespace ngraph
window_movement_strides, window_movement_strides,
padding_below, padding_below,
padding_above, padding_above,
include_padding_in_avg_computation); include_padding_in_avg_computation)
->add_provenance_group_members_above({input});
} }
shared_ptr<Node> ScaledQuantizedConvolutionBias(const Output<Node>& input, shared_ptr<Node> ScaledQuantizedConvolutionBias(const Output<Node>& input,
...@@ -187,7 +203,16 @@ namespace ngraph ...@@ -187,7 +203,16 @@ namespace ngraph
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
with_relu); with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
shared_ptr<Node> ScaledQuantizedConvolutionRelu(const Output<Node>& input, shared_ptr<Node> ScaledQuantizedConvolutionRelu(const Output<Node>& input,
...@@ -214,7 +239,15 @@ namespace ngraph ...@@ -214,7 +239,15 @@ namespace ngraph
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requantization_scale); requantization_scale)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
shared_ptr<Node> ScaledQuantizedMaxPool(const Output<Node>& input, shared_ptr<Node> ScaledQuantizedMaxPool(const Output<Node>& input,
...@@ -226,7 +259,8 @@ namespace ngraph ...@@ -226,7 +259,8 @@ namespace ngraph
const Output<Node>& /* max */) const Output<Node>& /* max */)
{ {
return make_shared<op::QuantizedMaxPool>( return make_shared<op::QuantizedMaxPool>(
input, window_shape, window_movement_strides, padding_below, padding_above); input, window_shape, window_movement_strides, padding_below, padding_above)
->add_provenance_group_members_above({input});
} }
shared_ptr<Node> ScaledQuantizedConvolutionBiasAdd(const Output<Node>& input, shared_ptr<Node> ScaledQuantizedConvolutionBiasAdd(const Output<Node>& input,
...@@ -280,7 +314,19 @@ namespace ngraph ...@@ -280,7 +314,19 @@ namespace ngraph
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
sum_scale, sum_scale,
with_relu); with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
} }
shared_ptr<Node> shared_ptr<Node>
...@@ -340,7 +386,19 @@ namespace ngraph ...@@ -340,7 +386,19 @@ namespace ngraph
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
sum_scale, sum_scale,
with_relu); with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
return make_shared<op::Convert>(qconv, element::u8); return make_shared<op::Convert>(qconv, element::u8);
} }
...@@ -381,7 +439,9 @@ namespace ngraph ...@@ -381,7 +439,9 @@ namespace ngraph
bias, bias_scale, zero, element::i32, quantization_axes, round_mode); bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
} }
return make_shared<op::QuantizedDotBias>( return make_shared<op::QuantizedDotBias>(
input, filters, mybias, requantization_scale, requantize, with_relu); input, filters, mybias, requantization_scale, requantize, with_relu)
->add_provenance_group_members_above(
{min_input, max_input, min_filter, max_filter, min_output, max_output});
} }
shared_ptr<Node> ScaledQuantizedDot(const Output<Node>& input, shared_ptr<Node> ScaledQuantizedDot(const Output<Node>& input,
...@@ -406,7 +466,15 @@ namespace ngraph ...@@ -406,7 +466,15 @@ namespace ngraph
with_relu ? element::u8 : element::i8, with_relu ? element::u8 : element::i8,
requantize); requantize);
return make_shared<op::QuantizedDot>( return make_shared<op::QuantizedDot>(
input, filters, requantization_scale, requantize, with_relu); input, filters, requantization_scale, requantize, with_relu)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
} // namespace builder } // namespace builder
} // namespace ngraph } // namespace ngraph
...@@ -63,7 +63,6 @@ namespace ngraph ...@@ -63,7 +63,6 @@ namespace ngraph
mybias = make_shared<op::Quantize>( mybias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode); bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
} }
return make_shared<op::QuantizedConvolutionBias>(input, return make_shared<op::QuantizedConvolutionBias>(input,
filter, filter,
mybias, mybias,
...@@ -73,7 +72,9 @@ namespace ngraph ...@@ -73,7 +72,9 @@ namespace ngraph
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
requantization_scale, requantization_scale,
false); false)
->add_provenance_group_members_above(
{input, filter, bias, input_scale, filter_scale, output_scale});
} }
} }
} }
......
...@@ -73,12 +73,21 @@ namespace ngraph ...@@ -73,12 +73,21 @@ namespace ngraph
auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1); auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
return make_shared<op::Quantize>( return make_shared<op::Quantize>(
dot, dot,
output_scale, output_scale,
output_zero_point, output_zero_point,
output_zero_point.get_element_type(), output_zero_point.get_element_type(),
axes, axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
->add_provenance_group_members_above({input0,
input1,
input0_scale,
input0_zero_point,
input1_scale,
input1_scale,
input1_zero_point,
output_scale,
output_zero_point});
} }
} }
...@@ -86,7 +95,8 @@ namespace ngraph ...@@ -86,7 +95,8 @@ namespace ngraph
const Output<Node>& input1) const Output<Node>& input1)
{ {
auto output_scale = make_constant(element::f32, Shape{}, 1); auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false); return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false)
->add_provenance_group_members_above({input0, input1});
} }
shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0, shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
...@@ -124,12 +134,14 @@ namespace ngraph ...@@ -124,12 +134,14 @@ namespace ngraph
const auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1); const auto dot = make_shared<op::Dot>(dq_input0, dq_input1, 1);
return make_shared<op::Quantize>( return make_shared<op::Quantize>(
dot, dot,
output_scale, output_scale,
output_zero_point, output_zero_point,
output_zero_point->get_element_type(), output_zero_point->get_element_type(),
axes, axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN); op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
->add_provenance_group_members_above(
{input0, input1, input0_zero_point, input1_zero_point});
} }
} }
} }
......
...@@ -26,7 +26,8 @@ namespace ngraph ...@@ -26,7 +26,8 @@ namespace ngraph
{ {
auto abs_a = std::make_shared<op::Abs>(a); auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b); auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b); return std::make_shared<op::Maximum>(abs_a, abs_b)
->add_provenance_group_members_above({a, b});
} }
std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range, std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
...@@ -72,7 +73,8 @@ namespace ngraph ...@@ -72,7 +73,8 @@ namespace ngraph
auto max_abs_range = max_abs(min_range, max_range); auto max_abs_range = max_abs(min_range, max_range);
auto target_range = make_constant(type, shape, range); auto target_range = make_constant(type, shape, range);
return max_abs_range / target_range; return (max_abs_range / target_range)
->add_provenance_group_members_above({input_min_range, input_max_range});
} }
} }
} }
......
...@@ -56,23 +56,31 @@ namespace ngraph ...@@ -56,23 +56,31 @@ namespace ngraph
auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0}); auto filter_zero_point = op::Constant::create(filters.get_element_type(), Shape{}, {0});
return make_shared<op::QuantizedConvolution>( return make_shared<op::QuantizedConvolution>(
input, input,
filters, filters,
window_movement_strides, window_movement_strides,
window_dilation_strides, window_dilation_strides,
padding_below, padding_below,
padding_above, padding_above,
data_dilation_strides, data_dilation_strides,
input_scale, input_scale,
input_zero_point, input_zero_point,
filter_scale, filter_scale,
filter_zero_point, filter_zero_point,
output_scale, output_scale,
filter_zero_point, // output type will be same as filter filter_zero_point, // output type will be same as filter
output_type, output_type,
input_axes, input_axes,
filter_axes, filter_axes,
output_axes); output_axes)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
} }
} }
} }
...@@ -47,7 +47,7 @@ namespace ngraph ...@@ -47,7 +47,7 @@ namespace ngraph
auto x2 = node * node; auto x2 = node * node;
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes); auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
return std::make_shared<op::Sqrt>(x2sum); return std::make_shared<op::Sqrt>(x2sum)->add_provenance_group_members_above({node});
} }
std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes) std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes)
...@@ -59,14 +59,15 @@ namespace ngraph ...@@ -59,14 +59,15 @@ namespace ngraph
auto divisor = op::Constant::create(et, xsum->get_shape(), {N}); auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor; return (xsum / divisor)->add_provenance_group_members_above({value});
} }
std::shared_ptr<Node> std_dev(const Output<Node>& node, std::shared_ptr<Node> std_dev(const Output<Node>& node,
const AxisSet& reduction_axes, const AxisSet& reduction_axes,
const bool bessel_correction) const bool bessel_correction)
{ {
return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction)); return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction))
->add_provenance_group_members_above({node});
} }
// This currently calculates [E[X^2] - E[X]^2] instead of [E[(X-\mu)^2]] // This currently calculates [E[X^2] - E[X]^2] instead of [E[(X-\mu)^2]]
...@@ -96,16 +97,18 @@ namespace ngraph ...@@ -96,16 +97,18 @@ namespace ngraph
const auto& et = value.get_element_type(); const auto& et = value.get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes); auto N = get_num_elements(value.get_shape(), reduction_axes);
std::shared_ptr<Node> result;
if (bessel_correction) if (bessel_correction)
{ {
auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1}); auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1});
return diff / N1const; result = diff / N1const;
} }
else else
{ {
auto Nconst = op::Constant::create(et, diff.get_shape(), {N}); auto Nconst = op::Constant::create(et, diff.get_shape(), {N});
return diff / Nconst; result = diff / Nconst;
} }
return result->add_provenance_group_members_above({value});
} }
} // namespace builder } // namespace builder
......
...@@ -29,7 +29,8 @@ using namespace std; ...@@ -29,7 +29,8 @@ using namespace std;
shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape) shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape)
{ {
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape); return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order) shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
...@@ -49,7 +50,8 @@ shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> ...@@ -49,7 +50,8 @@ shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t>
} }
auto axis_vector = AxisVector{begin(axes_order), end(axes_order)}; auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return make_shared<op::Reshape>(value, axis_vector, out_shape); return make_shared<op::Reshape>(value, axis_vector, out_shape)
->add_provenance_group_members_above({value});
} }
shared_ptr<Node> builder::transpose(const Output<Node>& value) shared_ptr<Node> builder::transpose(const Output<Node>& value)
...@@ -73,5 +75,6 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis) ...@@ -73,5 +75,6 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis)
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>()); accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>( return make_shared<op::Reshape>(
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size}); value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size})
->add_provenance_group_members_above({value});
} }
...@@ -41,7 +41,9 @@ namespace ...@@ -41,7 +41,9 @@ namespace
upper_bounds.at(axis) = upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), output.get_shape().at(axis)); get_valid_array_index(ends.at(index), output.get_shape().at(axis));
} }
return std::make_shared<op::Slice>(output, lower_bounds, upper_bounds); return std::static_pointer_cast<op::Slice>(
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
} }
} }
......
...@@ -96,7 +96,8 @@ namespace ngraph ...@@ -96,7 +96,8 @@ namespace ngraph
convert_sequence, mask_shape, non_sequence_axes); convert_sequence, mask_shape, non_sequence_axes);
// mask = sequence_length < sequence // mask = sequence_length < sequence
return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths); return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths)
->add_provenance_group_members_above({sequence_lengths});
} }
} }
} }
...@@ -318,6 +318,66 @@ void Node::set_placement_index(size_t placement) ...@@ -318,6 +318,66 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement; m_placement_index = placement;
} }
void Node::add_provenance_group_member(const shared_ptr<Node>& node)
{
m_provenance_group.insert(node);
}
shared_ptr<Node> Node::add_provenance_group_members_above(const OutputVector& base)
{
set<Node*> base_set;
for (auto& output : base)
{
base_set.insert(output.get_node());
}
vector<Node*> todo;
for (auto input : inputs())
{
todo.push_back(input.get_source_output().get_node());
}
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0 || !node->m_provenance_group.empty())
{
continue;
}
add_provenance_group_member(node->shared_from_this());
for (auto input : node->inputs())
{
todo.push_back(input.get_source_output().get_node());
}
base_set.insert(node);
}
return shared_from_this();
}
void Node::add_provenance_tags_above(const OutputVector& base, const std::set<std::string>& tag_set)
{
set<Node*> base_set;
for (auto& output : base)
{
base_set.insert(output.get_node());
}
vector<Node*> todo{this};
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0)
{
continue;
}
node->add_provenance_tags(tag_set);
for (auto input : node->inputs())
{
todo.push_back(input.get_source_output().get_node());
}
base_set.insert(node);
}
}
const std::set<std::string>& Node::get_provenance_tags() const const std::set<std::string>& Node::get_provenance_tags() const
{ {
return m_provenance_tags; return m_provenance_tags;
...@@ -326,6 +386,10 @@ const std::set<std::string>& Node::get_provenance_tags() const ...@@ -326,6 +386,10 @@ const std::set<std::string>& Node::get_provenance_tags() const
void Node::add_provenance_tag(const std::string& tag) void Node::add_provenance_tag(const std::string& tag)
{ {
m_provenance_tags.insert(tag); m_provenance_tags.insert(tag);
for (auto node : m_provenance_group)
{
node->add_provenance_tag(tag);
}
} }
void Node::add_provenance_tags(const std::set<std::string>& tag_set) void Node::add_provenance_tags(const std::set<std::string>& tag_set)
......
...@@ -372,7 +372,14 @@ namespace ngraph ...@@ -372,7 +372,14 @@ namespace ngraph
const std::set<std::string>& get_provenance_tags() const; const std::set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag); void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set); void add_provenance_tags(const std::set<std::string>& tag_set);
/// \brief Adds tag_set to this node and all intermediate nodes above base
void add_provenance_tags_above(const OutputVector& base,
const std::set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag); void remove_provenance_tag(const std::string& tag);
/// \brief Add node to additional nodes that receive tags
void add_provenance_group_member(const std::shared_ptr<Node>& node);
/// \brief Add all nodes between this node and nodes in base as additional nodes to receive provenance tags.
std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
// to be used when nodes are replaced // to be used when nodes are replaced
void merge_provenance_tags_from(const std::shared_ptr<const Node>& source); void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
...@@ -428,6 +435,7 @@ namespace ngraph ...@@ -428,6 +435,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static std::atomic<size_t> m_next_instance_id; static std::atomic<size_t> m_next_instance_id;
std::set<std::string> m_provenance_tags; std::set<std::string> m_provenance_tags;
std::set<std::shared_ptr<Node>> m_provenance_group;
std::deque<descriptor::Input> m_inputs; std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs; std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map; std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
......
...@@ -36,8 +36,19 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -36,8 +36,19 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
// Op supported by backend. Do not decompose // Op supported by backend. Do not decompose
return modified; return modified;
} }
// Capture the input values as a base for provenance
OutputVector base_input_values;
for (auto input : fused_op->inputs())
{
base_input_values.push_back(input.get_source_output());
}
auto subgraph_outputs = fused_op->decompose_op(); auto subgraph_outputs = fused_op->decompose_op();
// Transfer the new provenance tags to the newly created ops
auto provenance_tags = fused_op->get_provenance_tags();
for (auto subgraph : subgraph_outputs)
{
subgraph->add_provenance_tags_above(base_input_values, provenance_tags);
}
// Run recursively untill no more fused ops // Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments()); auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
for (auto subgraph_node : subgraph) for (auto subgraph_node : subgraph)
...@@ -50,7 +61,6 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node) ...@@ -50,7 +61,6 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{ {
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++) for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{ {
// TODO: Provenance
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()), set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())}; end(fused_op->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users) for (auto fop_user : fop_users)
......
...@@ -22,7 +22,11 @@ ...@@ -22,7 +22,11 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/builder/norm.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/provenance.hpp" #include "ngraph/provenance.hpp"
using namespace std; using namespace std;
...@@ -367,3 +371,63 @@ TEST(provenance, provenance) ...@@ -367,3 +371,63 @@ TEST(provenance, provenance)
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_e"})); EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_e"}));
} }
} }
TEST(provenance, add_group_above)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto p2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p2->add_provenance_tag("P2");
auto a1 = p1 + p2;
auto m1 = (a1 * a1)->add_provenance_group_members_above({p1, p2});
m1->add_provenance_tag("m1");
EXPECT_EQ(p1->get_provenance_tags(), (ProvSet{"P1"}));
EXPECT_EQ(p2->get_provenance_tags(), (ProvSet{"P2"}));
EXPECT_EQ(a1->get_provenance_tags(), (ProvSet{"m1"}));
EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"}));
}
TEST(provenance, builder)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto norm = builder::lp_norm(p1, {0}, 1, 0);
norm->add_provenance_tag("norm");
for (auto node : topological_sort(NodeVector{norm}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"norm"}));
}
}
}
TEST(provenance, fused)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto g = make_shared<op::Gelu>(p1);
g->add_provenance_tag("G");
auto r = make_shared<op::Result>(g);
auto f = make_shared<Function>(ResultVector{r}, ParameterVector{p1});
pass::Manager manager;
manager.register_pass<pass::FusedOpDecomposition>();
manager.run_passes(f);
traverse_nodes(f, [&](const std::shared_ptr<Node>& node) {
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == r)
{
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"G"}));
}
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment