Commit 32105c29 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fix validate_and_infer_type protection so all classes are consistent (#3158)

parent 0c4b5917
......@@ -58,11 +58,11 @@ namespace ngraph
const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
private:
/// Helper method to compute output shape
......
......@@ -56,11 +56,11 @@ namespace ngraph
const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
private:
/// Helper method to compute output shape
......
......@@ -74,13 +74,14 @@ namespace ngraph
bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state.
bool has_state() const override { return true; }
void validate_and_infer_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override
{
}
void validate_and_infer_types() override;
element::Type m_element_type;
// These will be deprecated
Shape m_shape;
......
......@@ -32,7 +32,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
void validate_and_infer_types() override;
};
}
......
......@@ -49,6 +49,7 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; }
double get_beta() const { return m_beta; }
......@@ -57,7 +58,6 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
double m_alpha;
double m_beta;
......
......@@ -56,12 +56,11 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; }
protected:
void validate_and_infer_types() override;
PartialShape m_shape;
size_t m_one_hot_axis;
};
......
......@@ -43,6 +43,7 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The padding-below sizes.
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
/// \return The padding-above sizes.
......@@ -57,7 +58,6 @@ namespace ngraph
virtual std::shared_ptr<Node> get_default_value() const override;
protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
CoordinateDiff m_padding_below;
......
......@@ -81,6 +81,7 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
......@@ -91,7 +92,6 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
......
......@@ -51,9 +51,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -55,6 +55,7 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
......@@ -65,7 +66,6 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
......
......@@ -52,12 +52,12 @@ namespace ngraph
void set_reduction_axis(size_t value);
element::Type get_index_element_type() const;
void set_index_element_type(const element::Type& index_element_type);
void validate_and_infer_types() override;
protected:
size_t m_axis{0};
element::Type m_index_element_type;
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -78,6 +78,7 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
......@@ -86,8 +87,6 @@ namespace ngraph
/// \return The slicing strides.
const Strides& get_strides() const { return m_strides; }
protected:
void validate_and_infer_types() override;
Coordinate m_lower_bounds;
Coordinate m_upper_bounds;
Strides m_strides;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment