Unverified Commit 3f672f08 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Provenance for builders (#3644)

* Provenance for builders

* Tests

* Tests

* Update src/ngraph/node.hpp
Co-Authored-By: 's avatarSayantan Sarkar <sayantan.sarkar@intel.com>

* style
parent 02909e48
......@@ -172,7 +172,7 @@ namespace ngraph
return_value, final_shape, broadcast_axes);
}
return return_value.get_node_shared_ptr();
return return_value.get_node_shared_ptr()->add_provenance_group_members_above({value});
}
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
......
......@@ -103,7 +103,7 @@ namespace ngraph
val = std::make_shared<ngraph::op::Broadcast>(val, shape, axes);
}
return val;
return val->add_provenance_group_members_above({});
}
}
}
......@@ -64,7 +64,8 @@ namespace ngraph
values->get_shape(),
vector<float>(shape_size(values->get_shape()), 1.f / p_norm));
return {make_shared<op::Power>(values, inv_p_node)};
return {make_shared<op::Power>(values, inv_p_node)
->add_provenance_group_members_above({value})};
}
}
......@@ -80,7 +81,8 @@ namespace ngraph
shared_ptr<Node> non_zero_values = make_shared<op::Convert>(
make_shared<op::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<op::Sum>(non_zero_values, reduction_axes);
return make_shared<op::Sum>(non_zero_values, reduction_axes)
->add_provenance_group_members_above({value});
}
shared_ptr<Node>
......@@ -94,7 +96,7 @@ namespace ngraph
values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))};
return values + bias_node;
return (values + bias_node)->add_provenance_group_members_above({value});
}
shared_ptr<Node>
......@@ -107,7 +109,8 @@ namespace ngraph
values->get_shape(),
vector<float>(shape_size(values->get_shape()), bias))};
return {make_shared<op::Sqrt>(values + bias_node)};
return {make_shared<op::Sqrt>(values + bias_node)
->add_provenance_group_members_above({value})};
}
shared_ptr<Node> lp_norm(const Output<Node>& value,
......
......@@ -74,7 +74,8 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape);
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape)
->add_provenance_group_members_above({value});
}
} // namespace builder
......
......@@ -60,7 +60,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_util::get_scale(min, max, quant_type, true);
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode);
return make_shared<op::Quantize>(input, scale, zero, quant_type, axes, round_mode)
->add_provenance_group_members_above({input, min, max});
}
shared_ptr<Node> ScaledDequantize(const Output<Node>& input,
......@@ -89,7 +90,8 @@ namespace ngraph
auto zero = make_constant(quant_type, shape, 0);
auto scale = quantization_util::get_scale(min, max, quant_type);
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes);
return make_shared<op::Dequantize>(input, scale, zero, real_type, axes)
->add_provenance_group_members_above({input, min, max});
}
shared_ptr<Node> ScaledQuantizedConcat(const NodeVector& args,
......@@ -123,8 +125,21 @@ namespace ngraph
AxisSet{},
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
}
return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis);
OutputVector base;
for (auto node : args)
{
base.push_back(node);
}
for (auto node : mins)
{
base.push_back(node);
}
for (auto node : maxs)
{
base.push_back(node);
}
return make_shared<op::QuantizedConcat>(rescaled_args, concatenation_axis)
->add_provenance_group_members_above(base);
}
shared_ptr<Node> ScaledQuantizedAvgPool(const Output<Node>& input,
......@@ -141,7 +156,8 @@ namespace ngraph
window_movement_strides,
padding_below,
padding_above,
include_padding_in_avg_computation);
include_padding_in_avg_computation)
->add_provenance_group_members_above({input});
}
shared_ptr<Node> ScaledQuantizedConvolutionBias(const Output<Node>& input,
......@@ -187,7 +203,16 @@ namespace ngraph
padding_above,
data_dilation_strides,
requantization_scale,
with_relu);
with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
shared_ptr<Node> ScaledQuantizedConvolutionRelu(const Output<Node>& input,
......@@ -214,7 +239,15 @@ namespace ngraph
padding_below,
padding_above,
data_dilation_strides,
requantization_scale);
requantization_scale)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
shared_ptr<Node> ScaledQuantizedMaxPool(const Output<Node>& input,
......@@ -226,7 +259,8 @@ namespace ngraph
const Output<Node>& /* max */)
{
return make_shared<op::QuantizedMaxPool>(
input, window_shape, window_movement_strides, padding_below, padding_above);
input, window_shape, window_movement_strides, padding_below, padding_above)
->add_provenance_group_members_above({input});
}
shared_ptr<Node> ScaledQuantizedConvolutionBiasAdd(const Output<Node>& input,
......@@ -280,7 +314,19 @@ namespace ngraph
data_dilation_strides,
requantization_scale,
sum_scale,
with_relu);
with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
}
shared_ptr<Node>
......@@ -340,7 +386,19 @@ namespace ngraph
data_dilation_strides,
requantization_scale,
sum_scale,
with_relu);
with_relu)
->add_provenance_group_members_above({input,
filters,
bias,
sum_input,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output,
min_sum_input,
max_sum_input});
return make_shared<op::Convert>(qconv, element::u8);
}
......@@ -381,7 +439,9 @@ namespace ngraph
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedDotBias>(
input, filters, mybias, requantization_scale, requantize, with_relu);
input, filters, mybias, requantization_scale, requantize, with_relu)
->add_provenance_group_members_above(
{min_input, max_input, min_filter, max_filter, min_output, max_output});
}
shared_ptr<Node> ScaledQuantizedDot(const Output<Node>& input,
......@@ -406,7 +466,15 @@ namespace ngraph
with_relu ? element::u8 : element::i8,
requantize);
return make_shared<op::QuantizedDot>(
input, filters, requantization_scale, requantize, with_relu);
input, filters, requantization_scale, requantize, with_relu)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
} // namespace builder
} // namespace ngraph
......@@ -63,7 +63,6 @@ namespace ngraph
mybias = make_shared<op::Quantize>(
bias, bias_scale, zero, element::i32, quantization_axes, round_mode);
}
return make_shared<op::QuantizedConvolutionBias>(input,
filter,
mybias,
......@@ -73,7 +72,9 @@ namespace ngraph
padding_above,
data_dilation_strides,
requantization_scale,
false);
false)
->add_provenance_group_members_above(
{input, filter, bias, input_scale, filter_scale, output_scale});
}
}
}
......
......@@ -78,7 +78,16 @@ namespace ngraph
output_zero_point,
output_zero_point.get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
->add_provenance_group_members_above({input0,
input1,
input0_scale,
input0_zero_point,
input1_scale,
input1_scale,
input1_zero_point,
output_scale,
output_zero_point});
}
}
......@@ -86,7 +95,8 @@ namespace ngraph
const Output<Node>& input1)
{
auto output_scale = make_constant(element::f32, Shape{}, 1);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false);
return make_shared<op::QuantizedDot>(input0, input1, output_scale, false, false)
->add_provenance_group_members_above({input0, input1});
}
shared_ptr<Node> QuantizedLinearMatmulInteger(const Output<Node>& input0,
......@@ -129,7 +139,9 @@ namespace ngraph
output_zero_point,
output_zero_point->get_element_type(),
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)
->add_provenance_group_members_above(
{input0, input1, input0_zero_point, input1_zero_point});
}
}
}
......
......@@ -26,7 +26,8 @@ namespace ngraph
{
auto abs_a = std::make_shared<op::Abs>(a);
auto abs_b = std::make_shared<op::Abs>(b);
return std::make_shared<op::Maximum>(abs_a, abs_b);
return std::make_shared<op::Maximum>(abs_a, abs_b)
->add_provenance_group_members_above({a, b});
}
std::shared_ptr<Node> get_scale(const Output<Node>& input_min_range,
......@@ -72,7 +73,8 @@ namespace ngraph
auto max_abs_range = max_abs(min_range, max_range);
auto target_range = make_constant(type, shape, range);
return max_abs_range / target_range;
return (max_abs_range / target_range)
->add_provenance_group_members_above({input_min_range, input_max_range});
}
}
}
......
......@@ -72,7 +72,15 @@ namespace ngraph
output_type,
input_axes,
filter_axes,
output_axes);
output_axes)
->add_provenance_group_members_above({input,
filters,
min_input,
max_input,
min_filter,
max_filter,
min_output,
max_output});
}
}
}
......@@ -47,7 +47,7 @@ namespace ngraph
auto x2 = node * node;
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
return std::make_shared<op::Sqrt>(x2sum);
return std::make_shared<op::Sqrt>(x2sum)->add_provenance_group_members_above({node});
}
std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes)
......@@ -59,14 +59,15 @@ namespace ngraph
auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor;
return (xsum / divisor)->add_provenance_group_members_above({value});
}
std::shared_ptr<Node> std_dev(const Output<Node>& node,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction));
return std::make_shared<op::Sqrt>(variance(node, reduction_axes, bessel_correction))
->add_provenance_group_members_above({node});
}
// This currently calculates [E[X^2] - E[X]^2] instead of [E[(X-\mu)^2]]
......@@ -96,16 +97,18 @@ namespace ngraph
const auto& et = value.get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes);
std::shared_ptr<Node> result;
if (bessel_correction)
{
auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1});
return diff / N1const;
result = diff / N1const;
}
else
{
auto Nconst = op::Constant::create(et, diff.get_shape(), {N});
return diff / Nconst;
result = diff / Nconst;
}
return result->add_provenance_group_members_above({value});
}
} // namespace builder
......
......@@ -29,7 +29,8 @@ using namespace std;
shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape)
{
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape);
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
......@@ -49,7 +50,8 @@ shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t>
}
auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return make_shared<op::Reshape>(value, axis_vector, out_shape);
return make_shared<op::Reshape>(value, axis_vector, out_shape)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::transpose(const Output<Node>& value)
......@@ -73,5 +75,6 @@ shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis)
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>(
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size})
->add_provenance_group_members_above({value});
}
......@@ -41,7 +41,9 @@ namespace
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), output.get_shape().at(axis));
}
return std::make_shared<op::Slice>(output, lower_bounds, upper_bounds);
return std::static_pointer_cast<op::Slice>(
std::make_shared<op::Slice>(output, lower_bounds, upper_bounds)
->add_provenance_group_members_above({output}));
}
}
......
......@@ -96,7 +96,8 @@ namespace ngraph
convert_sequence, mask_shape, non_sequence_axes);
// mask = sequence_length < sequence
return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths);
return std::make_shared<T>(broadcast_sequence, broadcast_sequence_lengths)
->add_provenance_group_members_above({sequence_lengths});
}
}
}
......@@ -318,6 +318,66 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement;
}
void Node::add_provenance_group_member(const shared_ptr<Node>& node)
{
m_provenance_group.insert(node);
}
shared_ptr<Node> Node::add_provenance_group_members_above(const OutputVector& base)
{
set<Node*> base_set;
for (auto& output : base)
{
base_set.insert(output.get_node());
}
vector<Node*> todo;
for (auto input : inputs())
{
todo.push_back(input.get_source_output().get_node());
}
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0 || !node->m_provenance_group.empty())
{
continue;
}
add_provenance_group_member(node->shared_from_this());
for (auto input : node->inputs())
{
todo.push_back(input.get_source_output().get_node());
}
base_set.insert(node);
}
return shared_from_this();
}
void Node::add_provenance_tags_above(const OutputVector& base, const std::set<std::string>& tag_set)
{
set<Node*> base_set;
for (auto& output : base)
{
base_set.insert(output.get_node());
}
vector<Node*> todo{this};
while (!todo.empty())
{
Node* node = todo.back();
todo.pop_back();
if (base_set.count(node) > 0)
{
continue;
}
node->add_provenance_tags(tag_set);
for (auto input : node->inputs())
{
todo.push_back(input.get_source_output().get_node());
}
base_set.insert(node);
}
}
const std::set<std::string>& Node::get_provenance_tags() const
{
return m_provenance_tags;
......@@ -326,6 +386,10 @@ const std::set<std::string>& Node::get_provenance_tags() const
void Node::add_provenance_tag(const std::string& tag)
{
m_provenance_tags.insert(tag);
for (auto node : m_provenance_group)
{
node->add_provenance_tag(tag);
}
}
void Node::add_provenance_tags(const std::set<std::string>& tag_set)
......
......@@ -372,7 +372,14 @@ namespace ngraph
const std::set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set);
/// \brief Adds tag_set to this node and all intermediate nodes above base
void add_provenance_tags_above(const OutputVector& base,
const std::set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag);
/// \brief Add node to additional nodes that receive tags
void add_provenance_group_member(const std::shared_ptr<Node>& node);
/// \brief Add all nodes between this node and nodes in base as additional nodes to receive provenance tags.
std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
// to be used when nodes are replaced
void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
......@@ -428,6 +435,7 @@ namespace ngraph
NGRAPH_API
static std::atomic<size_t> m_next_instance_id;
std::set<std::string> m_provenance_tags;
std::set<std::shared_ptr<Node>> m_provenance_group;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
......
......@@ -36,8 +36,19 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
// Op supported by backend. Do not decompose
return modified;
}
// Capture the input values as a base for provenance
OutputVector base_input_values;
for (auto input : fused_op->inputs())
{
base_input_values.push_back(input.get_source_output());
}
auto subgraph_outputs = fused_op->decompose_op();
// Transfer the new provenance tags to the newly created ops
auto provenance_tags = fused_op->get_provenance_tags();
for (auto subgraph : subgraph_outputs)
{
subgraph->add_provenance_tags_above(base_input_values, provenance_tags);
}
// Run recursively untill no more fused ops
auto subgraph = extract_subgraph(subgraph_outputs, fused_op->get_arguments());
for (auto subgraph_node : subgraph)
......@@ -50,7 +61,6 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
for (size_t j = 0; j < output_node->get_outputs().size(); j++, i++)
{
// TODO: Provenance
set<descriptor::Input*> fop_users{begin(fused_op->get_outputs().at(i).get_inputs()),
end(fused_op->get_outputs().at(i).get_inputs())};
for (auto fop_user : fop_users)
......
......@@ -22,7 +22,11 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/builder/norm.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
......@@ -367,3 +371,63 @@ TEST(provenance, provenance)
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_e"}));
}
}
TEST(provenance, add_group_above)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto p2 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p2->add_provenance_tag("P2");
auto a1 = p1 + p2;
auto m1 = (a1 * a1)->add_provenance_group_members_above({p1, p2});
m1->add_provenance_tag("m1");
EXPECT_EQ(p1->get_provenance_tags(), (ProvSet{"P1"}));
EXPECT_EQ(p2->get_provenance_tags(), (ProvSet{"P2"}));
EXPECT_EQ(a1->get_provenance_tags(), (ProvSet{"m1"}));
EXPECT_EQ(m1->get_provenance_tags(), (ProvSet{"m1"}));
}
TEST(provenance, builder)
{
auto p1 = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto norm = builder::lp_norm(p1, {0}, 1, 0);
norm->add_provenance_tag("norm");
for (auto node : topological_sort(NodeVector{norm}))
{
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"norm"}));
}
}
}
TEST(provenance, fused)
{
auto p1 = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, 4});
p1->add_provenance_tag("P1");
auto g = make_shared<op::Gelu>(p1);
g->add_provenance_tag("G");
auto r = make_shared<op::Result>(g);
auto f = make_shared<Function>(ResultVector{r}, ParameterVector{p1});
pass::Manager manager;
manager.register_pass<pass::FusedOpDecomposition>();
manager.run_passes(f);
traverse_nodes(f, [&](const std::shared_ptr<Node>& node) {
if (node == p1)
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"P1"}));
}
else if (node == r)
{
}
else
{
EXPECT_EQ(node->get_provenance_tags(), (ProvSet{"G"}));
}
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment