Commit 1533b97f authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Constant folding with shape inference; Updated specialize_function (#3802)

* Added shape inference for entire graph to constant folding pass

* Updated specialize_function

* FusedOp MVN fix

* style

* Fixed issue with export symbols

* Fixed code style
parent 0bd90a78
...@@ -36,15 +36,6 @@ op::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_vari ...@@ -36,15 +36,6 @@ op::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_vari
, m_normalize_variance{normalize_variance} , m_normalize_variance{normalize_variance}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
// if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel
m_reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data.get_shape().size(); ++i)
{
m_reduction_axes.insert(i);
}
} }
op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps) op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps)
...@@ -57,6 +48,24 @@ op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_va ...@@ -57,6 +48,24 @@ op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_va
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::MVN::pre_validate_and_infer_types()
{
// if m_across_channels is true we should calculate mean and variance per batch
// else we calculate these per channel
if (m_reduction_axes.empty())
{
auto data = input_value(0);
AxisSet reduction_axes;
reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data.get_shape().size(); ++i)
{
reduction_axes.insert(i);
}
set_reduction_axes(reduction_axes);
}
}
NodeVector op::MVN::decompose_op() const NodeVector op::MVN::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
......
...@@ -63,12 +63,15 @@ namespace ngraph ...@@ -63,12 +63,15 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
double get_eps() const { return m_eps; } double get_eps() const { return m_eps; }
bool get_normalize_variance() const { return m_normalize_variance; } bool get_normalize_variance() const { return m_normalize_variance; }
AxisSet get_reduction_axes() const { return m_reduction_axes; } AxisSet get_reduction_axes() const { return m_reduction_axes; }
void set_reduction_axes(AxisSet axes) { m_reduction_axes = axes; }
private: private:
double m_eps; double m_eps;
bool m_across_channels; bool m_across_channels;
......
...@@ -62,6 +62,8 @@ public: ...@@ -62,6 +62,8 @@ public:
: GraphRewrite() : GraphRewrite()
{ {
m_cfmap = cfmap; m_cfmap = cfmap;
m_enable_shape_inference = true;
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast(); construct_constant_broadcast();
construct_constant_dyn_broadcast(); construct_constant_dyn_broadcast();
......
...@@ -80,6 +80,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -80,6 +80,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
m_matchers.clear(); m_matchers.clear();
for (auto node : f->get_ordered_ops()) for (auto node : f->get_ordered_ops())
{ {
if (m_enable_shape_inference)
{
node->revalidate_and_infer_types();
}
for (auto& closure : matchers_to_run) for (auto& closure : matchers_to_run)
{ {
if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE]) if (is_dyn_func && closure.property[PassProperty::REQUIRE_STATIC_SHAPE])
......
...@@ -70,6 +70,7 @@ public: ...@@ -70,6 +70,7 @@ public:
protected: protected:
bool is_enabled(const std::shared_ptr<pattern::Matcher>& m) const; bool is_enabled(const std::shared_ptr<pattern::Matcher>& m) const;
bool m_enable_shape_inference = false;
private: private:
struct MatchClosure struct MatchClosure
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/specialize_function.hpp" #include "ngraph/specialize_function.hpp"
#include <pass/constant_folding.hpp>
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -24,6 +25,17 @@ std::shared_ptr<Function> ...@@ -24,6 +25,17 @@ std::shared_ptr<Function>
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes, const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values) const std::vector<void*>& parameter_values)
{
return specialize_function(
f, parameter_element_types, parameter_shapes, parameter_values, false);
}
std::shared_ptr<Function>
ngraph::specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values,
bool constant_folding)
{ {
NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size()); NGRAPH_CHECK(f->get_parameters().size() == parameter_shapes.size());
NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size()); NGRAPH_CHECK(f->get_parameters().size() == parameter_element_types.size());
...@@ -33,8 +45,6 @@ std::shared_ptr<Function> ...@@ -33,8 +45,6 @@ std::shared_ptr<Function>
for (size_t i = 0; i < parameter_shapes.size(); i++) for (size_t i = 0; i < parameter_shapes.size(); i++)
{ {
NGRAPH_CHECK(
parameter_shapes[i].refines(f->get_parameters()[i]->get_output_partial_shape(0)));
NGRAPH_CHECK(f->get_parameters()[i]->get_element_type().is_dynamic() || NGRAPH_CHECK(f->get_parameters()[i]->get_element_type().is_dynamic() ||
parameter_element_types[i] == f->get_parameters()[i]->get_element_type()); parameter_element_types[i] == f->get_parameters()[i]->get_element_type());
...@@ -65,11 +75,13 @@ std::shared_ptr<Function> ...@@ -65,11 +75,13 @@ std::shared_ptr<Function>
new_args.push_back(output.for_node(m[output.get_node()])); new_args.push_back(output.for_node(m[output.get_node()]));
} }
m[old_node.get()] = old_node->copy_with_new_inputs(new_args); m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
m[old_node.get()]->set_friendly_name(old_node->get_friendly_name());
} }
ParameterVector new_parameters = f->get_parameters(); ParameterVector new_parameters = f->get_parameters();
for (size_t i = 0; i < new_parameters.size(); i++) for (size_t i = 0; i < new_parameters.size(); i++)
{ {
auto name = new_parameters[i]->get_friendly_name();
new_parameters[i] = as_type_ptr<op::Parameter>(m[new_parameters[i].get()]); new_parameters[i] = as_type_ptr<op::Parameter>(m[new_parameters[i].get()]);
// If the replacement for a Parameter is not itself a Parameter, we must have replaced it // If the replacement for a Parameter is not itself a Parameter, we must have replaced it
...@@ -80,13 +92,21 @@ std::shared_ptr<Function> ...@@ -80,13 +92,21 @@ std::shared_ptr<Function>
new_parameters[i] = new_parameters[i] =
std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]); std::make_shared<op::Parameter>(parameter_element_types[i], parameter_shapes[i]);
} }
new_parameters[i]->set_friendly_name(name);
} }
ResultVector new_results = f->get_results(); ResultVector new_results = f->get_results();
for (size_t i = 0; i < new_results.size(); i++) for (size_t i = 0; i < new_results.size(); i++)
{ {
auto name = new_results[i]->get_friendly_name();
new_results[i] = std::static_pointer_cast<op::Result>(m[new_results[i].get()]); new_results[i] = std::static_pointer_cast<op::Result>(m[new_results[i].get()]);
new_results[i]->set_friendly_name(name);
} }
return std::make_shared<Function>(new_results, new_parameters); auto function = std::make_shared<Function>(new_results, new_parameters);
if (constant_folding)
{
ngraph::pass::ConstantFolding().run_on_function(function);
}
return function;
} }
...@@ -108,4 +108,95 @@ namespace ngraph ...@@ -108,4 +108,95 @@ namespace ngraph
const std::vector<element::Type>& parameter_element_types, const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes, const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values); const std::vector<void*>& parameter_values);
/// \brief Creates a "specialized" clone of a function. The partial shapes and element types of
/// the function's parameters may be narrowed to more specific shapes and element types,
/// and constant values may optionally be substituted for any or all of the parameters.
/// \param f The function to be cloned.
/// \param parameter_element_types The new parameter element types to substitute. Length must
/// be equal to the number of parameters of f.
/// \param parameter_shapes The new parameter shapes to substitute. Length must be equal to the
/// number of parameters of f.
/// \param parameter_values Parameter values to substitute. Length must be equal to the number
/// of parameters of f, with nullptr indicating that no substitution is to be made for
/// the corresponding parameter.
/// \param constant_folding If flag is true, constant propagation is applied
/// \return A clone of f, with the parameter element types, shapes, and values specialized.
/// \throws CheckFailure if parameter_element_types, parameter_shapes is not valid
/// (see details).
/// \throws NodeValidationError if node validation fails as the clone is being constructed.
///
/// Creates a "specialized" clone of an nGraph Function.
///
/// For example, suppose that a function f has three parameters with partial shapes:
///
/// ```
/// param0: ?
/// param1: {1,?,3}
/// param2: {?,?,4}
/// ```
///
/// Shape specialization would allow us to create a clone of f where the shapes are (for
/// example):
///
/// ```
/// param0: {1,2}
/// param1: {1,5,3}
/// param2: {3,?,4}
/// ```
///
/// But not (for example):
///
/// ```
/// param1: {1,5,3,4} // rank doesn't match {1,?,3}
/// param1: {2,?,3} // the "2" doesn't match the "1"
/// param1: {?,?,3} // the new shape is too relaxed: it doesn't require 1 for the first dim
/// ```
///
/// Note that validation errors can potentially occur during cloning. For example:
///
/// ```
/// n = Parameter{shape=?}
/// m = Parameter{shape=?}
/// x = n + m
/// f = Function(x,{n,m})
/// ```
///
/// If we specialize n to the shape `{1,2,3}` and m to the shape `{4,5,6}`, cloning will fail
/// because when we reconstruct the new x node, it will see that the shapes are inconsistent
/// for elementwise add.
///
/// Specialization of element types is also possible: `element::dynamic` can be specialized
/// to a concrete element type or left dynamic; but a concrete element type can only be
/// specialized to itself (e.g., specialization does not allow you to change `element::i32`
/// to `element::i64`).
///
/// Finally, it is possible to specialize parameter values. If the ith element of
/// `parameter_values` is not `nullptr`, and fully static element type and shape has been
/// specified for the ith parameter, a `Constant` node will be created and substituted for the
/// ith parameter, with its data drawn from `parameter_values[i]`. Note that the Parameter node
/// remains (in order to maintain the arity of the function), but will no longer have any
/// users.
///
/// It is required that:
/// 1. The length of parameter_element_types, parameter_shapes, and parameter_values is the
/// same as the number of f's parameters.
/// 2. Each shape in parameter_shapes is a refinement of the shape of the corresponding
/// parameter of f. Roughly speaking, a shape s1 is said to "refine" s2 if s1 can be
/// obtained from s2 by filling in s2's question marks. See PartialShape::refines for
/// more details.
/// 3. For all i, either the element type of fp_i is dynamic, or fp_i is the same as
/// parameter_element_types[i]. (Here fp_i is the ith parameter of f.)
/// 4. For all i where parameter_values[i] != nullptr and parameter_element_types[i] is
/// static and parameter_shapes[i] is static, parameter_values[i] points to a buffer from
/// which a Constant node with element type parameter_element_types[i] and shape
/// parameter_shapes[i] can be created.
///
/// TODO(amprocte): convert this to a pass.
std::shared_ptr<Function>
specialize_function(std::shared_ptr<Function> f,
const std::vector<element::Type>& parameter_element_types,
const std::vector<PartialShape>& parameter_shapes,
const std::vector<void*>& parameter_values,
bool constant_folding);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment