Commit 507e6d0a authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

helpers for function and node to determine the dynamic shape (#2753)

* - add method to determine if the function contains nodes with dynamic shape
- add virtual method to node to determine if the node contains any partial shape.

* - add support to reshape op to determine if the reshape node as dynamic attributes

* Address PR comments

* - check all the inputs to determine if the node has dynamic_shape
- addressed Adam comments

* - add support to abort graph match if the nodes are dynamic

* Addressed PR comments

* - remove strict check on dynmic shape in the matcher
parent 127e0ded
......@@ -224,3 +224,19 @@ void Function::set_placement(size_t placement)
{
m_placement = placement;
}
// TODO(pthoreho) this will be expensive, since we will be traversing all the nodes in
// the graph, figure out if their is a way to cache the result and invalidate/update
// the result if the function is modified
bool Function::is_dynamic() const
{
auto list_of_nodes = this->get_ops();
for (auto& node : list_of_nodes)
{
if (node->get_output_partial_shape(0).is_dynamic())
{
return true;
}
}
return false;
}
......@@ -105,6 +105,9 @@ namespace ngraph
size_t get_placement() const;
void set_placement(size_t placement);
/// \brief Returns true if any of the op's defined in the function contains partial shape
bool is_dynamic() const;
protected:
ResultVector m_results;
ParameterVector m_parameters;
......
......@@ -524,3 +524,18 @@ void Node::validate_and_infer_elementwise_logical()
set_output_type(0, element::boolean, args_pshape);
}
// default implementation for the node to check if it contains partial shape
// we will override this method, for the Op's which depends on additional shape
// attribute to determine if node contains partial shape or not
bool Node::is_dynamic() const
{
for (size_t i = 0; i < get_input_size(); i++)
{
if (get_input_partial_shape(i).is_dynamic())
{
return true;
}
}
return false;
}
......@@ -181,6 +181,7 @@ namespace ngraph
virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; }
virtual bool is_dynamic() const;
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const;
......
......@@ -18,6 +18,7 @@
#include <iostream>
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp"
using namespace std;
......
......@@ -7481,3 +7481,21 @@ TEST(${BACKEND_NAME}, batch_mat_mul_forward)
#undef BACKEND_TEST_${BACKEND_NAME}
#endif
// clang-format on
NGRAPH_TEST(${BACKEND_NAME}, validate_function_for_dynamic_shape)
{
auto make_function = [&](bool dynmaic_shape) {
auto param1_shape =
dynmaic_shape ? PartialShape{Dimension::dynamic(), 2, 3} : Shape{5, 4, 2};
auto param2_shape = dynmaic_shape ? PartialShape::dynamic() : Shape{5, 2, 3};
auto param_1 = std::make_shared<op::Parameter>(element::f32, param1_shape);
auto param_2 = std::make_shared<op::Parameter>(element::f32, param2_shape);
auto batch_dot = make_shared<op::BatchMatMul>(param_1, param_2);
auto f = make_shared<Function>(NodeVector{batch_dot}, ParameterVector{param_1, param_2});
return f;
};
EXPECT_EQ(true, make_function(true)->is_dynamic());
EXPECT_EQ(false, make_function(false)->is_dynamic());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment