Unverified Commit 840bf1a3 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into mlir

parents c1b08b42 a073c39e
...@@ -33,9 +33,26 @@ descriptor::Input::Input(Node* node, size_t index, Output& output) ...@@ -33,9 +33,26 @@ descriptor::Input::Input(Node* node, size_t index, Output& output)
output.add_input(this); output.add_input(this);
} }
descriptor::Input::Input(Node* node, size_t index)
: m_node(node)
, m_index(index)
, m_output(nullptr)
, m_is_relevant_to_shape(false)
, m_is_relevant_to_value(true)
{
}
descriptor::Input::~Input()
{
remove_output();
}
void descriptor::Input::replace_output(Output& new_output) void descriptor::Input::replace_output(Output& new_output)
{ {
m_output->remove_input(this); if (m_output != nullptr)
{
m_output->remove_input(this);
}
new_output.add_input(this); new_output.add_input(this);
m_output = &new_output; m_output = &new_output;
m_src_node = std::shared_ptr<Node>(new_output.get_node()); m_src_node = std::shared_ptr<Node>(new_output.get_node());
...@@ -56,6 +73,16 @@ void descriptor::Input::replace_output(std::shared_ptr<Node> node, size_t i) ...@@ -56,6 +73,16 @@ void descriptor::Input::replace_output(std::shared_ptr<Node> node, size_t i)
replace_output(node->m_outputs.at(i)); replace_output(node->m_outputs.at(i));
} }
void descriptor::Input::remove_output()
{
if (m_output != nullptr)
{
m_output->remove_input(this);
m_src_node = nullptr;
m_output = nullptr;
}
}
std::shared_ptr<Node> descriptor::Input::get_node() const std::shared_ptr<Node> descriptor::Input::get_node() const
{ {
return m_node->shared_from_this(); return m_node->shared_from_this();
......
...@@ -38,6 +38,11 @@ namespace ngraph ...@@ -38,6 +38,11 @@ namespace ngraph
/// \param index The position of this this tensor in all input tensors /// \param index The position of this this tensor in all input tensors
/// \param output The output that supplies a value for this input /// \param output The output that supplies a value for this input
Input(Node* node, size_t index, Output& output); Input(Node* node, size_t index, Output& output);
/// \brief Create an Input that is not connected to an output
/// \param node The node that owns this input
/// \param index The position of this this tensor in all input tensors
Input(Node* node, size_t index);
~Input();
/// \return the node that this is an input of /// \return the node that this is an input of
std::shared_ptr<Node> get_node() const; std::shared_ptr<Node> get_node() const;
...@@ -50,14 +55,20 @@ namespace ngraph ...@@ -50,14 +55,20 @@ namespace ngraph
const Output& get_output() const { return *m_output; } const Output& get_output() const { return *m_output; }
/// \return the connected output /// \return the connected output
Output& get_output() { return *m_output; } Output& get_output() { return *m_output; }
/// \return true if an output is connected to the input.
bool has_output() const { return m_output != nullptr; }
/// \return the tensor of the connected output /// \return the tensor of the connected output
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
/// \return the tensor of the connected output /// \return the tensor of the connected output
Tensor& get_tensor(); Tensor& get_tensor();
/// \brief Replace the current output that supplies a value for this input with output i of node
void replace_output(std::shared_ptr<Node> node, size_t i); void replace_output(std::shared_ptr<Node> node, size_t i);
/// \brief Replace the current output that supplies a value for this input with output
void replace_output(Output& output); void replace_output(Output& output);
/// \brief Remove the output from this input. The node will not be valid until another output is supplied.
void remove_output();
/// \return true if the value of this input is relevant to the output shapes of the /// \return true if the value of this input is relevant to the output shapes of the
/// corresponding node. (Usually this is false.) /// corresponding node. (Usually this is false.)
......
...@@ -31,6 +31,18 @@ descriptor::Tensor::Tensor(const element::Type& element_type, ...@@ -31,6 +31,18 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
{ {
} }
descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape,
Node* node,
size_t node_output_number)
: m_element_type(element_type)
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_node(node)
, m_node_output_number(node_output_number)
{
}
void descriptor::Tensor::set_tensor_type(const element::Type& element_type, void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape) const PartialShape& pshape)
{ {
...@@ -95,6 +107,16 @@ void descriptor::Tensor::set_tensor_layout( ...@@ -95,6 +107,16 @@ void descriptor::Tensor::set_tensor_layout(
m_tensor_layout = tensor_layout; m_tensor_layout = tensor_layout;
} }
const std::string& descriptor::Tensor::get_name() const
{
if (m_name.empty() && m_node != nullptr)
{
const_cast<Tensor*>(this)->m_name =
m_node->get_name() + "_" + to_string(m_node_output_number);
}
return m_name;
}
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor) ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{ {
out << "Tensor(" << tensor.get_name() << ")"; out << "Tensor(" << tensor.get_name() << ")";
......
...@@ -45,8 +45,12 @@ namespace ngraph ...@@ -45,8 +45,12 @@ namespace ngraph
Tensor(const element::Type& element_type, Tensor(const element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const std::string& name); const std::string& name);
Tensor(const element::Type& element_type,
const PartialShape& pshape,
Node* node,
size_t node_output_number);
const std::string& get_name() const { return m_name; } const std::string& get_name() const;
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape); void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
const element::Type& get_element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
...@@ -73,6 +77,8 @@ namespace ngraph ...@@ -73,6 +77,8 @@ namespace ngraph
// should refactor so that get_shape returns by value. // should refactor so that get_shape returns by value.
Shape m_shape; Shape m_shape;
PartialShape m_partial_shape; PartialShape m_partial_shape;
Node* m_node{nullptr};
size_t m_node_output_number{0};
std::string m_name; std::string m_name;
std::shared_ptr<layout::TensorLayout> m_tensor_layout; std::shared_ptr<layout::TensorLayout> m_tensor_layout;
......
...@@ -33,21 +33,131 @@ using namespace ngraph; ...@@ -33,21 +33,131 @@ using namespace ngraph;
atomic<size_t> Node::m_next_instance_id(0); atomic<size_t> Node::m_next_instance_id(0);
Node::Node(size_t output_size)
: Node()
{
set_output_size(output_size);
}
Node::Node(const std::string& node_type, const NodeVector& arguments, size_t output_size) Node::Node(const std::string& node_type, const NodeVector& arguments, size_t output_size)
: m_node_type(node_type) : m_node_type(node_type)
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_unique_name(description() + "_" + to_string(m_instance_id))
{ {
// Add this node as a user of each argument. set_arguments(arguments);
size_t i = 0; set_output_size(output_size);
}
Node::Node(const NodeVector& arguments, size_t output_size)
: Node()
{
set_arguments(arguments);
set_output_size(output_size);
}
Node::Node(const OutputVector& arguments, size_t output_size)
: Node()
{
set_arguments(arguments);
set_output_size(output_size);
}
Node::~Node()
{
for (descriptor::Input& input : m_inputs)
{
if (input.has_output())
{
// This test adds 1 to the actual count, so a count of 2 means this input is the only reference to the node.
if (input.get_output().get_node().use_count() == 2)
{
// Don't want to trigger a deep recursive delete
NodeVector nodes{input.get_output().get_node()};
input.remove_output();
safe_delete(nodes, true);
return;
}
input.remove_output();
}
}
}
void Node::safe_delete(NodeVector& nodes, bool recurse)
{
for (auto& input : m_inputs)
{
if (input.has_output())
{
// This test adds 1 to the actual count, so a count of 2 means this input is the only reference to the node.
auto node = input.get_output().get_node();
if (node.use_count() == 2)
{
// Move the node from the input to nodes so we don't trigger a deep recursive delete
nodes.push_back(node);
}
input.remove_output();
}
}
if (recurse)
{
while (nodes.size() > 0)
{
auto node = nodes.back();
nodes.pop_back();
node->safe_delete(nodes, false);
}
}
}
void Node::set_arguments(const NodeVector& arguments)
{
OutputVector outputs;
for (auto arg : arguments) for (auto arg : arguments)
{ {
for (descriptor::Output& output : arg->m_outputs) for (auto& output : arg->outputs())
{ {
m_inputs.emplace_back(this, i++, output); outputs.push_back(output);
} }
} }
set_output_size(output_size); set_arguments(outputs);
}
void Node::set_arguments(const OutputVector& arguments)
{
// Add this node as a user of each argument.
size_t i = 0;
for (auto& output : arguments)
{
auto output_node = output.get_node();
auto& output_descriptor = output_node->get_outputs().at(output.get_index());
m_inputs.emplace_back(this, i++, output_descriptor);
}
}
descriptor::Input& Node::get_input_descriptor(size_t position)
{
while (m_inputs.size() <= position)
{
m_inputs.emplace_back(this, m_inputs.size());
}
return m_inputs.at(position);
}
descriptor::Output& Node::get_output_descriptor(size_t position)
{
while (m_outputs.size() <= position)
{
size_t i = m_outputs.size();
auto tensor_descriptor =
make_shared<descriptor::Tensor>(element::dynamic, PartialShape::dynamic(), this, i);
m_outputs.emplace_back(this, i, tensor_descriptor);
}
return m_outputs.at(position);
}
void Node::set_argument(size_t position, const Output<Node>& argument)
{
auto output_node = argument.get_node();
auto& output_descriptor = output_node->get_output_descriptor(argument.get_index());
get_input_descriptor(position).replace_output(output_descriptor);
} }
// While we are still doing validation and type inference in the constructor, this is true // While we are still doing validation and type inference in the constructor, this is true
...@@ -75,9 +185,8 @@ void Node::set_output_size(size_t n) ...@@ -75,9 +185,8 @@ void Node::set_output_size(size_t n)
NGRAPH_CHECK(n >= m_outputs.size(), "shrinking ", m_outputs.size(), " to ", n); NGRAPH_CHECK(n >= m_outputs.size(), "shrinking ", m_outputs.size(), " to ", n);
for (size_t i = m_outputs.size(); i < n; ++i) for (size_t i = m_outputs.size(); i < n; ++i)
{ {
auto tensor_descriptor = make_shared<descriptor::Tensor>( // create the descriptors
element::dynamic, PartialShape::dynamic(), get_name() + "_" + to_string(i)); get_output_descriptor(i);
m_outputs.emplace_back(this, i, tensor_descriptor);
} }
} }
...@@ -97,7 +206,7 @@ void Node::set_input_is_relevant_to_value(size_t i, bool relevant) ...@@ -97,7 +206,7 @@ void Node::set_input_is_relevant_to_value(size_t i, bool relevant)
void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape) void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape)
{ {
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, pshape); get_output_descriptor(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
} }
std::deque<descriptor::Output>& Node::get_outputs() std::deque<descriptor::Output>& Node::get_outputs()
...@@ -134,13 +243,17 @@ const std::string& Node::get_friendly_name() const ...@@ -134,13 +243,17 @@ const std::string& Node::get_friendly_name() const
{ {
if (m_friendly_name.empty()) if (m_friendly_name.empty())
{ {
return m_unique_name; return get_name();
} }
return m_friendly_name; return m_friendly_name;
} }
const std::string& Node::get_name() const const std::string& Node::get_name() const
{ {
if (m_unique_name.empty())
{
const_cast<Node*>(this)->m_unique_name = description() + "_" + to_string(m_instance_id);
}
return m_unique_name; return m_unique_name;
} }
...@@ -204,14 +317,6 @@ std::shared_ptr<Node> Node::get_argument(size_t index) const ...@@ -204,14 +317,6 @@ std::shared_ptr<Node> Node::get_argument(size_t index) const
return m_inputs.at(index).get_output().get_node(); return m_inputs.at(index).get_output().get_node();
} }
Node::~Node()
{
for (auto& input : m_inputs)
{
input.get_output().remove_input(&input);
}
}
NodeVector Node::get_arguments() const NodeVector Node::get_arguments() const
{ {
NodeVector result; NodeVector result;
...@@ -442,18 +547,6 @@ std::string ngraph::node_validation_failure_loc_string(const Node* node) ...@@ -442,18 +547,6 @@ std::string ngraph::node_validation_failure_loc_string(const Node* node)
return ss.str(); return ss.str();
} }
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{
NODE_VALIDATION_CHECK(node,
new_args.size() == node->get_arguments().size(),
"copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
}
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node, const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
size_t i) size_t i)
{ {
......
...@@ -49,6 +49,7 @@ namespace ngraph ...@@ -49,6 +49,7 @@ namespace ngraph
class Node; class Node;
using NodeVector = std::vector<std::shared_ptr<Node>>; using NodeVector = std::vector<std::shared_ptr<Node>>;
using OutputVector = std::vector<Output<Node>>;
class Function; class Function;
...@@ -108,11 +109,42 @@ namespace ngraph ...@@ -108,11 +109,42 @@ namespace ngraph
void validate_and_infer_elementwise_logical( void validate_and_infer_elementwise_logical(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec()); const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
/// \brief Construct an unitialized Node
Node() {}
/// \brief Construct an unitialized Node
/// \param output_size Number of outputs for this node
Node(size_t output_size);
/// \brief Constructor for Node subclasses that have metaclasses.
/// \param arguments Output i will connect to input i
/// \param output_size Number of outputs for this node
Node(const OutputVector& arguments, size_t output_size = 1);
/// \brief Construct a node with arguments. Will be deprecated.
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1); Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
/// \brief Constructor for Node subclasses that have metaclasses. Will be deprecated.
/// \param arguments The 0th output of node i will connect to input i
/// \param output_size Number of outputs for this node
Node(const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {} virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
/// \brief Moves nodes that would be deleted from inputs to nodes to avoid stack overflows on deep networks.
void safe_delete(NodeVector& nodes, bool recurse);
public: public:
virtual ~Node(); virtual ~Node();
/// Sets/replaces the arguments with new arguments.
void set_arguments(const NodeVector& arguments);
/// Sets/replaces the arguments with new arguments.
void set_arguments(const OutputVector& arguments);
/// Sets/replaces the arguments with new arguments.
void set_argument(size_t position, const Output<Node>& argument);
/// Sets the number of outputs
void set_output_size(size_t output_size);
void revalidate_and_infer_types() { validate_and_infer_types(); } void revalidate_and_infer_types() { validate_and_infer_types(); }
// Called after transition // Called after transition
void delayed_validate_and_infer_types(); void delayed_validate_and_infer_types();
...@@ -120,8 +152,7 @@ namespace ngraph ...@@ -120,8 +152,7 @@ namespace ngraph
/// \brief Get the string name for the type of the node, such as `Add` or `Multiply`. /// \brief Get the string name for the type of the node, such as `Add` or `Multiply`.
/// The class name, must not contain spaces as it is used for codegen. /// The class name, must not contain spaces as it is used for codegen.
/// \returns A const reference to the node's type name /// \returns A const reference to the node's type name
const std::string& description() const; virtual const std::string& description() const;
/// \brief Get the unique name of the node. /// \brief Get the unique name of the node.
/// \returns A const reference to the node's unique name. /// \returns A const reference to the node's unique name.
const std::string& get_name() const; const std::string& get_name() const;
...@@ -285,10 +316,11 @@ namespace ngraph ...@@ -285,10 +316,11 @@ namespace ngraph
std::unordered_set<descriptor::Tensor*> liveness_new_list; std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list; std::unordered_set<descriptor::Tensor*> liveness_free_list;
// Will be deprecated
virtual NodeVector get_arguments() const; virtual NodeVector get_arguments() const;
// Will be deprecated
std::shared_ptr<Node> get_argument(size_t index) const; std::shared_ptr<Node> get_argument(size_t index) const;
// Will be replaced with an OutputVector version
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0; virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0;
virtual std::vector<std::shared_ptr<Function>> get_functions() const; virtual std::vector<std::shared_ptr<Function>> get_functions() const;
...@@ -353,16 +385,15 @@ namespace ngraph ...@@ -353,16 +385,15 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs. /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
Output<const Node> output(size_t output_index) const; Output<const Node> output(size_t output_index) const;
protected:
void set_output_size(size_t n);
private: private:
std::set<std::shared_ptr<Node>> m_control_dependencies; descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position);
std::set<std::shared_ptr<Node>> m_control_dependencies;
const std::string m_node_type; const std::string m_node_type;
size_t m_instance_id; size_t m_instance_id{m_next_instance_id.fetch_add(1)};
std::string m_friendly_name; std::string m_friendly_name;
const std::string m_unique_name; std::string m_unique_name;
static std::atomic<size_t> m_next_instance_id; static std::atomic<size_t> m_next_instance_id;
std::unordered_set<std::string> m_provenance_tags; std::unordered_set<std::string> m_provenance_tags;
std::deque<descriptor::Input> m_inputs; std::deque<descriptor::Input> m_inputs;
...@@ -409,6 +440,11 @@ namespace ngraph ...@@ -409,6 +440,11 @@ namespace ngraph
{ {
return m_node->m_inputs.at(m_index).get_output().get_tensor(); return m_node->m_inputs.at(m_index).get_output().get_tensor();
} }
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor_ptr();
}
/// \return true if this input is relevant to its node's output shapes; else false. /// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const bool get_is_relevant_to_shapes() const
{ {
...@@ -490,6 +526,11 @@ namespace ngraph ...@@ -490,6 +526,11 @@ namespace ngraph
{ {
return m_node->m_outputs.at(m_index).get_tensor(); return m_node->m_outputs.at(m_index).get_tensor();
} }
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_outputs.at(m_index).get_tensor_ptr();
}
/// \return The element type of the output referred to by this output handle. /// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const const element::Type& get_element_type() const
{ {
...@@ -684,9 +725,23 @@ namespace ngraph ...@@ -684,9 +725,23 @@ namespace ngraph
const Node& m_node; const Node& m_node;
bool m_is_short; bool m_is_short;
}; };
}
void check_new_args_count(const Node* node, const NodeVector& new_args);
} // namespace ngraph
#define NODE_VALIDATION_CHECK(node, cond, ...) \ #define NODE_VALIDATION_CHECK(node, cond, ...) \
NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), (cond), ##__VA_ARGS__) NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), (cond), ##__VA_ARGS__)
namespace ngraph
{
template <typename T>
void check_new_args_count(const Node* node, T new_args)
{
NODE_VALIDATION_CHECK(node,
new_args.size() == node->get_arguments().size(),
"copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
}
} // namespace ngraph
...@@ -21,8 +21,14 @@ ...@@ -21,8 +21,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Abs::Abs(const shared_ptr<Node>& arg) const string op::Abs::type_name{"Abs"};
: UnaryElementwiseArithmetic("Abs", arg)
op::Abs::Abs()
{
}
op::Abs::Abs(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,18 +29,22 @@ namespace ngraph ...@@ -29,18 +29,22 @@ namespace ngraph
class Abs : public util::UnaryElementwiseArithmetic class Abs : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an absolute value operation.
Abs();
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
/// ///
/// \param arg Node that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]` /// `[d1, ...]`
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Abs(const std::shared_ptr<Node>& arg); Abs(const Output<Node>& arg);
Abs(const op::Abs& other, const NodeVector& new_args);
virtual std::shared_ptr<Node> std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -32,8 +32,14 @@ ...@@ -32,8 +32,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Acos::Acos(const shared_ptr<Node>& arg) const string op::Acos::type_name{"Acos"};
: UnaryElementwiseArithmetic("Acos", arg)
op::Acos::Acos()
{
}
op::Acos::Acos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,17 +29,21 @@ namespace ngraph ...@@ -29,17 +29,21 @@ namespace ngraph
class Acos : public util::UnaryElementwiseArithmetic class Acos : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arccos operation.
Acos();
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
/// ///
/// \param arg Node that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]` /// `[d1, ...]`
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Acos(const std::shared_ptr<Node>& arg); Acos(const Output<Node>& arg);
virtual std::shared_ptr<Node> std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -19,10 +19,14 @@ ...@@ -19,10 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Add::Add(const shared_ptr<Node>& arg0, const string op::Add::type_name{"Add"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Add::Add()
: BinaryElementwiseArithmetic("Add", arg0, arg1, autob) {
}
op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,22 +29,27 @@ namespace ngraph ...@@ -29,22 +29,27 @@ namespace ngraph
class Add : public util::BinaryElementwiseArithmetic class Add : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized addition operation
Add();
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
/// ///
/// \param arg0 Node that produces the first input tensor.<br> /// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]` /// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br> /// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]` /// `[d0, ...]`
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Add(const std::shared_ptr<Node>& arg0, Add(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -19,8 +19,14 @@ ...@@ -19,8 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::All::All(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) const string op::All::type_name{"All"};
: LogicalReduction("All", arg, reduction_axes)
op::All::All()
{
}
op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,14 +29,18 @@ namespace ngraph ...@@ -29,14 +29,18 @@ namespace ngraph
class All : public util::LogicalReduction class All : public util::LogicalReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "all" reduction operation.
All();
/// \brief Constructs an "all" reduction operation. /// \brief Constructs an "all" reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); All(const Output<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for All. /// \return The default value for All.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
......
...@@ -19,8 +19,14 @@ ...@@ -19,8 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce()
{
}
op::AllReduce::AllReduce(const shared_ptr<Node>& arg) op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
: Op("AllReduce", check_single_output_args({arg})) : Op(check_single_output_args({arg}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,12 +26,15 @@ namespace ngraph ...@@ -26,12 +26,15 @@ namespace ngraph
class AllReduce : public Op class AllReduce : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
AllReduce();
AllReduce(const std::shared_ptr<Node>& arg); AllReduce(const std::shared_ptr<Node>& arg);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
}; };
} }
} }
...@@ -19,10 +19,14 @@ ...@@ -19,10 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::And::And(const shared_ptr<Node>& arg0, const string op::And::type_name{"And"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::And::And()
: BinaryElementwiseLogical("And", arg0, arg1, autob) {
}
op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,22 +29,27 @@ namespace ngraph ...@@ -29,22 +29,27 @@ namespace ngraph
class And : public util::BinaryElementwiseLogical class And : public util::BinaryElementwiseLogical
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-and operation.
And();
/// \brief Constructs a logical-and operation. /// \brief Constructs a logical-and operation.
/// ///
/// \param arg0 Node that produces the first input tensor.<br> /// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]` /// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br> /// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]` /// `[d0, ...]`
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
And(const std::shared_ptr<Node>& arg0, And(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
copy_with_new_args(const NodeVector& new_args) const override;
protected: protected:
virtual bool is_commutative() override { return true; } virtual bool is_commutative() override { return true; }
......
...@@ -19,8 +19,14 @@ ...@@ -19,8 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Any::Any(const shared_ptr<Node>& arg, const AxisSet& reduction_axes) const string op::Any::type_name{"Any"};
: LogicalReduction("Any", arg, reduction_axes)
op::Any::Any()
{
}
op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,11 +29,16 @@ namespace ngraph ...@@ -29,11 +29,16 @@ namespace ngraph
class Any : public util::LogicalReduction class Any : public util::LogicalReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "any" reduction operation.
Any();
/// \brief Constructs an "any" reduction operation. /// \brief Constructs an "any" reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated. /// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes); Any(const Output<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -19,8 +19,21 @@ ...@@ -19,8 +19,21 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ArgMax::type_name{"ArgMax"};
op::ArgMax::ArgMax()
{
}
op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args);
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type()); return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
} }
...@@ -28,17 +28,17 @@ namespace ngraph ...@@ -28,17 +28,17 @@ namespace ngraph
class ArgMax : public op::util::IndexReduction class ArgMax : public op::util::IndexReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMax operation.
ArgMax();
/// \brief Constructs a ArgMax operation. /// \brief Constructs a ArgMax operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
/// \param axis The axis along which to compute an index for maximum /// \param axis The axis along which to compute an index for maximum
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported /// \param index_element_type produce indices. Currently, only int64 or int32 are supported
ArgMax(const std::shared_ptr<Node>& arg, ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type);
size_t axis,
const element::Type& index_element_type)
: IndexReduction("ArgMax", arg, axis, index_element_type)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -19,6 +19,18 @@ ...@@ -19,6 +19,18 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ArgMin::type_name{"ArgMin"};
op::ArgMin::ArgMin()
{
}
op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -28,17 +28,18 @@ namespace ngraph ...@@ -28,17 +28,18 @@ namespace ngraph
class ArgMin : public op::util::IndexReduction class ArgMin : public op::util::IndexReduction
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMin operation.
ArgMin();
/// \brief Constructs a ArgMin operation. /// \brief Constructs a ArgMin operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
/// \param axis The axis along which to compute an index for minimum /// \param axis The axis along which to compute an index for minimum
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported /// \param index_element_type produce indices. Currently, only int64 or int32 are supported
ArgMin(const std::shared_ptr<Node>& arg, ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type);
size_t axis,
const element::Type& index_element_type)
: IndexReduction("ArgMin", arg, axis, index_element_type)
{
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -31,8 +31,14 @@ ...@@ -31,8 +31,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Asin::Asin(const shared_ptr<Node>& arg) const string op::Asin::type_name{"Asin"};
: UnaryElementwiseArithmetic("Asin", arg)
op::Asin::Asin()
{
}
op::Asin::Asin(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,14 +29,19 @@ namespace ngraph ...@@ -29,14 +29,19 @@ namespace ngraph
class Asin : public util::UnaryElementwiseArithmetic class Asin : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation.
Asin();
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
/// ///
/// \param arg Node that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]` /// `[d1, ...]`
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Asin(const std::shared_ptr<Node>& arg); Asin(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -30,8 +30,14 @@ ...@@ -30,8 +30,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Atan::Atan(const shared_ptr<Node>& arg) const string op::Atan::type_name{"Atan"};
: UnaryElementwiseArithmetic("Atan", arg)
op::Atan::Atan()
{
}
op::Atan::Atan(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,14 +29,20 @@ namespace ngraph ...@@ -29,14 +29,20 @@ namespace ngraph
class Atan : public util::UnaryElementwiseArithmetic class Atan : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arctan operation.
Atan();
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
/// ///
/// \param arg Node that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]` /// `[d1, ...]`
/// ///
/// Output `[d1, ...]` /// Output `[d1, ...]`
/// ///
Atan(const std::shared_ptr<Node>& arg); Atan(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -21,14 +21,20 @@ ...@@ -21,14 +21,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::AvgPool::AvgPool(const shared_ptr<Node>& arg, const string op::AvgPool::type_name{"AvgPool"};
op::AvgPool::AvgPool()
{
}
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation, bool include_padding_in_avg_computation,
const PadType& pad_type) const PadType& pad_type)
: Op("AvgPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -91,18 +97,78 @@ void op::AvgPool::validate_and_infer_types() ...@@ -91,18 +97,78 @@ void op::AvgPool::validate_and_infer_types()
m_include_padding_in_avg_computation)); m_include_padding_in_avg_computation));
} }
op::AvgPool::AvgPool(const shared_ptr<Node>& arg, op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: AvgPool(arg, window_shape, window_movement_strides, Shape(), Shape(), false) : AvgPool(arg, window_shape, window_movement_strides, Shape(), Shape(), false)
{ {
} }
op::AvgPool::AvgPool(const shared_ptr<Node>& arg, const Shape& window_shape) op::AvgPool::AvgPool(const Output<Node>& arg, const Shape& window_shape)
: AvgPool(arg, window_shape, Strides(), Shape(), Shape(), false) : AvgPool(arg, window_shape, Strides(), Shape(), Shape(), false)
{ {
} }
const Shape& op::AvgPool::get_window_shape() const
{
return m_window_shape;
}
void op::AvgPool::set_window_shape(const Shape& window_shape)
{
m_window_shape = window_shape;
}
const Strides& op::AvgPool::get_window_movement_strides() const
{
return m_window_movement_strides;
}
void op::AvgPool::set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& op::AvgPool::get_padding_below() const
{
return m_padding_below;
}
void op::AvgPool::set_padding_below(const Shape& padding_below)
{
m_padding_below = padding_below;
}
const Shape& op::AvgPool::get_padding_above() const
{
return m_padding_above;
}
void op::AvgPool::set_padding_above(const Shape& padding_above)
{
m_padding_above = padding_above;
}
bool op::AvgPool::get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
void op::AvgPool::get_include_padding_in_avg_computation(bool include_padding_in_avg_computation)
{
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
const op::PadType& op::AvgPool::get_pad_type() const
{
return m_pad_type;
}
void op::AvgPool::set_pad_type(const op::PadType& pad_type)
{
m_pad_type = pad_type;
}
shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -114,6 +180,12 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -114,6 +180,12 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
m_include_padding_in_avg_computation); m_include_padding_in_avg_computation);
} }
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
op::AvgPoolBackprop::AvgPoolBackprop()
{
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta, const shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
...@@ -121,7 +193,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, ...@@ -121,7 +193,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation) bool include_padding_in_avg_computation)
: Op("AvgPoolBackprop", check_single_output_args({delta})) : Op(check_single_output_args({delta}))
, m_forward_arg_shape(forward_arg_shape) , m_forward_arg_shape(forward_arg_shape)
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
...@@ -166,6 +238,67 @@ void op::AvgPoolBackprop::validate_and_infer_types() ...@@ -166,6 +238,67 @@ void op::AvgPoolBackprop::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), m_forward_arg_shape); set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
} }
const Shape& op::AvgPoolBackprop::get_forward_arg_shape() const
{
return m_forward_arg_shape;
}
void op::AvgPoolBackprop::set_forward_arg_shape(const Shape& forward_arg_shape)
{
m_forward_arg_shape = forward_arg_shape;
}
const Shape& op::AvgPoolBackprop::get_window_shape() const
{
return m_window_shape;
}
void op::AvgPoolBackprop::set_window_shape(const Shape& window_shape)
{
m_window_shape = window_shape;
}
const Strides& op::AvgPoolBackprop::get_window_movement_strides() const
{
return m_window_movement_strides;
}
void op::AvgPoolBackprop::set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& op::AvgPoolBackprop::get_padding_below() const
{
return m_padding_below;
}
void op::AvgPoolBackprop::set_padding_below(const Shape& padding_below)
{
m_padding_below = padding_below;
}
const Shape& op::AvgPoolBackprop::get_padding_above() const
{
return m_padding_above;
}
void op::AvgPoolBackprop::set_padding_above(const Shape& padding_above)
{
m_padding_above = padding_above;
}
bool op::AvgPoolBackprop::get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
void op::AvgPoolBackprop::set_include_padding_in_avg_computation(
bool include_padding_in_avg_computation)
{
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -29,9 +29,14 @@ namespace ngraph ...@@ -29,9 +29,14 @@ namespace ngraph
class AvgPool : public Op class AvgPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation.
AvgPool();
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor.<br> /// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]` /// `[d1, dn]`
/// \param window_shape The window shape.<br> /// \param window_shape The window shape.<br>
/// `[n]` /// `[n]`
...@@ -44,7 +49,7 @@ namespace ngraph ...@@ -44,7 +49,7 @@ namespace ngraph
/// \param include_padding_in_avg_computation If true then averages include padding /// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely /// elements, each treated as the number zero. If false, padding elements are entirely
/// ignored when computing averages. /// ignored when computing averages.
AvgPool(const std::shared_ptr<Node>& arg, AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -54,23 +59,23 @@ namespace ngraph ...@@ -54,23 +59,23 @@ namespace ngraph
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0). /// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0).
/// ///
/// \param arg The node producing the input data batch tensor.<br> /// \param arg The output producing the input data batch tensor.<br>
/// `[d1, ..., dn]` /// `[d1, ..., dn]`
/// \param window_shape The window shape.<br> /// \param window_shape The window shape.<br>
/// `[n]` /// `[n]`
/// \param window_movement_strides The window movement strides.<br> /// \param window_movement_strides The window movement strides.<br>
/// `[n]` /// `[n]`
AvgPool(const std::shared_ptr<Node>& arg, AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides); const Strides& window_movement_strides);
/// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1 and all padding shapes are set to 0). /// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1 and all padding shapes are set to 0).
/// ///
/// \param arg The node producing the input data batch tensor.<br> /// \param arg The output producing the input data batch tensor.<br>
/// `[d1, ..., dn]` /// `[d1, ..., dn]`
/// \param window_shape The window shape.<br> /// \param window_shape The window shape.<br>
/// `[n]` /// `[n]`
AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); AvgPool(const Output<Node>& arg, const Shape& window_shape);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -81,19 +86,22 @@ namespace ngraph ...@@ -81,19 +86,22 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const;
void set_window_shape(const Shape& window_shape);
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const;
void set_window_movement_strides(const Strides& window_movement_strides);
/// \return The below-padding shape. /// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const;
void set_padding_below(const Shape& padding_below);
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const;
bool get_include_padding_in_avg_computation() const void set_padding_above(const Shape& padding_above);
{ bool get_include_padding_in_avg_computation() const;
return m_include_padding_in_avg_computation; void get_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
}
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const;
void set_pad_type(const PadType& pad_type);
/// \return The default value for AvgPool. /// \return The default value for AvgPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -112,6 +120,9 @@ namespace ngraph ...@@ -112,6 +120,9 @@ namespace ngraph
class AvgPoolBackprop : public Op class AvgPoolBackprop : public Op
{ {
public: public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
AvgPoolBackprop();
AvgPoolBackprop(const Shape& forward_arg_shape, AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta, const std::shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
...@@ -125,15 +136,18 @@ namespace ngraph ...@@ -125,15 +136,18 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_forward_arg_shape() const { return m_forward_arg_shape; } const Shape& get_forward_arg_shape() const;
const Shape& get_window_shape() const { return m_window_shape; } void set_forward_arg_shape(const Shape& forward_arg_shape);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Shape& get_window_shape() const;
const Shape& get_padding_below() const { return m_padding_below; } void set_window_shape(const Shape& window_shape);
const Shape& get_padding_above() const { return m_padding_above; } const Strides& get_window_movement_strides() const;
bool get_include_padding_in_avg_computation() const void set_window_movement_strides(const Strides& window_movement_strides);
{ const Shape& get_padding_below() const;
return m_include_padding_in_avg_computation; void set_padding_below(const Shape& padding_below);
} const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_abve);
bool get_include_padding_in_avg_computation() const;
void set_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
protected: protected:
Shape m_forward_arg_shape; Shape m_forward_arg_shape;
......
...@@ -25,6 +25,16 @@ ...@@ -25,6 +25,16 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::Op::Op(const NodeVector& args)
: Node(args)
{
}
op::Op::Op(const OutputVector& args)
: Node(args)
{
}
op::Op::Op(const std::string& node_type, const NodeVector& args) op::Op::Op(const std::string& node_type, const NodeVector& args)
: Node(node_type, args) : Node(node_type, args)
{ {
......
...@@ -41,6 +41,12 @@ namespace ngraph ...@@ -41,6 +41,12 @@ namespace ngraph
virtual bool is_op() const override { return true; } virtual bool is_op() const override { return true; }
protected: protected:
Op()
: Node()
{
}
Op(const NodeVector& arguments);
Op(const OutputVector& arguments);
Op(const std::string& node_type, const NodeVector& arguments); Op(const std::string& node_type, const NodeVector& arguments);
private: private:
......
...@@ -19,12 +19,28 @@ ...@@ -19,12 +19,28 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::ArithmeticReduction::ArithmeticReduction()
{
}
op::util::ArithmeticReduction::ArithmeticReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(check_single_output_args({arg}))
{
set_reduction_axes(reduction_axes);
}
op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type, op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg})) : Op(node_type, check_single_output_args({arg}))
, m_reduction_axes(reduction_axes)
{ {
set_reduction_axes(reduction_axes);
}
void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_axes)
{
m_reduction_axes = reduction_axes;
} }
void op::util::ArithmeticReduction::validate_and_infer_types() void op::util::ArithmeticReduction::validate_and_infer_types()
......
...@@ -28,7 +28,23 @@ namespace ngraph ...@@ -28,7 +28,23 @@ namespace ngraph
/// are eliminated (reduced out) by repeated application of a particular binary arithmetic operation. /// are eliminated (reduced out) by repeated application of a particular binary arithmetic operation.
class ArithmeticReduction : public Op class ArithmeticReduction : public Op
{ {
public: protected:
/// \brief Constructs an arithmetic reduction operation.
ArithmeticReduction();
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation. /// \brief Constructs an arithmetic reduction operation.
/// ///
/// \param arg Node that produces the first input tensor. /// \param arg Node that produces the first input tensor.
...@@ -37,10 +53,14 @@ namespace ngraph ...@@ -37,10 +53,14 @@ namespace ngraph
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes); const AxisSet& reduction_axes);
public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction. /// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
/// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes);
protected: protected:
AxisSet m_reduction_axes; AxisSet m_reduction_axes;
}; };
......
...@@ -19,6 +19,27 @@ ...@@ -19,6 +19,27 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic()
{
}
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op({arg0, arg1})
, m_autob(autob)
{
}
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op(check_single_output_args({arg0, arg1}))
, m_autob(autob)
{
}
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic( op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(
const std::string& node_type, const std::string& node_type,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg0,
......
...@@ -47,7 +47,22 @@ namespace ngraph ...@@ -47,7 +47,22 @@ namespace ngraph
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors (after auto broadcasting). | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors (after auto broadcasting). |
class BinaryElementwiseArithmetic : public Op class BinaryElementwiseArithmetic : public Op
{ {
public: protected:
/// \brief Constructs a binary elementwise arithmetic operation.
BinaryElementwiseArithmetic();
BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise arithmetic operation.
///
/// \param arg0 Output that produces the first input tensor.
/// \param arg1 Output that produces the second input tensor.
BinaryElementwiseArithmetic(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise arithmetic operation. /// \brief Constructs a binary elementwise arithmetic operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -58,9 +73,11 @@ namespace ngraph ...@@ -58,9 +73,11 @@ namespace ngraph
const std::shared_ptr<Node>& arg1, const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; } const AutoBroadcastSpec& get_autob() const { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -19,6 +19,26 @@ ...@@ -19,6 +19,26 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison()
{
}
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op(check_single_output_args({arg0, arg1}))
, m_autob(autob)
{
}
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op({arg0, arg1})
, m_autob(autob)
{
}
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type, op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type,
const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
......
...@@ -47,7 +47,28 @@ namespace ngraph ...@@ -47,7 +47,28 @@ namespace ngraph
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseComparison : public Op class BinaryElementwiseComparison : public Op
{ {
public: protected:
/// \brief Constructs a binary elementwise comparison operation.
BinaryElementwiseComparison();
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob AutoBroadcast mode.
BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Output that produces the first input tensor.
/// \param arg1 Output that produces the second input tensor.
/// \param autob AutoBroadcast mode.
BinaryElementwiseComparison(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise comparison operation. /// \brief Constructs a binary elementwise comparison operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -58,9 +79,11 @@ namespace ngraph ...@@ -58,9 +79,11 @@ namespace ngraph
const std::shared_ptr<Node>& arg1, const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; } const AutoBroadcastSpec& get_autob() const { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -19,6 +19,26 @@ ...@@ -19,6 +19,26 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical()
{
}
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op({arg0, arg1})
, m_autob(autob)
{
}
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op(check_single_output_args({arg0, arg1}))
, m_autob(autob)
{
}
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const string& node_type, op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const string& node_type,
const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
......
...@@ -47,7 +47,25 @@ namespace ngraph ...@@ -47,7 +47,25 @@ namespace ngraph
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseLogical : public Op class BinaryElementwiseLogical : public Op
{ {
public: protected:
BinaryElementwiseLogical();
/// \brief Constructs a binary elementwise logical operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseLogical(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise logical operation.
///
/// \param arg0 Output that produces the first input tensor.
/// \param arg1 Output that produces the second input tensor.
BinaryElementwiseLogical(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise logical operation. /// \brief Constructs a binary elementwise logical operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -58,9 +76,11 @@ namespace ngraph ...@@ -58,9 +76,11 @@ namespace ngraph
const std::shared_ptr<Node>& arg1, const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; } const AutoBroadcastSpec& get_autob() const { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
private: private:
AutoBroadcastSpec m_autob; AutoBroadcastSpec m_autob;
}; };
......
...@@ -20,6 +20,16 @@ ...@@ -20,6 +20,16 @@
using namespace ngraph; using namespace ngraph;
op::util::FusedOp::FusedOp()
: Op()
{
}
op::util::FusedOp::FusedOp(const NodeVector& args)
: Op(args)
{
}
op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args) op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
: Op(node_type, args) : Op(node_type, args)
{ {
......
...@@ -44,6 +44,13 @@ namespace ngraph ...@@ -44,6 +44,13 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
protected: protected:
FusedOp();
/// \brief Constructs a FusedOp
///
/// \param args Nodes that produce the input tensors for the fused op
FusedOp(const NodeVector& args);
/// \brief Constructs a FusedOp /// \brief Constructs a FusedOp
/// ///
/// \param args Nodes that produce the input tensors for the fused op /// \param args Nodes that produce the input tensors for the fused op
......
...@@ -21,15 +21,53 @@ ...@@ -21,15 +21,53 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::IndexReduction::IndexReduction()
{
}
op::util::IndexReduction::IndexReduction(const Output<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: Op({arg})
{
set_reduction_axis(axis);
set_index_element_type(index_element_type);
}
op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: Op(check_single_output_args({arg}))
{
set_reduction_axis(axis);
set_index_element_type(index_element_type);
}
op::util::IndexReduction::IndexReduction(const std::string& node_type, op::util::IndexReduction::IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
size_t axis, size_t axis,
const element::Type& index_element_type) const element::Type& index_element_type)
: Op(node_type, check_single_output_args({arg})) : Op(node_type, check_single_output_args({arg}))
, m_axis(axis)
, m_index_element_type(index_element_type)
{ {
constructor_validate_and_infer_types(); set_reduction_axis(axis);
set_index_element_type(index_element_type);
}
size_t op::util::IndexReduction::get_reduction_axis() const
{
return m_axis;
}
void op::util::IndexReduction::set_reduction_axis(size_t value)
{
m_axis = value;
}
element::Type op::util::IndexReduction::get_index_element_type() const
{
return m_index_element_type;
}
void op::util::IndexReduction::set_index_element_type(const element::Type& index_element_type)
{
m_index_element_type = index_element_type;
} }
void op::util::IndexReduction::validate_and_infer_types() void op::util::IndexReduction::validate_and_infer_types()
......
...@@ -16,6 +16,11 @@ ...@@ -16,6 +16,11 @@
#pragma once #pragma once
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
namespace ngraph namespace ngraph
...@@ -26,14 +31,28 @@ namespace ngraph ...@@ -26,14 +31,28 @@ namespace ngraph
{ {
class IndexReduction : public Op class IndexReduction : public Op
{ {
public: protected:
size_t get_reduction_axis() const { return m_axis; } IndexReduction();
element::Type get_index_element_type() const { return m_index_element_type; }
IndexReduction(const Output<Node>& arg,
size_t axis,
const element::Type& index_element_type);
IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type);
IndexReduction(const std::string& node_type, IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
size_t axis, size_t axis,
const element::Type& index_element_type); const element::Type& index_element_type);
public:
size_t get_reduction_axis() const;
void set_reduction_axis(size_t value);
element::Type get_index_element_type() const;
void set_index_element_type(const element::Type& index_element_type);
protected: protected:
size_t m_axis; size_t m_axis;
element::Type m_index_element_type; element::Type m_index_element_type;
......
...@@ -19,12 +19,39 @@ ...@@ -19,12 +19,39 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::util::LogicalReduction::LogicalReduction()
{
}
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes)
: Op({arg})
{
set_reduction_axes(reduction_axes);
}
op::util::LogicalReduction::LogicalReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(check_single_output_args({arg}))
{
set_reduction_axes(reduction_axes);
}
op::util::LogicalReduction::LogicalReduction(const std::string& node_type, op::util::LogicalReduction::LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg})) : Op(node_type, check_single_output_args({arg}))
, m_reduction_axes(reduction_axes)
{ {
set_reduction_axes(reduction_axes);
}
const AxisSet& op::util::LogicalReduction::get_reduction_axes() const
{
return m_reduction_axes;
}
void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axes)
{
m_reduction_axes = reduction_axes;
} }
void op::util::LogicalReduction::validate_and_infer_types() void op::util::LogicalReduction::validate_and_infer_types()
......
...@@ -28,7 +28,19 @@ namespace ngraph ...@@ -28,7 +28,19 @@ namespace ngraph
/// are eliminated (reduced out) by repeated application of a particular binary logical operation. /// are eliminated (reduced out) by repeated application of a particular binary logical operation.
class LogicalReduction : public Op class LogicalReduction : public Op
{ {
public: protected:
/// \brief Constructs a logical reduction operation.
LogicalReduction();
/// \brief Constructs a logical reduction operation.
///
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation. /// \brief Constructs a logical reduction operation.
/// ///
/// \param arg Node that produces the first input tensor. /// \param arg Node that produces the first input tensor.
...@@ -37,10 +49,13 @@ namespace ngraph ...@@ -37,10 +49,13 @@ namespace ngraph
const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes); const AxisSet& reduction_axes);
public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction. /// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; } const AxisSet& get_reduction_axes() const;
void set_reduction_axes(const AxisSet& reduction_axes);
protected: protected:
AxisSet m_reduction_axes; AxisSet m_reduction_axes;
}; };
......
...@@ -18,6 +18,21 @@ ...@@ -18,6 +18,21 @@
using namespace ngraph; using namespace ngraph;
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic()
: Op()
{
}
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const Output<Node>& arg)
: Op({arg})
{
}
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: Op(check_single_output_args({arg}))
{
}
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type, op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: Op(node_type, check_single_output_args({arg})) : Op(node_type, check_single_output_args({arg}))
......
...@@ -44,12 +44,24 @@ namespace ngraph ...@@ -44,12 +44,24 @@ namespace ngraph
class UnaryElementwiseArithmetic : public Op class UnaryElementwiseArithmetic : public Op
{ {
protected: protected:
/// \brief Constructs a unary elementwise arithmetic operation.
UnaryElementwiseArithmetic();
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg);
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Output that produces the input tensor.
UnaryElementwiseArithmetic(const Output<Node>& arg);
/// \brief Constructs a unary elementwise arithmetic operation. /// \brief Constructs a unary elementwise arithmetic operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
UnaryElementwiseArithmetic(const std::string& node_type, UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg); const std::shared_ptr<Node>& arg);
public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
}; };
} }
......
...@@ -117,7 +117,7 @@ static void random_init(shared_ptr<runtime::Tensor> tv) ...@@ -117,7 +117,7 @@ static void random_init(shared_ptr<runtime::Tensor> tv)
case element::Type_t::i8: init_int_tv<int8_t>(tv, -1, 1); break; case element::Type_t::i8: init_int_tv<int8_t>(tv, -1, 1); break;
case element::Type_t::i16: init_int_tv<int16_t>(tv, -1, 1); break; case element::Type_t::i16: init_int_tv<int16_t>(tv, -1, 1); break;
case element::Type_t::i32: init_int_tv<int32_t>(tv, 0, 1); break; case element::Type_t::i32: init_int_tv<int32_t>(tv, 0, 1); break;
case element::Type_t::i64: init_int_tv<int64_t>(tv, -1, 1); break; case element::Type_t::i64: init_int_tv<int64_t>(tv, 0, 1); break;
case element::Type_t::u8: init_int_tv<uint8_t>(tv, 0, 1); break; case element::Type_t::u8: init_int_tv<uint8_t>(tv, 0, 1); break;
case element::Type_t::u16: init_int_tv<uint16_t>(tv, 0, 1); break; case element::Type_t::u16: init_int_tv<uint16_t>(tv, 0, 1); break;
case element::Type_t::u32: init_int_tv<uint32_t>(tv, 0, 1); break; case element::Type_t::u32: init_int_tv<uint32_t>(tv, 0, 1); break;
......
...@@ -97,11 +97,6 @@ TEST(build_graph, tensor) ...@@ -97,11 +97,6 @@ TEST(build_graph, tensor)
ASSERT_EQ(int32_0->get_shape(), ishape); ASSERT_EQ(int32_0->get_shape(), ishape);
} }
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
// Check functions with undeclared parameters // Check functions with undeclared parameters
TEST(build_graph, function_undeclared_parameters) TEST(build_graph, function_undeclared_parameters)
{ {
...@@ -131,3 +126,27 @@ TEST(build_graph, function_undeclared_parameters) ...@@ -131,3 +126,27 @@ TEST(build_graph, function_undeclared_parameters)
FAIL() << "Function construction failed for unexpected reason"; FAIL() << "Function construction failed for unexpected reason";
} }
} }
// Check no-arg construction
TEST(build_graph, no_arg_construction)
{
// The ops
// Parameters aren't converted yet
auto arg0 = make_shared<op::Parameter>(element::f32, Shape{7});
auto arg1 = make_shared<op::Parameter>(element::f32, Shape{7});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{7});
auto arg3 = make_shared<op::Parameter>(element::f32, Shape{7});
auto add0 = make_shared<op::Add>();
auto abs0 = make_shared<op::Abs>();
auto acos0 = make_shared<op::Acos>();
auto add1 = make_shared<op::Add>();
add0->set_argument(1, arg0);
add0->set_argument(0, arg1);
abs0->set_argument(0, add0);
acos0->set_argument(0, add0);
add1->set_argument(0, acos0);
add1->set_argument(1, abs0);
NodeVector ops{arg0, arg1, add0, abs0, acos0, add1};
validate_nodes_and_infer_types(ops);
ASSERT_EQ(add1->get_output_shape(0), Shape{7});
}
...@@ -538,6 +538,32 @@ TEST(util, enum_mask_operators) ...@@ -538,6 +538,32 @@ TEST(util, enum_mask_operators)
EXPECT_EQ(true, n[Type::b]); EXPECT_EQ(true, n[Type::b]);
} }
TEST(graph, huge)
{
Function* f;
std::vector<std::weak_ptr<Node>> weak_nodes;
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 3});
std::shared_ptr<Node> n = param;
for (size_t i = 0; i < 1000000; i++)
{
n = make_shared<op::Negative>(n);
}
f = new Function(NodeVector{n}, ParameterVector{param});
for (auto node : f->get_ops())
{
weak_nodes.push_back(node);
}
}
delete f;
for (auto weak_node : weak_nodes)
{
EXPECT_TRUE(weak_node.expired());
}
}
TEST(util, apply_permutation) TEST(util, apply_permutation)
{ {
ASSERT_EQ(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{2, 1, 0, 3}), (Shape{2, 1, 0, 3})); ASSERT_EQ(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{2, 1, 0, 3}), (Shape{2, 1, 0, 3}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment