Commit afd396dc authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Allow fused ops to create outputs and set types/shapes in pre_validation (#3839)

* Allow fused ops to create outputs and set types/shapes in pre_validation

* style fix
parent 71b6ef81
...@@ -99,6 +99,11 @@ op::LSTMCell::LSTMCell(const Output<Node>& X, ...@@ -99,6 +99,11 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
void op::LSTMCell::pre_validate_and_infer_types() void op::LSTMCell::pre_validate_and_infer_types()
{ {
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0); const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1); const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2); const auto& r_pshape = get_input_partial_shape(2);
......
...@@ -228,4 +228,5 @@ void op::PartialSliceBackprop::pre_validate_and_infer_types() ...@@ -228,4 +228,5 @@ void op::PartialSliceBackprop::pre_validate_and_infer_types()
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
} }
...@@ -77,6 +77,11 @@ op::RNNCell::RNNCell(const Output<Node>& X, ...@@ -77,6 +77,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
void op::RNNCell::pre_validate_and_infer_types() void op::RNNCell::pre_validate_and_infer_types()
{ {
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0); const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1); const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2); const auto& r_pshape = get_input_partial_shape(2);
......
...@@ -29,11 +29,6 @@ op::Split::Split(const Output<Node>& data, const int axis, const size_t num_spli ...@@ -29,11 +29,6 @@ op::Split::Split(const Output<Node>& data, const int axis, const size_t num_spli
, m_axis{axis} , m_axis{axis}
, m_num_split{num_split} , m_num_split{num_split}
{ {
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < num_split; i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -41,18 +36,25 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz ...@@ -41,18 +36,25 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz
: FusedOp({data}) : FusedOp({data})
, m_split_evenly{false} , m_split_evenly{false}
, m_axis{axis} , m_axis{axis}
, m_num_split{0}
, m_splits{splits} , m_splits{splits}
{ {
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < splits.size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::Split::pre_validate_and_infer_types() void op::Split::pre_validate_and_infer_types()
{ {
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < std::max(m_splits.size(), m_num_split); i++)
{
set_output_type(i, input(0).get_element_type(), PartialShape::dynamic());
}
if (is_dynamic())
{
return;
}
const auto shape = input(0).get_shape(); const auto shape = input(0).get_shape();
m_axis = adjust_axis_value(m_axis, shape.size()); m_axis = adjust_axis_value(m_axis, shape.size());
......
...@@ -42,24 +42,12 @@ op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args) ...@@ -42,24 +42,12 @@ op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
void op::util::FusedOp::validate_and_infer_types() void op::util::FusedOp::validate_and_infer_types()
{ {
// Bail out if any of the shapes are unknown since fused op decomposition pre_validate_and_infer_types();
// typically requires fully-determined static types.
// if (!can_decompose_with_partial_shapes() && is_dynamic())
// In the absence of decomposition, we will not know how many outputs this
// fused op has, so we conservatively create and set a single output to
// facilitate downstream ops that would like to use this op as an argument.
// Multi-output fused ops (e.g., split) should create these outputs in their
// constructors instead.
for (size_t i = 0; i < get_input_size(); i++)
{
if (!get_input_partial_shape(i).is_static())
{ {
set_output_type(0, element::dynamic, PartialShape::dynamic());
return; return;
} }
}
pre_validate_and_infer_types();
auto subgraph_outputs = decompose_op(); auto subgraph_outputs = decompose_op();
auto subgraph = extract_subgraph(subgraph_outputs, get_arguments()); auto subgraph = extract_subgraph(subgraph_outputs, get_arguments());
......
...@@ -31,10 +31,23 @@ namespace ngraph ...@@ -31,10 +31,23 @@ namespace ngraph
{ {
public: public:
bool supports_decompose() const override { return true; } bool supports_decompose() const override { return true; }
// Fused op decomposition can be performed in the presence of
// partial shapes
virtual bool can_decompose_with_partial_shapes() { return false; }
// Shape inference that will use fused op decomposition to infer
// shapes and types of output elements. Ops can choose to override
// and provide a more direct implementation.
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// Pre and post validation hooks for op-specific actions // Pre-validation hook that will be invoked before op
// decomposition in validate_and_infer_types().
// Can be used for attribute validation and setting types/shapes
// that can be inferred without requiring op decomposition.
// Can also be used to set shape specialization hints
// (set_input_is_relevant_to_shape())
virtual void pre_validate_and_infer_types() {} virtual void pre_validate_and_infer_types() {}
// Post-validation hook that will be invoked after op decomposition
// in validate_and_infer_types().
virtual void post_validate_and_infer_types() {} virtual void post_validate_and_infer_types() {}
void generate_adjoints(autodiff::Adjoints& adjoints, void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment