Commit 32105c29 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

Fix validate_and_infer_type protection so all classes are consistent (#3158)

parent 0c4b5917
...@@ -58,11 +58,11 @@ namespace ngraph ...@@ -58,11 +58,11 @@ namespace ngraph
const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; } const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void validate_and_infer_types() override;
private: private:
/// Helper method to compute output shape /// Helper method to compute output shape
......
...@@ -56,11 +56,11 @@ namespace ngraph ...@@ -56,11 +56,11 @@ namespace ngraph
const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; } const AxisSet& get_ellipsis_mask() const { return m_ellipsis_mask; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void validate_and_infer_types() override;
private: private:
/// Helper method to compute output shape /// Helper method to compute output shape
......
...@@ -74,13 +74,14 @@ namespace ngraph ...@@ -74,13 +74,14 @@ namespace ngraph
bool get_use_seed() const { return m_use_seed; } bool get_use_seed() const { return m_use_seed; }
/// GenerateMask has state. /// GenerateMask has state.
bool has_state() const override { return true; } bool has_state() const override { return true; }
void validate_and_infer_types() override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override const NodeVector& deltas) override
{ {
} }
void validate_and_infer_types() override;
element::Type m_element_type; element::Type m_element_type;
// These will be deprecated // These will be deprecated
Shape m_shape; Shape m_shape;
......
...@@ -32,7 +32,6 @@ namespace ngraph ...@@ -32,7 +32,6 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected:
void validate_and_infer_types() override; void validate_and_infer_types() override;
}; };
} }
......
...@@ -49,6 +49,7 @@ namespace ngraph ...@@ -49,6 +49,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; } double get_alpha() const { return m_alpha; }
double get_beta() const { return m_beta; } double get_beta() const { return m_beta; }
...@@ -57,7 +58,6 @@ namespace ngraph ...@@ -57,7 +58,6 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void validate_and_infer_types() override;
double m_alpha; double m_alpha;
double m_beta; double m_beta;
......
...@@ -56,12 +56,11 @@ namespace ngraph ...@@ -56,12 +56,11 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
protected: protected:
void validate_and_infer_types() override;
PartialShape m_shape; PartialShape m_shape;
size_t m_one_hot_axis; size_t m_one_hot_axis;
}; };
......
...@@ -43,6 +43,7 @@ namespace ngraph ...@@ -43,6 +43,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The padding-below sizes. /// \return The padding-below sizes.
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
/// \return The padding-above sizes. /// \return The padding-above sizes.
...@@ -57,7 +58,6 @@ namespace ngraph ...@@ -57,7 +58,6 @@ namespace ngraph
virtual std::shared_ptr<Node> get_default_value() const override; virtual std::shared_ptr<Node> get_default_value() const override;
protected: protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
CoordinateDiff m_padding_below; CoordinateDiff m_padding_below;
......
...@@ -81,6 +81,7 @@ namespace ngraph ...@@ -81,6 +81,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The inclusive lower-bound coordinates. /// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; } const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
...@@ -91,7 +92,6 @@ namespace ngraph ...@@ -91,7 +92,6 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void validate_and_infer_types() override;
Coordinate m_lower_bounds; Coordinate m_lower_bounds;
Coordinate m_upper_bounds; Coordinate m_upper_bounds;
......
...@@ -51,9 +51,9 @@ namespace ngraph ...@@ -51,9 +51,9 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
protected: protected:
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -55,6 +55,7 @@ namespace ngraph ...@@ -55,6 +55,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The inclusive lower-bound coordinates. /// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; } const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
...@@ -65,7 +66,6 @@ namespace ngraph ...@@ -65,7 +66,6 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
void validate_and_infer_types() override;
Coordinate m_lower_bounds; Coordinate m_lower_bounds;
Coordinate m_upper_bounds; Coordinate m_upper_bounds;
......
...@@ -52,12 +52,12 @@ namespace ngraph ...@@ -52,12 +52,12 @@ namespace ngraph
void set_reduction_axis(size_t value); void set_reduction_axis(size_t value);
element::Type get_index_element_type() const; element::Type get_index_element_type() const;
void set_index_element_type(const element::Type& index_element_type); void set_index_element_type(const element::Type& index_element_type);
void validate_and_infer_types() override;
protected: protected:
size_t m_axis{0}; size_t m_axis{0};
element::Type m_index_element_type; element::Type m_index_element_type;
void validate_and_infer_types() override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -78,6 +78,7 @@ namespace ngraph ...@@ -78,6 +78,7 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override;
/// \return The inclusive lower-bound coordinates. /// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; } const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
...@@ -86,8 +87,6 @@ namespace ngraph ...@@ -86,8 +87,6 @@ namespace ngraph
/// \return The slicing strides. /// \return The slicing strides.
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const { return m_strides; }
protected: protected:
void validate_and_infer_types() override;
Coordinate m_lower_bounds; Coordinate m_lower_bounds;
Coordinate m_upper_bounds; Coordinate m_upper_bounds;
Strides m_strides; Strides m_strides;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment