Commit 0f625c0f authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Update LRN for partial shape/type stuff (#1965)

parent 42bec7e7
......@@ -28,8 +28,17 @@ op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double
, m_size(nsize)
{
constructor_validate_and_infer_types();
NODE_VALIDATION_ASSERT(this, arg->get_shape().size() >= 3)
<< "Argument must have rank >= 3 (argument shape: " << arg->get_shape() << ").";
}
void op::LRN::validate_and_infer_types()
{
UnaryElementwiseArithmetic::validate_and_infer_types();
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_ASSERT(
this, input_shape.rank().is_dynamic() || static_cast<size_t>(input_shape.rank()) >= 3)
<< "Argument must have rank >= 3 (argument shape: " << input_shape << ").";
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -57,6 +57,7 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
double m_alpha;
double m_beta;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment