Commit 507e6d0a authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

helpers for function and node to determine the dynamic shape (#2753)

* - add method to determine if the function contains nodes with dynamic shape
- add virtual method to node to determine if the node contains any partial shape.

* - add support to reshape op to determine if the reshape node as dynamic attributes

* Address PR comments

* - check all the inputs to determine if the node has dynamic_shape
- addressed Adam comments

* - add support to abort graph match if the nodes are dynamic

* Addressed PR comments

* - remove strict check on dynmic shape in the matcher
parent 127e0ded
...@@ -224,3 +224,19 @@ void Function::set_placement(size_t placement) ...@@ -224,3 +224,19 @@ void Function::set_placement(size_t placement)
{ {
m_placement = placement; m_placement = placement;
} }
// TODO(pthoreho) this will be expensive, since we will be traversing all the nodes in
// the graph, figure out if their is a way to cache the result and invalidate/update
// the result if the function is modified
bool Function::is_dynamic() const
{
auto list_of_nodes = this->get_ops();
for (auto& node : list_of_nodes)
{
if (node->get_output_partial_shape(0).is_dynamic())
{
return true;
}
}
return false;
}
...@@ -105,6 +105,9 @@ namespace ngraph ...@@ -105,6 +105,9 @@ namespace ngraph
size_t get_placement() const; size_t get_placement() const;
void set_placement(size_t placement); void set_placement(size_t placement);
/// \brief Returns true if any of the op's defined in the function contains partial shape
bool is_dynamic() const;
protected: protected:
ResultVector m_results; ResultVector m_results;
ParameterVector m_parameters; ParameterVector m_parameters;
......
...@@ -524,3 +524,18 @@ void Node::validate_and_infer_elementwise_logical() ...@@ -524,3 +524,18 @@ void Node::validate_and_infer_elementwise_logical()
set_output_type(0, element::boolean, args_pshape); set_output_type(0, element::boolean, args_pshape);
} }
// default implementation for the node to check if it contains partial shape
// we will override this method, for the Op's which depends on additional shape
// attribute to determine if node contains partial shape or not
bool Node::is_dynamic() const
{
for (size_t i = 0; i < get_input_size(); i++)
{
if (get_input_partial_shape(i).is_dynamic())
{
return true;
}
}
return false;
}
...@@ -181,6 +181,7 @@ namespace ngraph ...@@ -181,6 +181,7 @@ namespace ngraph
virtual bool is_null() const { return false; } virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() { return false; }
virtual bool is_dynamic() const;
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
virtual std::ostream& write_short_description(std::ostream&) const; virtual std::ostream& write_short_description(std::ostream&) const;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <iostream> #include <iostream>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
using namespace std; using namespace std;
......
...@@ -7481,3 +7481,21 @@ TEST(${BACKEND_NAME}, batch_mat_mul_forward) ...@@ -7481,3 +7481,21 @@ TEST(${BACKEND_NAME}, batch_mat_mul_forward)
#undef BACKEND_TEST_${BACKEND_NAME} #undef BACKEND_TEST_${BACKEND_NAME}
#endif #endif
// clang-format on // clang-format on
NGRAPH_TEST(${BACKEND_NAME}, validate_function_for_dynamic_shape)
{
auto make_function = [&](bool dynmaic_shape) {
auto param1_shape =
dynmaic_shape ? PartialShape{Dimension::dynamic(), 2, 3} : Shape{5, 4, 2};
auto param2_shape = dynmaic_shape ? PartialShape::dynamic() : Shape{5, 2, 3};
auto param_1 = std::make_shared<op::Parameter>(element::f32, param1_shape);
auto param_2 = std::make_shared<op::Parameter>(element::f32, param2_shape);
auto batch_dot = make_shared<op::BatchMatMul>(param_1, param_2);
auto f = make_shared<Function>(NodeVector{batch_dot}, ParameterVector{param_1, param_2});
return f;
};
EXPECT_EQ(true, make_function(true)->is_dynamic());
EXPECT_EQ(false, make_function(false)->is_dynamic());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment