Unverified Commit 840bf1a3 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into mlir

parents c1b08b42 a073c39e
......@@ -33,9 +33,26 @@ descriptor::Input::Input(Node* node, size_t index, Output& output)
output.add_input(this);
}
descriptor::Input::Input(Node* node, size_t index)
: m_node(node)
, m_index(index)
, m_output(nullptr)
, m_is_relevant_to_shape(false)
, m_is_relevant_to_value(true)
{
}
descriptor::Input::~Input()
{
remove_output();
}
void descriptor::Input::replace_output(Output& new_output)
{
m_output->remove_input(this);
if (m_output != nullptr)
{
m_output->remove_input(this);
}
new_output.add_input(this);
m_output = &new_output;
m_src_node = std::shared_ptr<Node>(new_output.get_node());
......@@ -56,6 +73,16 @@ void descriptor::Input::replace_output(std::shared_ptr<Node> node, size_t i)
replace_output(node->m_outputs.at(i));
}
void descriptor::Input::remove_output()
{
if (m_output != nullptr)
{
m_output->remove_input(this);
m_src_node = nullptr;
m_output = nullptr;
}
}
std::shared_ptr<Node> descriptor::Input::get_node() const
{
return m_node->shared_from_this();
......
......@@ -38,6 +38,11 @@ namespace ngraph
/// \param index The position of this this tensor in all input tensors
/// \param output The output that supplies a value for this input
Input(Node* node, size_t index, Output& output);
/// \brief Create an Input that is not connected to an output
/// \param node The node that owns this input
/// \param index The position of this this tensor in all input tensors
Input(Node* node, size_t index);
~Input();
/// \return the node that this is an input of
std::shared_ptr<Node> get_node() const;
......@@ -50,14 +55,20 @@ namespace ngraph
const Output& get_output() const { return *m_output; }
/// \return the connected output
Output& get_output() { return *m_output; }
/// \return true if an output is connected to the input.
bool has_output() const { return m_output != nullptr; }
/// \return the tensor of the connected output
const Tensor& get_tensor() const;
/// \return the tensor of the connected output
Tensor& get_tensor();
/// \brief Replace the current output that supplies a value for this input with output i of node
void replace_output(std::shared_ptr<Node> node, size_t i);
/// \brief Replace the current output that supplies a value for this input with output
void replace_output(Output& output);
/// \brief Remove the output from this input. The node will not be valid until another output is supplied.
void remove_output();
/// \return true if the value of this input is relevant to the output shapes of the
/// corresponding node. (Usually this is false.)
......
......@@ -31,6 +31,18 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
{
}
descriptor::Tensor::Tensor(const element::Type& element_type,
const PartialShape& pshape,
Node* node,
size_t node_output_number)
: m_element_type(element_type)
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
, m_partial_shape(pshape)
, m_node(node)
, m_node_output_number(node_output_number)
{
}
void descriptor::Tensor::set_tensor_type(const element::Type& element_type,
const PartialShape& pshape)
{
......@@ -95,6 +107,16 @@ void descriptor::Tensor::set_tensor_layout(
m_tensor_layout = tensor_layout;
}
const std::string& descriptor::Tensor::get_name() const
{
if (m_name.empty() && m_node != nullptr)
{
const_cast<Tensor*>(this)->m_name =
m_node->get_name() + "_" + to_string(m_node_output_number);
}
return m_name;
}
ostream& operator<<(ostream& out, const descriptor::Tensor& tensor)
{
out << "Tensor(" << tensor.get_name() << ")";
......
......@@ -45,8 +45,12 @@ namespace ngraph
Tensor(const element::Type& element_type,
const PartialShape& pshape,
const std::string& name);
Tensor(const element::Type& element_type,
const PartialShape& pshape,
Node* node,
size_t node_output_number);
const std::string& get_name() const { return m_name; }
const std::string& get_name() const;
void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
const element::Type& get_element_type() const { return m_element_type; }
......@@ -73,6 +77,8 @@ namespace ngraph
// should refactor so that get_shape returns by value.
Shape m_shape;
PartialShape m_partial_shape;
Node* m_node{nullptr};
size_t m_node_output_number{0};
std::string m_name;
std::shared_ptr<layout::TensorLayout> m_tensor_layout;
......
......@@ -33,21 +33,131 @@ using namespace ngraph;
atomic<size_t> Node::m_next_instance_id(0);
Node::Node(size_t output_size)
: Node()
{
set_output_size(output_size);
}
Node::Node(const std::string& node_type, const NodeVector& arguments, size_t output_size)
: m_node_type(node_type)
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_unique_name(description() + "_" + to_string(m_instance_id))
{
// Add this node as a user of each argument.
size_t i = 0;
set_arguments(arguments);
set_output_size(output_size);
}
Node::Node(const NodeVector& arguments, size_t output_size)
: Node()
{
set_arguments(arguments);
set_output_size(output_size);
}
Node::Node(const OutputVector& arguments, size_t output_size)
: Node()
{
set_arguments(arguments);
set_output_size(output_size);
}
Node::~Node()
{
for (descriptor::Input& input : m_inputs)
{
if (input.has_output())
{
// This test adds 1 to the actual count, so a count of 2 means this input is the only reference to the node.
if (input.get_output().get_node().use_count() == 2)
{
// Don't want to trigger a deep recursive delete
NodeVector nodes{input.get_output().get_node()};
input.remove_output();
safe_delete(nodes, true);
return;
}
input.remove_output();
}
}
}
void Node::safe_delete(NodeVector& nodes, bool recurse)
{
for (auto& input : m_inputs)
{
if (input.has_output())
{
// This test adds 1 to the actual count, so a count of 2 means this input is the only reference to the node.
auto node = input.get_output().get_node();
if (node.use_count() == 2)
{
// Move the node from the input to nodes so we don't trigger a deep recursive delete
nodes.push_back(node);
}
input.remove_output();
}
}
if (recurse)
{
while (nodes.size() > 0)
{
auto node = nodes.back();
nodes.pop_back();
node->safe_delete(nodes, false);
}
}
}
void Node::set_arguments(const NodeVector& arguments)
{
OutputVector outputs;
for (auto arg : arguments)
{
for (descriptor::Output& output : arg->m_outputs)
for (auto& output : arg->outputs())
{
m_inputs.emplace_back(this, i++, output);
outputs.push_back(output);
}
}
set_output_size(output_size);
set_arguments(outputs);
}
void Node::set_arguments(const OutputVector& arguments)
{
// Add this node as a user of each argument.
size_t i = 0;
for (auto& output : arguments)
{
auto output_node = output.get_node();
auto& output_descriptor = output_node->get_outputs().at(output.get_index());
m_inputs.emplace_back(this, i++, output_descriptor);
}
}
descriptor::Input& Node::get_input_descriptor(size_t position)
{
while (m_inputs.size() <= position)
{
m_inputs.emplace_back(this, m_inputs.size());
}
return m_inputs.at(position);
}
descriptor::Output& Node::get_output_descriptor(size_t position)
{
while (m_outputs.size() <= position)
{
size_t i = m_outputs.size();
auto tensor_descriptor =
make_shared<descriptor::Tensor>(element::dynamic, PartialShape::dynamic(), this, i);
m_outputs.emplace_back(this, i, tensor_descriptor);
}
return m_outputs.at(position);
}
void Node::set_argument(size_t position, const Output<Node>& argument)
{
auto output_node = argument.get_node();
auto& output_descriptor = output_node->get_output_descriptor(argument.get_index());
get_input_descriptor(position).replace_output(output_descriptor);
}
// While we are still doing validation and type inference in the constructor, this is true
......@@ -75,9 +185,8 @@ void Node::set_output_size(size_t n)
NGRAPH_CHECK(n >= m_outputs.size(), "shrinking ", m_outputs.size(), " to ", n);
for (size_t i = m_outputs.size(); i < n; ++i)
{
auto tensor_descriptor = make_shared<descriptor::Tensor>(
element::dynamic, PartialShape::dynamic(), get_name() + "_" + to_string(i));
m_outputs.emplace_back(this, i, tensor_descriptor);
// create the descriptors
get_output_descriptor(i);
}
}
......@@ -97,7 +206,7 @@ void Node::set_input_is_relevant_to_value(size_t i, bool relevant)
void Node::set_output_type(size_t i, const element::Type& element_type, const PartialShape& pshape)
{
m_outputs.at(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
get_output_descriptor(i).get_tensor_ptr()->set_tensor_type(element_type, pshape);
}
std::deque<descriptor::Output>& Node::get_outputs()
......@@ -134,13 +243,17 @@ const std::string& Node::get_friendly_name() const
{
if (m_friendly_name.empty())
{
return m_unique_name;
return get_name();
}
return m_friendly_name;
}
const std::string& Node::get_name() const
{
if (m_unique_name.empty())
{
const_cast<Node*>(this)->m_unique_name = description() + "_" + to_string(m_instance_id);
}
return m_unique_name;
}
......@@ -204,14 +317,6 @@ std::shared_ptr<Node> Node::get_argument(size_t index) const
return m_inputs.at(index).get_output().get_node();
}
Node::~Node()
{
for (auto& input : m_inputs)
{
input.get_output().remove_input(&input);
}
}
NodeVector Node::get_arguments() const
{
NodeVector result;
......@@ -442,18 +547,6 @@ std::string ngraph::node_validation_failure_loc_string(const Node* node)
return ss.str();
}
void ngraph::check_new_args_count(const Node* node, const NodeVector& new_args)
{
NODE_VALIDATION_CHECK(node,
new_args.size() == node->get_arguments().size(),
"copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
}
const std::shared_ptr<Node>& ngraph::check_single_output_arg(const std::shared_ptr<Node>& node,
size_t i)
{
......
......@@ -49,6 +49,7 @@ namespace ngraph
class Node;
using NodeVector = std::vector<std::shared_ptr<Node>>;
using OutputVector = std::vector<Output<Node>>;
class Function;
......@@ -108,11 +109,42 @@ namespace ngraph
void validate_and_infer_elementwise_logical(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
/// \brief Construct an unitialized Node
Node() {}
/// \brief Construct an unitialized Node
/// \param output_size Number of outputs for this node
Node(size_t output_size);
/// \brief Constructor for Node subclasses that have metaclasses.
/// \param arguments Output i will connect to input i
/// \param output_size Number of outputs for this node
Node(const OutputVector& arguments, size_t output_size = 1);
/// \brief Construct a node with arguments. Will be deprecated.
Node(const std::string& node_type, const NodeVector& arguments, size_t output_size = 1);
/// \brief Constructor for Node subclasses that have metaclasses. Will be deprecated.
/// \param arguments The 0th output of node i will connect to input i
/// \param output_size Number of outputs for this node
Node(const NodeVector& arguments, size_t output_size = 1);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
/// \brief Moves nodes that would be deleted from inputs to nodes to avoid stack overflows on deep networks.
void safe_delete(NodeVector& nodes, bool recurse);
public:
virtual ~Node();
/// Sets/replaces the arguments with new arguments.
void set_arguments(const NodeVector& arguments);
/// Sets/replaces the arguments with new arguments.
void set_arguments(const OutputVector& arguments);
/// Sets/replaces the arguments with new arguments.
void set_argument(size_t position, const Output<Node>& argument);
/// Sets the number of outputs
void set_output_size(size_t output_size);
void revalidate_and_infer_types() { validate_and_infer_types(); }
// Called after transition
void delayed_validate_and_infer_types();
......@@ -120,8 +152,7 @@ namespace ngraph
/// \brief Get the string name for the type of the node, such as `Add` or `Multiply`.
/// The class name, must not contain spaces as it is used for codegen.
/// \returns A const reference to the node's type name
const std::string& description() const;
virtual const std::string& description() const;
/// \brief Get the unique name of the node.
/// \returns A const reference to the node's unique name.
const std::string& get_name() const;
......@@ -285,10 +316,11 @@ namespace ngraph
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
// Will be deprecated
virtual NodeVector get_arguments() const;
// Will be deprecated
std::shared_ptr<Node> get_argument(size_t index) const;
// Will be replaced with an OutputVector version
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const = 0;
virtual std::vector<std::shared_ptr<Function>> get_functions() const;
......@@ -353,16 +385,15 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
Output<const Node> output(size_t output_index) const;
protected:
void set_output_size(size_t n);
private:
std::set<std::shared_ptr<Node>> m_control_dependencies;
descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position);
std::set<std::shared_ptr<Node>> m_control_dependencies;
const std::string m_node_type;
size_t m_instance_id;
size_t m_instance_id{m_next_instance_id.fetch_add(1)};
std::string m_friendly_name;
const std::string m_unique_name;
std::string m_unique_name;
static std::atomic<size_t> m_next_instance_id;
std::unordered_set<std::string> m_provenance_tags;
std::deque<descriptor::Input> m_inputs;
......@@ -409,6 +440,11 @@ namespace ngraph
{
return m_node->m_inputs.at(m_index).get_output().get_tensor();
}
/// \return A shared pointer to the tensor descriptor for this input.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_inputs.at(m_index).get_output().get_tensor_ptr();
}
/// \return true if this input is relevant to its node's output shapes; else false.
bool get_is_relevant_to_shapes() const
{
......@@ -490,6 +526,11 @@ namespace ngraph
{
return m_node->m_outputs.at(m_index).get_tensor();
}
/// \return A shared point to the tensor ptr for this output.
std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const
{
return m_node->m_outputs.at(m_index).get_tensor_ptr();
}
/// \return The element type of the output referred to by this output handle.
const element::Type& get_element_type() const
{
......@@ -684,9 +725,23 @@ namespace ngraph
const Node& m_node;
bool m_is_short;
};
void check_new_args_count(const Node* node, const NodeVector& new_args);
} // namespace ngraph
}
#define NODE_VALIDATION_CHECK(node, cond, ...) \
NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), (cond), ##__VA_ARGS__)
namespace ngraph
{
template <typename T>
void check_new_args_count(const Node* node, T new_args)
{
NODE_VALIDATION_CHECK(node,
new_args.size() == node->get_arguments().size(),
"copy_with_new_args() expected ",
node->get_arguments().size(),
" argument",
(node->get_arguments().size() == 1 ? "" : "s"),
" but got ",
new_args.size());
}
} // namespace ngraph
......@@ -21,8 +21,14 @@
using namespace std;
using namespace ngraph;
op::Abs::Abs(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Abs", arg)
const string op::Abs::type_name{"Abs"};
op::Abs::Abs()
{
}
op::Abs::Abs(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,18 +29,22 @@ namespace ngraph
class Abs : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an absolute value operation.
Abs();
/// \brief Constructs an absolute value operation.
///
/// \param arg Node that produces the input tensor.<br>
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Abs(const std::shared_ptr<Node>& arg);
Abs(const op::Abs& other, const NodeVector& new_args);
Abs(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -32,8 +32,14 @@
using namespace std;
using namespace ngraph;
op::Acos::Acos(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Acos", arg)
const string op::Acos::type_name{"Acos"};
op::Acos::Acos()
{
}
op::Acos::Acos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,17 +29,21 @@ namespace ngraph
class Acos : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arccos operation.
Acos();
/// \brief Constructs an arccos operation.
///
/// \param arg Node that produces the input tensor.<br>
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Acos(const std::shared_ptr<Node>& arg);
Acos(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -19,10 +19,14 @@
using namespace std;
using namespace ngraph;
op::Add::Add(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Add", arg0, arg1, autob)
const string op::Add::type_name{"Add"};
op::Add::Add()
{
}
op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,22 +29,27 @@ namespace ngraph
class Add : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized addition operation
Add();
/// \brief Constructs an addition operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param autob Auto broadcast specification
///
/// Output `[d0, ...]`
///
Add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -19,8 +19,14 @@
using namespace std;
using namespace ngraph;
op::All::All(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction("All", arg, reduction_axes)
const string op::All::type_name{"All"};
op::All::All()
{
}
op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,14 +29,18 @@ namespace ngraph
class All : public util::LogicalReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "all" reduction operation.
All();
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
All(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
All(const Output<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
/// \return The default value for All.
virtual std::shared_ptr<Node> get_default_value() const override
......
......@@ -19,8 +19,14 @@
using namespace std;
using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce()
{
}
op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
: Op("AllReduce", check_single_output_args({arg}))
: Op(check_single_output_args({arg}))
{
constructor_validate_and_infer_types();
}
......
......@@ -26,12 +26,15 @@ namespace ngraph
class AllReduce : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
AllReduce();
AllReduce(const std::shared_ptr<Node>& arg);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
......@@ -19,10 +19,14 @@
using namespace std;
using namespace ngraph;
op::And::And(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical("And", arg0, arg1, autob)
const string op::And::type_name{"And"};
op::And::And()
{
}
op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,22 +29,27 @@ namespace ngraph
class And : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-and operation.
And();
/// \brief Constructs a logical-and operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param autob Auto broadcast specification
///
/// Output `[d0, ...]`
///
And(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
And(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual bool is_commutative() override { return true; }
......
......@@ -19,8 +19,14 @@
using namespace std;
using namespace ngraph;
op::Any::Any(const shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction("Any", arg, reduction_axes)
const string op::Any::type_name{"Any"};
op::Any::Any()
{
}
op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,11 +29,16 @@ namespace ngraph
class Any : public util::LogicalReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "any" reduction operation.
Any();
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor to be reduced.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Any(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
Any(const Output<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -19,8 +19,21 @@
using namespace std;
using namespace ngraph;
const string op::ArgMax::type_name{"ArgMax"};
op::ArgMax::ArgMax()
{
}
op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
check_new_args_count(this, new_args);
return make_shared<ArgMax>(new_args.at(0), m_axis, this->get_element_type());
}
......@@ -28,17 +28,17 @@ namespace ngraph
class ArgMax : public op::util::IndexReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMax operation.
ArgMax();
/// \brief Constructs a ArgMax operation.
///
/// \param arg The input tensor
/// \param axis The axis along which to compute an index for maximum
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
ArgMax(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: IndexReduction("ArgMax", arg, axis, index_element_type)
{
}
ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -19,6 +19,18 @@
using namespace std;
using namespace ngraph;
const string op::ArgMin::type_name{"ArgMin"};
op::ArgMin::ArgMin()
{
}
op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -28,17 +28,18 @@ namespace ngraph
class ArgMin : public op::util::IndexReduction
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMin operation.
ArgMin();
/// \brief Constructs a ArgMin operation.
///
/// \param arg The input tensor
/// \param axis The axis along which to compute an index for minimum
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
ArgMin(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: IndexReduction("ArgMin", arg, axis, index_element_type)
{
}
ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -31,8 +31,14 @@
using namespace std;
using namespace ngraph;
op::Asin::Asin(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Asin", arg)
const string op::Asin::type_name{"Asin"};
op::Asin::Asin()
{
}
op::Asin::Asin(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,14 +29,19 @@ namespace ngraph
class Asin : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation.
Asin();
/// \brief Constructs an arcsin operation.
///
/// \param arg Node that produces the input tensor.<br>
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Asin(const std::shared_ptr<Node>& arg);
Asin(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -30,8 +30,14 @@
using namespace std;
using namespace ngraph;
op::Atan::Atan(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Atan", arg)
const string op::Atan::type_name{"Atan"};
op::Atan::Atan()
{
}
op::Atan::Atan(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -29,14 +29,20 @@ namespace ngraph
class Atan : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arctan operation.
Atan();
/// \brief Constructs an arctan operation.
///
/// \param arg Node that produces the input tensor.<br>
/// \param arg Output that produces the input tensor.<br>
/// `[d1, ...]`
///
/// Output `[d1, ...]`
///
Atan(const std::shared_ptr<Node>& arg);
Atan(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -21,14 +21,20 @@
using namespace std;
using namespace ngraph;
op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
const string op::AvgPool::type_name{"AvgPool"};
op::AvgPool::AvgPool()
{
}
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation,
const PadType& pad_type)
: Op("AvgPool", check_single_output_args({arg}))
: Op({arg})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
......@@ -91,18 +97,78 @@ void op::AvgPool::validate_and_infer_types()
m_include_padding_in_avg_computation));
}
op::AvgPool::AvgPool(const shared_ptr<Node>& arg,
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides)
: AvgPool(arg, window_shape, window_movement_strides, Shape(), Shape(), false)
{
}
op::AvgPool::AvgPool(const shared_ptr<Node>& arg, const Shape& window_shape)
op::AvgPool::AvgPool(const Output<Node>& arg, const Shape& window_shape)
: AvgPool(arg, window_shape, Strides(), Shape(), Shape(), false)
{
}
const Shape& op::AvgPool::get_window_shape() const
{
return m_window_shape;
}
void op::AvgPool::set_window_shape(const Shape& window_shape)
{
m_window_shape = window_shape;
}
const Strides& op::AvgPool::get_window_movement_strides() const
{
return m_window_movement_strides;
}
void op::AvgPool::set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& op::AvgPool::get_padding_below() const
{
return m_padding_below;
}
void op::AvgPool::set_padding_below(const Shape& padding_below)
{
m_padding_below = padding_below;
}
const Shape& op::AvgPool::get_padding_above() const
{
return m_padding_above;
}
void op::AvgPool::set_padding_above(const Shape& padding_above)
{
m_padding_above = padding_above;
}
bool op::AvgPool::get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
void op::AvgPool::get_include_padding_in_avg_computation(bool include_padding_in_avg_computation)
{
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
const op::PadType& op::AvgPool::get_pad_type() const
{
return m_pad_type;
}
void op::AvgPool::set_pad_type(const op::PadType& pad_type)
{
m_pad_type = pad_type;
}
shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......@@ -114,6 +180,12 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
m_include_padding_in_avg_computation);
}
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
op::AvgPoolBackprop::AvgPoolBackprop()
{
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta,
const Shape& window_shape,
......@@ -121,7 +193,7 @@ op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation)
: Op("AvgPoolBackprop", check_single_output_args({delta}))
: Op(check_single_output_args({delta}))
, m_forward_arg_shape(forward_arg_shape)
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
......@@ -166,6 +238,67 @@ void op::AvgPoolBackprop::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), m_forward_arg_shape);
}
const Shape& op::AvgPoolBackprop::get_forward_arg_shape() const
{
return m_forward_arg_shape;
}
void op::AvgPoolBackprop::set_forward_arg_shape(const Shape& forward_arg_shape)
{
m_forward_arg_shape = forward_arg_shape;
}
const Shape& op::AvgPoolBackprop::get_window_shape() const
{
return m_window_shape;
}
void op::AvgPoolBackprop::set_window_shape(const Shape& window_shape)
{
m_window_shape = window_shape;
}
const Strides& op::AvgPoolBackprop::get_window_movement_strides() const
{
return m_window_movement_strides;
}
void op::AvgPoolBackprop::set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& op::AvgPoolBackprop::get_padding_below() const
{
return m_padding_below;
}
void op::AvgPoolBackprop::set_padding_below(const Shape& padding_below)
{
m_padding_below = padding_below;
}
const Shape& op::AvgPoolBackprop::get_padding_above() const
{
return m_padding_above;
}
void op::AvgPoolBackprop::set_padding_above(const Shape& padding_above)
{
m_padding_above = padding_above;
}
bool op::AvgPoolBackprop::get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
void op::AvgPoolBackprop::set_include_padding_in_avg_computation(
bool include_padding_in_avg_computation)
{
m_include_padding_in_avg_computation = include_padding_in_avg_computation;
}
shared_ptr<Node> op::AvgPoolBackprop::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -29,9 +29,14 @@ namespace ngraph
class AvgPool : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation.
AvgPool();
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The node producing the input data batch tensor.<br>
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, dn]`
/// \param window_shape The window shape.<br>
/// `[n]`
......@@ -44,7 +49,7 @@ namespace ngraph
/// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely
/// ignored when computing averages.
AvgPool(const std::shared_ptr<Node>& arg,
AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......@@ -54,23 +59,23 @@ namespace ngraph
/// \brief Constructs a batched, unpadded average pooling operation (i.e., all padding shapes are set to 0).
///
/// \param arg The node producing the input data batch tensor.<br>
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, ..., dn]`
/// \param window_shape The window shape.<br>
/// `[n]`
/// \param window_movement_strides The window movement strides.<br>
/// `[n]`
AvgPool(const std::shared_ptr<Node>& arg,
AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides);
/// \brief Constructs an unstrided batched convolution operation (i.e., all window movement strides are 1 and all padding shapes are set to 0).
///
/// \param arg The node producing the input data batch tensor.<br>
/// \param arg The output producing the input data batch tensor.<br>
/// `[d1, ..., dn]`
/// \param window_shape The window shape.<br>
/// `[n]`
AvgPool(const std::shared_ptr<Node>& arg, const Shape& window_shape);
AvgPool(const Output<Node>& arg, const Shape& window_shape);
void validate_and_infer_types() override;
......@@ -81,19 +86,22 @@ namespace ngraph
const NodeVector& deltas) override;
/// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; }
const Shape& get_window_shape() const;
void set_window_shape(const Shape& window_shape);
/// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_movement_strides() const;
void set_window_movement_strides(const Strides& window_movement_strides);
/// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_below() const;
void set_padding_below(const Shape& padding_below);
/// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; }
bool get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_above);
bool get_include_padding_in_avg_computation() const;
void get_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
/// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; }
const PadType& get_pad_type() const;
void set_pad_type(const PadType& pad_type);
/// \return The default value for AvgPool.
virtual std::shared_ptr<Node> get_default_value() const override
{
......@@ -112,6 +120,9 @@ namespace ngraph
class AvgPoolBackprop : public Op
{
public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
AvgPoolBackprop();
AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta,
const Shape& window_shape,
......@@ -125,15 +136,18 @@ namespace ngraph
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const Shape& get_forward_arg_shape() const { return m_forward_arg_shape; }
const Shape& get_window_shape() const { return m_window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Shape& get_padding_below() const { return m_padding_below; }
const Shape& get_padding_above() const { return m_padding_above; }
bool get_include_padding_in_avg_computation() const
{
return m_include_padding_in_avg_computation;
}
const Shape& get_forward_arg_shape() const;
void set_forward_arg_shape(const Shape& forward_arg_shape);
const Shape& get_window_shape() const;
void set_window_shape(const Shape& window_shape);
const Strides& get_window_movement_strides() const;
void set_window_movement_strides(const Strides& window_movement_strides);
const Shape& get_padding_below() const;
void set_padding_below(const Shape& padding_below);
const Shape& get_padding_above() const;
void set_padding_above(const Shape& padding_abve);
bool get_include_padding_in_avg_computation() const;
void set_include_padding_in_avg_computation(bool include_padding_in_avg_computation);
protected:
Shape m_forward_arg_shape;
......
......@@ -25,6 +25,16 @@
using namespace ngraph;
using namespace std;
op::Op::Op(const NodeVector& args)
: Node(args)
{
}
op::Op::Op(const OutputVector& args)
: Node(args)
{
}
op::Op::Op(const std::string& node_type, const NodeVector& args)
: Node(node_type, args)
{
......
......@@ -41,6 +41,12 @@ namespace ngraph
virtual bool is_op() const override { return true; }
protected:
Op()
: Node()
{
}
Op(const NodeVector& arguments);
Op(const OutputVector& arguments);
Op(const std::string& node_type, const NodeVector& arguments);
private:
......
......@@ -19,12 +19,28 @@
using namespace std;
using namespace ngraph;
op::util::ArithmeticReduction::ArithmeticReduction()
{
}
op::util::ArithmeticReduction::ArithmeticReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(check_single_output_args({arg}))
{
set_reduction_axes(reduction_axes);
}
op::util::ArithmeticReduction::ArithmeticReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg}))
, m_reduction_axes(reduction_axes)
{
set_reduction_axes(reduction_axes);
}
void op::util::ArithmeticReduction::set_reduction_axes(const AxisSet& reduction_axes)
{
m_reduction_axes = reduction_axes;
}
void op::util::ArithmeticReduction::validate_and_infer_types()
......
......@@ -28,7 +28,23 @@ namespace ngraph
/// are eliminated (reduced out) by repeated application of a particular binary arithmetic operation.
class ArithmeticReduction : public Op
{
public:
protected:
/// \brief Constructs an arithmetic reduction operation.
ArithmeticReduction();
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Node that produces the first input tensor.
......@@ -37,10 +53,14 @@ namespace ngraph
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
public:
void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
/// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes);
protected:
AxisSet m_reduction_axes;
};
......
......@@ -19,6 +19,27 @@
using namespace std;
using namespace ngraph;
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic()
{
}
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op({arg0, arg1})
, m_autob(autob)
{
}
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op(check_single_output_args({arg0, arg1}))
, m_autob(autob)
{
}
op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(
const std::string& node_type,
const std::shared_ptr<Node>& arg0,
......
......@@ -47,7 +47,22 @@ namespace ngraph
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors (after auto broadcasting). |
class BinaryElementwiseArithmetic : public Op
{
public:
protected:
/// \brief Constructs a binary elementwise arithmetic operation.
BinaryElementwiseArithmetic();
BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise arithmetic operation.
///
/// \param arg0 Output that produces the first input tensor.
/// \param arg1 Output that produces the second input tensor.
BinaryElementwiseArithmetic(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise arithmetic operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -58,9 +73,11 @@ namespace ngraph
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
private:
AutoBroadcastSpec m_autob;
};
......
......@@ -19,6 +19,26 @@
using namespace std;
using namespace ngraph;
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison()
{
}
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op(check_single_output_args({arg0, arg1}))
, m_autob(autob)
{
}
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op({arg0, arg1})
, m_autob(autob)
{
}
op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type,
const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
......
......@@ -47,7 +47,28 @@ namespace ngraph
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseComparison : public Op
{
public:
protected:
/// \brief Constructs a binary elementwise comparison operation.
BinaryElementwiseComparison();
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob AutoBroadcast mode.
BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Output that produces the first input tensor.
/// \param arg1 Output that produces the second input tensor.
/// \param autob AutoBroadcast mode.
BinaryElementwiseComparison(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -58,9 +79,11 @@ namespace ngraph
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
private:
AutoBroadcastSpec m_autob;
};
......
......@@ -19,6 +19,26 @@
using namespace std;
using namespace ngraph;
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical()
{
}
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op({arg0, arg1})
, m_autob(autob)
{
}
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: Op(check_single_output_args({arg0, arg1}))
, m_autob(autob)
{
}
op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const string& node_type,
const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
......
......@@ -47,7 +47,25 @@ namespace ngraph
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseLogical : public Op
{
public:
protected:
BinaryElementwiseLogical();
/// \brief Constructs a binary elementwise logical operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseLogical(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise logical operation.
///
/// \param arg0 Output that produces the first input tensor.
/// \param arg1 Output that produces the second input tensor.
BinaryElementwiseLogical(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
/// \brief Constructs a binary elementwise logical operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -58,9 +76,11 @@ namespace ngraph
const std::shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
public:
void validate_and_infer_types() override;
const AutoBroadcastSpec& get_autob() const { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
private:
AutoBroadcastSpec m_autob;
};
......
......@@ -20,6 +20,16 @@
using namespace ngraph;
op::util::FusedOp::FusedOp()
: Op()
{
}
op::util::FusedOp::FusedOp(const NodeVector& args)
: Op(args)
{
}
op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
: Op(node_type, args)
{
......
......@@ -44,6 +44,13 @@ namespace ngraph
const NodeVector& deltas) override;
protected:
FusedOp();
/// \brief Constructs a FusedOp
///
/// \param args Nodes that produce the input tensors for the fused op
FusedOp(const NodeVector& args);
/// \brief Constructs a FusedOp
///
/// \param args Nodes that produce the input tensors for the fused op
......
......@@ -21,15 +21,53 @@
using namespace std;
using namespace ngraph;
op::util::IndexReduction::IndexReduction()
{
}
op::util::IndexReduction::IndexReduction(const Output<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: Op({arg})
{
set_reduction_axis(axis);
set_index_element_type(index_element_type);
}
op::util::IndexReduction::IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: Op(check_single_output_args({arg}))
{
set_reduction_axis(axis);
set_index_element_type(index_element_type);
}
op::util::IndexReduction::IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type)
: Op(node_type, check_single_output_args({arg}))
, m_axis(axis)
, m_index_element_type(index_element_type)
{
constructor_validate_and_infer_types();
set_reduction_axis(axis);
set_index_element_type(index_element_type);
}
size_t op::util::IndexReduction::get_reduction_axis() const
{
return m_axis;
}
void op::util::IndexReduction::set_reduction_axis(size_t value)
{
m_axis = value;
}
element::Type op::util::IndexReduction::get_index_element_type() const
{
return m_index_element_type;
}
void op::util::IndexReduction::set_index_element_type(const element::Type& index_element_type)
{
m_index_element_type = index_element_type;
}
void op::util::IndexReduction::validate_and_infer_types()
......
......@@ -16,6 +16,11 @@
#pragma once
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "ngraph/op/op.hpp"
namespace ngraph
......@@ -26,14 +31,28 @@ namespace ngraph
{
class IndexReduction : public Op
{
public:
size_t get_reduction_axis() const { return m_axis; }
element::Type get_index_element_type() const { return m_index_element_type; }
protected:
IndexReduction();
IndexReduction(const Output<Node>& arg,
size_t axis,
const element::Type& index_element_type);
IndexReduction(const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type);
IndexReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
size_t axis,
const element::Type& index_element_type);
public:
size_t get_reduction_axis() const;
void set_reduction_axis(size_t value);
element::Type get_index_element_type() const;
void set_index_element_type(const element::Type& index_element_type);
protected:
size_t m_axis;
element::Type m_index_element_type;
......
......@@ -19,12 +19,39 @@
using namespace std;
using namespace ngraph;
op::util::LogicalReduction::LogicalReduction()
{
}
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes)
: Op({arg})
{
set_reduction_axes(reduction_axes);
}
op::util::LogicalReduction::LogicalReduction(const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(check_single_output_args({arg}))
{
set_reduction_axes(reduction_axes);
}
op::util::LogicalReduction::LogicalReduction(const std::string& node_type,
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes)
: Op(node_type, check_single_output_args({arg}))
, m_reduction_axes(reduction_axes)
{
set_reduction_axes(reduction_axes);
}
const AxisSet& op::util::LogicalReduction::get_reduction_axes() const
{
return m_reduction_axes;
}
void op::util::LogicalReduction::set_reduction_axes(const AxisSet& reduction_axes)
{
m_reduction_axes = reduction_axes;
}
void op::util::LogicalReduction::validate_and_infer_types()
......
......@@ -28,7 +28,19 @@ namespace ngraph
/// are eliminated (reduced out) by repeated application of a particular binary logical operation.
class LogicalReduction : public Op
{
public:
protected:
/// \brief Constructs a logical reduction operation.
LogicalReduction();
/// \brief Constructs a logical reduction operation.
///
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
LogicalReduction(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs a logical reduction operation.
///
/// \param arg Node that produces the first input tensor.
......@@ -37,10 +49,13 @@ namespace ngraph
const std::shared_ptr<Node>& arg,
const AxisSet& reduction_axes);
public:
void validate_and_infer_types() override;
/// \return The axis positions (0-based) to be eliminated through reduction.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
const AxisSet& get_reduction_axes() const;
void set_reduction_axes(const AxisSet& reduction_axes);
protected:
AxisSet m_reduction_axes;
};
......
......@@ -18,6 +18,21 @@
using namespace ngraph;
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic()
: Op()
{
}
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const Output<Node>& arg)
: Op({arg})
{
}
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: Op(check_single_output_args({arg}))
{
}
op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg)
: Op(node_type, check_single_output_args({arg}))
......
......@@ -44,12 +44,24 @@ namespace ngraph
class UnaryElementwiseArithmetic : public Op
{
protected:
/// \brief Constructs a unary elementwise arithmetic operation.
UnaryElementwiseArithmetic();
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg);
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Output that produces the input tensor.
UnaryElementwiseArithmetic(const Output<Node>& arg);
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg);
public:
void validate_and_infer_types() override;
};
}
......
......@@ -117,7 +117,7 @@ static void random_init(shared_ptr<runtime::Tensor> tv)
case element::Type_t::i8: init_int_tv<int8_t>(tv, -1, 1); break;
case element::Type_t::i16: init_int_tv<int16_t>(tv, -1, 1); break;
case element::Type_t::i32: init_int_tv<int32_t>(tv, 0, 1); break;
case element::Type_t::i64: init_int_tv<int64_t>(tv, -1, 1); break;
case element::Type_t::i64: init_int_tv<int64_t>(tv, 0, 1); break;
case element::Type_t::u8: init_int_tv<uint8_t>(tv, 0, 1); break;
case element::Type_t::u16: init_int_tv<uint16_t>(tv, 0, 1); break;
case element::Type_t::u32: init_int_tv<uint32_t>(tv, 0, 1); break;
......
......@@ -97,11 +97,6 @@ TEST(build_graph, tensor)
ASSERT_EQ(int32_0->get_shape(), ishape);
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
// Check functions with undeclared parameters
TEST(build_graph, function_undeclared_parameters)
{
......@@ -131,3 +126,27 @@ TEST(build_graph, function_undeclared_parameters)
FAIL() << "Function construction failed for unexpected reason";
}
}
// Check no-arg construction
TEST(build_graph, no_arg_construction)
{
// The ops
// Parameters aren't converted yet
auto arg0 = make_shared<op::Parameter>(element::f32, Shape{7});
auto arg1 = make_shared<op::Parameter>(element::f32, Shape{7});
auto arg2 = make_shared<op::Parameter>(element::f32, Shape{7});
auto arg3 = make_shared<op::Parameter>(element::f32, Shape{7});
auto add0 = make_shared<op::Add>();
auto abs0 = make_shared<op::Abs>();
auto acos0 = make_shared<op::Acos>();
auto add1 = make_shared<op::Add>();
add0->set_argument(1, arg0);
add0->set_argument(0, arg1);
abs0->set_argument(0, add0);
acos0->set_argument(0, add0);
add1->set_argument(0, acos0);
add1->set_argument(1, abs0);
NodeVector ops{arg0, arg1, add0, abs0, acos0, add1};
validate_nodes_and_infer_types(ops);
ASSERT_EQ(add1->get_output_shape(0), Shape{7});
}
......@@ -538,6 +538,32 @@ TEST(util, enum_mask_operators)
EXPECT_EQ(true, n[Type::b]);
}
TEST(graph, huge)
{
Function* f;
std::vector<std::weak_ptr<Node>> weak_nodes;
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 3});
std::shared_ptr<Node> n = param;
for (size_t i = 0; i < 1000000; i++)
{
n = make_shared<op::Negative>(n);
}
f = new Function(NodeVector{n}, ParameterVector{param});
for (auto node : f->get_ops())
{
weak_nodes.push_back(node);
}
}
delete f;
for (auto weak_node : weak_nodes)
{
EXPECT_TRUE(weak_node.expired());
}
}
TEST(util, apply_permutation)
{
ASSERT_EQ(apply_permutation(Shape{0, 1, 2, 3}, AxisVector{2, 1, 0, 3}), (Shape{2, 1, 0, 3}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment