Commit 76fb19b0 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Enable non-constructor use of shape inference (#2875)

* Enable non-constructor use of shape inference

* Move GPU BatchNormTrainingWithStats shape inference out of the constructor

* Addressed PR feedback
parent 8b091114
......@@ -70,10 +70,6 @@ Function::Function(const std::shared_ptr<Node>& result,
const std::string& name)
: Function(NodeVector{result}, parameters, name)
{
// TODO this does not do anything while infer happens in the constructors
// and it will go away after we add shape during a clone; it is here now
// to assist development between those two stages.
validate_nodes_and_infer_types();
}
void Function::validate_nodes_and_infer_types()
......
......@@ -224,7 +224,7 @@ namespace ngraph
{
for (auto node : subgraph_topological_sort(nodes))
{
node->delayed_validate_and_infer_types();
node->revalidate_and_infer_types();
}
}
......
......@@ -74,11 +74,11 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2;
private:
double m_epsilon;
};
......
......@@ -18,6 +18,7 @@
#include <iostream>
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
using namespace std;
......@@ -46,7 +47,21 @@ void op::DynReshape::validate_and_infer_types()
Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0];
set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
// TODO: replace with const_shape->get_shapes_val()
auto out_shape = const_shape->get_vector<int64_t>();
Shape output_shape(shape_size(const_shape->get_shape()));
std::transform(out_shape.begin(),
out_shape.end(),
output_shape.begin(),
[&](const int64_t& v) { return max(v, int64_t(0)); });
set_output_type(0, get_input_element_type(0), output_shape);
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
}
}
shared_ptr<Node> op::DynReshape::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -28,6 +28,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
: Op("Result", check_single_output_args({arg}))
{
constructor_validate_and_infer_types();
// always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index());
}
void op::Result::validate_and_infer_types()
......@@ -35,8 +37,6 @@ void op::Result::validate_and_infer_types()
NODE_VALIDATION_CHECK(
this, get_input_size() == 1, "Argument has ", get_input_size(), " outputs (1 expected).");
// always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index());
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
......
......@@ -38,7 +38,10 @@ void op::util::FusedOp::validate_and_infer_types()
{
for (size_t j = 0; j < output_node->get_output_size(); j++, i++)
{
set_output_size(i + 1);
if (i >= get_output_size())
{
set_output_size(i + 1);
}
set_output_type(
i, output_node->get_output_element_type(j), output_node->get_output_shape(j));
}
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/validation_util.hpp"
ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
double eps,
......@@ -23,13 +24,32 @@ ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
std::shared_ptr<ngraph::Node> input)
: ngraph::op::BatchNormTraining(eps, gamma, beta, input)
{
auto output_index = get_output_size();
set_output_size(output_index + 2);
Shape channel_shape{input->get_shape()[1]};
constructor_validate_and_infer_types();
}
void ngraph::op::gpu::BatchNormTrainingWithStats::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
PartialShape result_channel_shape;
set_output_size(5);
std::tie(result_et, result_batch_shape, result_channel_shape) =
infer_batch_norm_forward(this,
get_input_element_type(INPUT_DATA),
get_input_element_type(INPUT_GAMMA),
get_input_element_type(INPUT_BETA),
get_input_partial_shape(INPUT_DATA),
get_input_partial_shape(INPUT_GAMMA),
get_input_partial_shape(INPUT_BETA));
set_output_type(0, result_et, result_batch_shape);
set_output_type(1, result_et, result_channel_shape);
set_output_type(2, result_et, result_channel_shape);
// saved batch mean
set_output_type(output_index++, input->get_element_type(), channel_shape);
set_output_type(3, result_et, result_channel_shape);
// saved batch inverse variance
set_output_type(output_index++, input->get_element_type(), channel_shape);
set_output_type(4, result_et, result_channel_shape);
}
std::shared_ptr<ngraph::Node> ngraph::op::gpu::BatchNormTrainingWithStats::copy_with_new_args(
......
......@@ -38,6 +38,8 @@ namespace ngraph
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
void validate_and_infer_types() override;
protected:
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -13969,6 +13969,27 @@ TEST(type_prop, group_conv_invalid_groups)
}
}
TEST(type_prop, function_revalidate_and_infer)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto pattern = op::Constant::create(element::i64, Shape{6}, {1, 3, 16, 2, 2, 2});
auto r = make_shared<op::DynReshape>(arg, pattern);
auto relu = make_shared<op::Relu>(r);
auto f = make_shared<Function>(relu, ParameterVector{arg});
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_EQ(r->get_output_shape(0), (Shape{1, 3, 16, 2, 2, 2}));
EXPECT_EQ(f->get_output_shape(0), (Shape{1, 3, 16, 2, 2, 2}));
auto new_pattern = op::Constant::create(element::i64, Shape{2}, {32, 12});
r->input(1).replace_source_output(new_pattern->output(0));
f->validate_nodes_and_infer_types();
EXPECT_EQ(r->get_output_shape(0), (Shape{32, 12}));
EXPECT_EQ(f->get_output_shape(0), (Shape{32, 12}));
}
TEST(type_prop, gemm)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment