Commit afd396dc authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Allow fused ops to create outputs and set types/shapes in pre_validation (#3839)

* Allow fused ops to create outputs and set types/shapes in pre_validation

* style fix
parent 71b6ef81
......@@ -99,6 +99,11 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
void op::LSTMCell::pre_validate_and_infer_types()
{
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
......
......@@ -228,4 +228,5 @@ void op::PartialSliceBackprop::pre_validate_and_infer_types()
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
......@@ -77,6 +77,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
void op::RNNCell::pre_validate_and_infer_types()
{
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0);
const auto& w_pshape = get_input_partial_shape(1);
const auto& r_pshape = get_input_partial_shape(2);
......
......@@ -29,11 +29,6 @@ op::Split::Split(const Output<Node>& data, const int axis, const size_t num_spli
, m_axis{axis}
, m_num_split{num_split}
{
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < num_split; i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
constructor_validate_and_infer_types();
}
......@@ -41,18 +36,25 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz
: FusedOp({data})
, m_split_evenly{false}
, m_axis{axis}
, m_num_split{0}
, m_splits{splits}
{
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < splits.size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
constructor_validate_and_infer_types();
}
void op::Split::pre_validate_and_infer_types()
{
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < std::max(m_splits.size(), m_num_split); i++)
{
set_output_type(i, input(0).get_element_type(), PartialShape::dynamic());
}
if (is_dynamic())
{
return;
}
const auto shape = input(0).get_shape();
m_axis = adjust_axis_value(m_axis, shape.size());
......
......@@ -42,25 +42,13 @@ op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
void op::util::FusedOp::validate_and_infer_types()
{
// Bail out if any of the shapes are unknown since fused op decomposition
// typically requires fully-determined static types.
//
// In the absence of decomposition, we will not know how many outputs this
// fused op has, so we conservatively create and set a single output to
// facilitate downstream ops that would like to use this op as an argument.
// Multi-output fused ops (e.g., split) should create these outputs in their
// constructors instead.
for (size_t i = 0; i < get_input_size(); i++)
pre_validate_and_infer_types();
if (!can_decompose_with_partial_shapes() && is_dynamic())
{
if (!get_input_partial_shape(i).is_static())
{
set_output_type(0, element::dynamic, PartialShape::dynamic());
return;
}
return;
}
pre_validate_and_infer_types();
auto subgraph_outputs = decompose_op();
auto subgraph = extract_subgraph(subgraph_outputs, get_arguments());
validate_nodes_and_infer_types(subgraph);
......
......@@ -31,10 +31,23 @@ namespace ngraph
{
public:
bool supports_decompose() const override { return true; }
// Fused op decomposition can be performed in the presence of
// partial shapes
virtual bool can_decompose_with_partial_shapes() { return false; }
// Shape inference that will use fused op decomposition to infer
// shapes and types of output elements. Ops can choose to override
// and provide a more direct implementation.
void validate_and_infer_types() override;
/// Pre and post validation hooks for op-specific actions
// Pre-validation hook that will be invoked before op
// decomposition in validate_and_infer_types().
// Can be used for attribute validation and setting types/shapes
// that can be inferred without requiring op decomposition.
// Can also be used to set shape specialization hints
// (set_input_is_relevant_to_shape())
virtual void pre_validate_and_infer_types() {}
// Post-validation hook that will be invoked after op decomposition
// in validate_and_infer_types().
virtual void post_validate_and_infer_types() {}
void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment