Commit 76fb19b0 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Enable non-constructor use of shape inference (#2875)

* Enable non-constructor use of shape inference

* Move GPU BatchNormTrainingWithStats shape inference out of the constructor

* Addressed PR feedback
parent 8b091114
...@@ -70,10 +70,6 @@ Function::Function(const std::shared_ptr<Node>& result, ...@@ -70,10 +70,6 @@ Function::Function(const std::shared_ptr<Node>& result,
const std::string& name) const std::string& name)
: Function(NodeVector{result}, parameters, name) : Function(NodeVector{result}, parameters, name)
{ {
// TODO this does not do anything while infer happens in the constructors
// and it will go away after we add shape during a clone; it is here now
// to assist development between those two stages.
validate_nodes_and_infer_types();
} }
void Function::validate_nodes_and_infer_types() void Function::validate_nodes_and_infer_types()
......
...@@ -224,7 +224,7 @@ namespace ngraph ...@@ -224,7 +224,7 @@ namespace ngraph
{ {
for (auto node : subgraph_topological_sort(nodes)) for (auto node : subgraph_topological_sort(nodes))
{ {
node->delayed_validate_and_infer_types(); node->revalidate_and_infer_types();
} }
} }
......
...@@ -74,11 +74,11 @@ namespace ngraph ...@@ -74,11 +74,11 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private:
static constexpr size_t INPUT_GAMMA = 0; static constexpr size_t INPUT_GAMMA = 0;
static constexpr size_t INPUT_BETA = 1; static constexpr size_t INPUT_BETA = 1;
static constexpr size_t INPUT_DATA = 2; static constexpr size_t INPUT_DATA = 2;
private:
double m_epsilon; double m_epsilon;
}; };
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <iostream> #include <iostream>
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
using namespace std; using namespace std;
...@@ -46,7 +47,21 @@ void op::DynReshape::validate_and_infer_types() ...@@ -46,7 +47,21 @@ void op::DynReshape::validate_and_infer_types()
Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0]; Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0];
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank)); if (auto const_shape = dynamic_pointer_cast<op::Constant>(get_argument(1)))
{
// TODO: replace with const_shape->get_shapes_val()
auto out_shape = const_shape->get_vector<int64_t>();
Shape output_shape(shape_size(const_shape->get_shape()));
std::transform(out_shape.begin(),
out_shape.end(),
output_shape.begin(),
[&](const int64_t& v) { return max(v, int64_t(0)); });
set_output_type(0, get_input_element_type(0), output_shape);
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
}
} }
shared_ptr<Node> op::DynReshape::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::DynReshape::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -28,6 +28,8 @@ op::Result::Result(const shared_ptr<Node>& arg) ...@@ -28,6 +28,8 @@ op::Result::Result(const shared_ptr<Node>& arg)
: Op("Result", check_single_output_args({arg})) : Op("Result", check_single_output_args({arg}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
// always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index());
} }
void op::Result::validate_and_infer_types() void op::Result::validate_and_infer_types()
...@@ -35,8 +37,6 @@ void op::Result::validate_and_infer_types() ...@@ -35,8 +37,6 @@ void op::Result::validate_and_infer_types()
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, get_input_size() == 1, "Argument has ", get_input_size(), " outputs (1 expected)."); this, get_input_size() == 1, "Argument has ", get_input_size(), " outputs (1 expected).");
// always borrow the placement conf even the default one
set_placement_index(get_argument(0)->get_placement_index());
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
......
...@@ -38,7 +38,10 @@ void op::util::FusedOp::validate_and_infer_types() ...@@ -38,7 +38,10 @@ void op::util::FusedOp::validate_and_infer_types()
{ {
for (size_t j = 0; j < output_node->get_output_size(); j++, i++) for (size_t j = 0; j < output_node->get_output_size(); j++, i++)
{ {
set_output_size(i + 1); if (i >= get_output_size())
{
set_output_size(i + 1);
}
set_output_type( set_output_type(
i, output_node->get_output_element_type(j), output_node->get_output_shape(j)); i, output_node->get_output_element_type(j), output_node->get_output_shape(j));
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/gpu/op/batch_norm.hpp" #include "ngraph/runtime/gpu/op/batch_norm.hpp"
#include "ngraph/validation_util.hpp"
ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats( ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
double eps, double eps,
...@@ -23,13 +24,32 @@ ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats( ...@@ -23,13 +24,32 @@ ngraph::op::gpu::BatchNormTrainingWithStats::BatchNormTrainingWithStats(
std::shared_ptr<ngraph::Node> input) std::shared_ptr<ngraph::Node> input)
: ngraph::op::BatchNormTraining(eps, gamma, beta, input) : ngraph::op::BatchNormTraining(eps, gamma, beta, input)
{ {
auto output_index = get_output_size(); constructor_validate_and_infer_types();
set_output_size(output_index + 2); }
Shape channel_shape{input->get_shape()[1]};
void ngraph::op::gpu::BatchNormTrainingWithStats::validate_and_infer_types()
{
element::Type result_et;
PartialShape result_batch_shape;
PartialShape result_channel_shape;
set_output_size(5);
std::tie(result_et, result_batch_shape, result_channel_shape) =
infer_batch_norm_forward(this,
get_input_element_type(INPUT_DATA),
get_input_element_type(INPUT_GAMMA),
get_input_element_type(INPUT_BETA),
get_input_partial_shape(INPUT_DATA),
get_input_partial_shape(INPUT_GAMMA),
get_input_partial_shape(INPUT_BETA));
set_output_type(0, result_et, result_batch_shape);
set_output_type(1, result_et, result_channel_shape);
set_output_type(2, result_et, result_channel_shape);
// saved batch mean // saved batch mean
set_output_type(output_index++, input->get_element_type(), channel_shape); set_output_type(3, result_et, result_channel_shape);
// saved batch inverse variance // saved batch inverse variance
set_output_type(output_index++, input->get_element_type(), channel_shape); set_output_type(4, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> ngraph::op::gpu::BatchNormTrainingWithStats::copy_with_new_args( std::shared_ptr<ngraph::Node> ngraph::op::gpu::BatchNormTrainingWithStats::copy_with_new_args(
......
...@@ -38,6 +38,8 @@ namespace ngraph ...@@ -38,6 +38,8 @@ namespace ngraph
std::shared_ptr<Node> beta, std::shared_ptr<Node> beta,
std::shared_ptr<Node> input); std::shared_ptr<Node> input);
void validate_and_infer_types() override;
protected: protected:
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -13969,6 +13969,27 @@ TEST(type_prop, group_conv_invalid_groups) ...@@ -13969,6 +13969,27 @@ TEST(type_prop, group_conv_invalid_groups)
} }
} }
TEST(type_prop, function_revalidate_and_infer)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto pattern = op::Constant::create(element::i64, Shape{6}, {1, 3, 16, 2, 2, 2});
auto r = make_shared<op::DynReshape>(arg, pattern);
auto relu = make_shared<op::Relu>(r);
auto f = make_shared<Function>(relu, ParameterVector{arg});
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_EQ(r->get_output_shape(0), (Shape{1, 3, 16, 2, 2, 2}));
EXPECT_EQ(f->get_output_shape(0), (Shape{1, 3, 16, 2, 2, 2}));
auto new_pattern = op::Constant::create(element::i64, Shape{2}, {32, 12});
r->input(1).replace_source_output(new_pattern->output(0));
f->validate_nodes_and_infer_types();
EXPECT_EQ(r->get_output_shape(0), (Shape{32, 12}));
EXPECT_EQ(f->get_output_shape(0), (Shape{32, 12}));
}
TEST(type_prop, gemm) TEST(type_prop, gemm)
{ {
auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6}); auto A = make_shared<op::Parameter>(element::f32, Shape{3, 6});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment