Commit 0f625c0f authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Update LRN for partial shape/type stuff (#1965)

parent 42bec7e7
...@@ -28,8 +28,17 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double ...@@ -28,8 +28,17 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double
, m_size(nsize) , m_size(nsize)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
NODE_VALIDATION_ASSERT(this, arg->get_shape().size() >= 3) }
<< "Argument must have rank >= 3 (argument shape: " << arg->get_shape() << ").";
void op::LRN::validate_and_infer_types()
{
UnaryElementwiseArithmetic::validate_and_infer_types();
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(
this, input_shape.rank().is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 3)
<< "Argument must have rank >= 3 (argument shape: " << input_shape << ").";
} }
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -57,6 +57,7 @@ namespace ngraph ...@@ -57,6 +57,7 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void validate_and_infer_types() override;
double m_alpha; double m_alpha;
double m_beta; double m_beta;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment