Unverified Commit 6f1e728a authored by Fenglei Tian's avatar Fenglei Tian Committed by GitHub

Merge branch 'master' into tfl/send_recv_op

parents 9c2230aa d0a83a35
...@@ -22,11 +22,6 @@ ...@@ -22,11 +22,6 @@
using namespace ngraph; using namespace ngraph;
NGRAPH_API const reduction::Type reduction::sum(reduction::Type_t::sum);
NGRAPH_API const reduction::Type reduction::prod(reduction::Type_t::prod);
NGRAPH_API const reduction::Type reduction::min(reduction::Type_t::min);
NGRAPH_API const reduction::Type reduction::max(reduction::Type_t::max);
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj) std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{ {
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob ...@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (obj.get_type()) switch (obj)
{ {
case reduction::Type_t::sum: out << "sum"; break; case reduction::Type::SUM: out << "SUM"; break;
case reduction::Type_t::prod: out << "prod"; break; case reduction::Type::PROD: out << "PROD"; break;
case reduction::Type_t::min: out << "min"; break; case reduction::Type::MIN: out << "MIN"; break;
case reduction::Type_t::max: out << "max"; break; case reduction::Type::MAX: out << "MAX"; break;
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
...@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob ...@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
return out; return out;
}; };
bool reduction::Type::operator==(const reduction::Type& other) const
{
return m_type == other.m_type;
}
reduction::Type_t reduction::Type::get_type() const
{
return m_type;
}
static std::unique_ptr<DistributedInterface> s_distributed_interface; static std::unique_ptr<DistributedInterface> s_distributed_interface;
void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface) void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface)
......
...@@ -26,34 +26,15 @@ namespace ngraph ...@@ -26,34 +26,15 @@ namespace ngraph
{ {
namespace reduction namespace reduction
{ {
enum class Type_t enum class Type
{ {
sum, SUM,
prod, PROD,
min, MIN,
max, MAX,
}; };
class Type
{
public:
Type(const Type_t t)
: m_type(t)
{
}
friend std::ostream& operator<<(std::ostream&, const Type&);
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
Type_t get_type() const;
private:
Type_t m_type;
};
std::ostream& operator<<(std::ostream& out, const Type& obj); std::ostream& operator<<(std::ostream& out, const Type& obj);
extern NGRAPH_API const Type sum;
extern NGRAPH_API const Type prod;
extern NGRAPH_API const Type min;
extern NGRAPH_API const Type max;
} }
class DistributedInterface class DistributedInterface
......
...@@ -92,14 +92,14 @@ namespace ngraph ...@@ -92,14 +92,14 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (reduce_type.get_type()) switch (reduce_type)
{ {
case reduction::Type_t::sum: mlsl_reduce_type = MLSL::RT_SUM; break; case reduction::Type::SUM: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type_t::prod: case reduction::Type::PROD:
throw std::runtime_error("MLSL doesn't support allreduce prod"); throw std::runtime_error("MLSL doesn't support allreduce prod");
break; break;
case reduction::Type_t::min: mlsl_reduce_type = MLSL::RT_MIN; break; case reduction::Type::MIN: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type_t::max: mlsl_reduce_type = MLSL::RT_MAX; break; case reduction::Type::MAX: mlsl_reduce_type = MLSL::RT_MAX; break;
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
......
...@@ -104,12 +104,12 @@ namespace ngraph ...@@ -104,12 +104,12 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (reduce_type.get_type()) switch (reduce_type)
{ {
case reduction::Type_t::sum: mpi_reduce_type = MPI_SUM; break; case reduction::Type::SUM: mpi_reduce_type = MPI_SUM; break;
case reduction::Type_t::prod: mpi_reduce_type = MPI_PROD; break; case reduction::Type::PROD: mpi_reduce_type = MPI_PROD; break;
case reduction::Type_t::min: mpi_reduce_type = MPI_MIN; break; case reduction::Type::MIN: mpi_reduce_type = MPI_MIN; break;
case reduction::Type_t::max: mpi_reduce_type = MPI_MAX; break; case reduction::Type::MAX: mpi_reduce_type = MPI_MAX; break;
} }
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8) #if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
......
...@@ -233,6 +233,11 @@ std::list<std::shared_ptr<ngraph::Node>> ...@@ -233,6 +233,11 @@ std::list<std::shared_ptr<ngraph::Node>>
// There is a friendly name for this node so copy it // There is a friendly name for this node so copy it
cloned_node->set_friendly_name(node->get_friendly_name()); cloned_node->set_friendly_name(node->get_friendly_name());
} }
for (auto tag : node->get_provenance_tags())
{
cloned_node->add_provenance_tag(tag);
}
node_map[node.get()] = cloned_node; node_map[node.get()] = cloned_node;
} }
} }
......
...@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args) ...@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args; return args;
} }
OutputVector ngraph::as_output_vector(const NodeVector& args)
{
OutputVector output_vector;
for (auto& arg : check_single_output_args(args))
{
output_vector.push_back(arg);
}
return output_vector;
}
std::tuple<element::Type, PartialShape> std::tuple<element::Type, PartialShape>
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob) Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
{ {
......
...@@ -73,6 +73,8 @@ namespace ngraph ...@@ -73,6 +73,8 @@ namespace ngraph
size_t i); size_t i);
const NodeVector& check_single_output_args(const NodeVector& args); const NodeVector& check_single_output_args(const NodeVector& args);
OutputVector as_output_vector(const NodeVector& args);
/// Alias useful for cloning /// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>; using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
...@@ -487,7 +489,7 @@ namespace ngraph ...@@ -487,7 +489,7 @@ namespace ngraph
/// \param node A pointer to the node for the output handle. /// \param node A pointer to the node for the output handle.
/// \param index The index of the output. /// \param index The index of the output.
Output(NodeType* node, size_t index) Output(NodeType* node, size_t index)
: m_node(node) : m_node(node->shared_from_this())
, m_index(index) , m_index(index)
{ {
} }
...@@ -498,7 +500,7 @@ namespace ngraph ...@@ -498,7 +500,7 @@ namespace ngraph
/// ///
/// TODO: Make a plan to deprecate this. /// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<NodeType>& node, size_t index) Output(const std::shared_ptr<NodeType>& node, size_t index)
: m_node(node.get()) : m_node(node)
, m_index(index) , m_index(index)
{ {
} }
...@@ -511,12 +513,15 @@ namespace ngraph ...@@ -511,12 +513,15 @@ namespace ngraph
{ {
} }
// A null output
Output() = default;
/// \return A pointer to the node referred to by this output handle. /// \return A pointer to the node referred to by this output handle.
NodeType* get_node() const { return m_node; } NodeType* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle. /// \return A `shared_ptr` to the node referred to by this output handle.
/// ///
/// TODO: Make a plan to deprecate this. /// TODO: Make a plan to deprecate this.
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node->shared_from_this(); } std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return The index of the output referred to by this output handle. /// \return The index of the output referred to by this output handle.
size_t get_index() const { return m_index; } size_t get_index() const { return m_index; }
/// \return A reference to the tensor descriptor for this output. /// \return A reference to the tensor descriptor for this output.
...@@ -568,8 +573,8 @@ namespace ngraph ...@@ -568,8 +573,8 @@ namespace ngraph
bool operator<=(const Output& other) const { return !(*this > other); } bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); } bool operator>=(const Output& other) const { return !(*this < other); }
private: private:
NodeType* const m_node; std::shared_ptr<NodeType> m_node;
const size_t m_index; size_t m_index{0};
}; };
inline Input<Node> Node::input(size_t input_index) inline Input<Node> Node::input(size_t input_index)
......
...@@ -23,10 +23,6 @@ using namespace ngraph; ...@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::Abs::type_name{"Abs"}; const string op::Abs::type_name{"Abs"};
op::Abs::Abs()
{
}
op::Abs::Abs(const Output<Node>& arg) op::Abs::Abs(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
Abs(); Abs() = default;
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
/// ///
......
...@@ -34,10 +34,6 @@ using namespace ngraph; ...@@ -34,10 +34,6 @@ using namespace ngraph;
const string op::Acos::type_name{"Acos"}; const string op::Acos::type_name{"Acos"};
op::Acos::Acos()
{
}
op::Acos::Acos(const Output<Node>& arg) op::Acos::Acos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
Acos(); Acos() = default;
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Add::type_name{"Add"}; const string op::Add::type_name{"Add"};
op::Add::Add()
{
}
op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob) op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized addition operation /// \brief Constructs an unitialized addition operation
Add(); Add() = default;
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
/// ///
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::All::type_name{"All"}; const string op::All::type_name{"All"};
op::All::All()
{
}
op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes) op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes) : LogicalReduction(arg, reduction_axes)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an "all" reduction operation. /// \brief Constructs an "all" reduction operation.
All(); All() = default;
/// \brief Constructs an "all" reduction operation. /// \brief Constructs an "all" reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -21,13 +21,8 @@ using namespace ngraph; ...@@ -21,13 +21,8 @@ using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"}; const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce() op::AllReduce::AllReduce(const Output<Node>& arg, reduction::Type reduce_type)
: m_reduce_type(reduction::sum) : Op({arg})
{
}
op::AllReduce::AllReduce(const shared_ptr<Node>& arg, const reduction::Type reduce_type)
: Op(check_single_output_args({arg}))
, m_reduce_type(reduce_type) , m_reduce_type(reduce_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const ...@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const
{ {
return m_reduce_type; return m_reduce_type;
} }
void op::AllReduce::set_reduce_type(reduction::Type reduce_type)
{
m_reduce_type = reduce_type;
}
...@@ -29,17 +29,17 @@ namespace ngraph ...@@ -29,17 +29,17 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AllReduce(); AllReduce() = default;
AllReduce(const std::shared_ptr<Node>& arg, AllReduce(const Output<Node>& arg, reduction::Type reduce_type = reduction::Type::SUM);
const reduction::Type reduce_type = reduction::sum);
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const; reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type);
private: private:
const reduction::Type m_reduce_type; reduction::Type m_reduce_type{reduction::Type::SUM};
}; };
} }
} }
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::And::type_name{"And"}; const string op::And::type_name{"And"};
op::And::And()
{
}
op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob) op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob) : BinaryElementwiseLogical(arg0, arg1, autob)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-and operation. /// \brief Constructs a logical-and operation.
And(); And() = default;
/// \brief Constructs a logical-and operation. /// \brief Constructs a logical-and operation.
/// ///
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Any::type_name{"Any"}; const string op::Any::type_name{"Any"};
op::Any::Any()
{
}
op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes) op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes) : LogicalReduction(arg, reduction_axes)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an "any" reduction operation. /// \brief Constructs an "any" reduction operation.
Any(); Any() = default;
/// \brief Constructs an "any" reduction operation. /// \brief Constructs an "any" reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMax::type_name{"ArgMax"}; const string op::ArgMax::type_name{"ArgMax"};
op::ArgMax::ArgMax()
{
}
op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type) op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type) : op::util::IndexReduction(arg, axis, index_element_type)
{ {
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMax operation. /// \brief Constructs a ArgMax operation.
ArgMax(); ArgMax() = default;
/// \brief Constructs a ArgMax operation. /// \brief Constructs a ArgMax operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMin::type_name{"ArgMin"}; const string op::ArgMin::type_name{"ArgMin"};
op::ArgMin::ArgMin()
{
}
op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type) op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type) : op::util::IndexReduction(arg, axis, index_element_type)
{ {
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMin operation. /// \brief Constructs a ArgMin operation.
ArgMin(); ArgMin() = default;
/// \brief Constructs a ArgMin operation. /// \brief Constructs a ArgMin operation.
/// ///
......
...@@ -33,10 +33,6 @@ using namespace ngraph; ...@@ -33,10 +33,6 @@ using namespace ngraph;
const string op::Asin::type_name{"Asin"}; const string op::Asin::type_name{"Asin"};
op::Asin::Asin()
{
}
op::Asin::Asin(const Output<Node>& arg) op::Asin::Asin(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
Asin(); Asin() = default;
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
......
...@@ -32,10 +32,6 @@ using namespace ngraph; ...@@ -32,10 +32,6 @@ using namespace ngraph;
const string op::Atan::type_name{"Atan"}; const string op::Atan::type_name{"Atan"};
op::Atan::Atan()
{
}
op::Atan::Atan(const Output<Node>& arg) op::Atan::Atan(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
Atan(); Atan() = default;
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
/// ///
......
...@@ -23,10 +23,6 @@ using namespace ngraph; ...@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::AvgPool::type_name{"AvgPool"}; const string op::AvgPool::type_name{"AvgPool"};
op::AvgPool::AvgPool()
{
}
op::AvgPool::AvgPool(const Output<Node>& arg, op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop"); const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
op::AvgPoolBackprop::AvgPoolBackprop()
{
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta, const shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
AvgPool(); AvgPool() = default;
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
/// ///
...@@ -175,7 +175,7 @@ namespace ngraph ...@@ -175,7 +175,7 @@ namespace ngraph
public: public:
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AvgPoolBackprop(); AvgPoolBackprop() = default;
AvgPoolBackprop(const Shape& forward_arg_shape, AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta, const std::shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
......
...@@ -22,11 +22,13 @@ ...@@ -22,11 +22,13 @@
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> input, const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"};
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
double epsilon) double epsilon)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input})) : Op({gamma, beta, input})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i ...@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps, ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input) Output<ngraph::Node> input)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input})) : Op({gamma, beta, input})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin ...@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta); adjoints.add_delta(beta, dbeta);
} }
ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node> input, const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"};
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon) double epsilon)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance})) : Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node> ...@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps, ngraph::op::BatchNormInference::BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input, Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance) Output<ngraph::Node> variance)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance})) : Op({gamma, beta, input, mean, variance})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node> ...@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma, ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> delta, Output<ngraph::Node> variance,
double epsilon) Output<ngraph::Node> delta,
: Op("BatchNormTrainingBackprop", double epsilon)
check_single_output_args({gamma, beta, input, mean, variance, delta})) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( ...@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
double epsilon, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> input, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> delta)
std::shared_ptr<ngraph::Node> delta) : Op({gamma, beta, input, mean, variance, delta})
: Op("BatchNormTrainingBackprop",
check_single_output_args({gamma, beta, input, mean, variance, delta}))
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
......
...@@ -31,13 +31,17 @@ namespace ngraph ...@@ -31,13 +31,17 @@ namespace ngraph
class BatchNormTraining : public Op class BatchNormTraining : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTraining() = default;
/// \param input Must have rank >= 2, [., C, ...] /// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(std::shared_ptr<Node> input, BatchNormTraining(Output<Node> input,
std::shared_ptr<Node> gamma, Output<Node> gamma,
std::shared_ptr<Node> beta, Output<Node> beta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -62,13 +66,14 @@ namespace ngraph ...@@ -62,13 +66,14 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis. /// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps, BatchNormTraining(double eps,
std::shared_ptr<Node> gamma, Output<Node> gamma,
std::shared_ptr<Node> beta, Output<Node> beta,
std::shared_ptr<Node> input); Output<Node> input);
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -87,17 +92,20 @@ namespace ngraph ...@@ -87,17 +92,20 @@ namespace ngraph
class BatchNormInference : public Op class BatchNormInference : public Op
{ {
public: public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormInference() = default;
/// \param input [., C, ...] /// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C] /// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C] /// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(std::shared_ptr<ngraph::Node> input, BatchNormInference(Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> variance,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -120,15 +128,16 @@ namespace ngraph ...@@ -120,15 +128,16 @@ namespace ngraph
/// output: shall have the same shape as 'input'. /// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps, BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input, Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance); Output<ngraph::Node> variance);
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -152,28 +161,33 @@ namespace ngraph ...@@ -152,28 +161,33 @@ namespace ngraph
class BatchNormTrainingBackprop : public Op class BatchNormTrainingBackprop : public Op
{ {
public: public:
BatchNormTrainingBackprop(std::shared_ptr<Node> input, NGRAPH_API
std::shared_ptr<Node> gamma, static const std::string type_name;
std::shared_ptr<Node> beta, const std::string& description() const override { return type_name; }
std::shared_ptr<Node> mean, BatchNormTrainingBackprop() = default;
std::shared_ptr<Node> variance, BatchNormTrainingBackprop(Output<Node> input,
std::shared_ptr<Node> delta, Output<Node> gamma,
Output<Node> beta,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon, BatchNormTrainingBackprop(double epsilon,
std::shared_ptr<Node> gamma, Output<Node> gamma,
std::shared_ptr<Node> beta, Output<Node> beta,
std::shared_ptr<Node> input, Output<Node> input,
std::shared_ptr<Node> mean, Output<Node> mean,
std::shared_ptr<Node> variance, Output<Node> variance,
std::shared_ptr<Node> delta); Output<Node> delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,21 +20,20 @@ ...@@ -20,21 +20,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Broadcast::Broadcast(const std::string& name, const string op::Broadcast::type_name{"Broadcast"};
const NodeVector& args,
op::Broadcast::Broadcast(const OutputVector& args,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: Op(name, check_single_output_args(args)) : Op(args)
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Broadcast::Broadcast(const shared_ptr<Node>& arg, op::Broadcast::Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
const Shape& shape, : Broadcast(OutputVector{arg}, shape, broadcast_axes)
const AxisSet& broadcast_axes)
: Broadcast("Broadcast", {arg}, shape, broadcast_axes)
{ {
} }
...@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe ...@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes)); adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
} }
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg, const string op::BroadcastLike::type_name{"BroadcastLike"};
const std::shared_ptr<Node>& like_arg,
op::BroadcastLike::BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes) const AxisSet& initial_broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {}) : Broadcast({arg, like_arg}, {}, {})
, m_initial_broadcast_axes(initial_broadcast_axes) , m_initial_broadcast_axes(initial_broadcast_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -27,15 +27,18 @@ namespace ngraph ...@@ -27,15 +27,18 @@ namespace ngraph
class Broadcast : public Op class Broadcast : public Op
{ {
public: public:
/// \brief Constructs a conversion operation. NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
/// ///
/// \param arg Node that produces the input tensor to be broadcast. /// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor. /// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The /// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg. /// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg, Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes);
const Shape& shape,
const AxisSet& broadcast_axes);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -44,12 +47,14 @@ namespace ngraph ...@@ -44,12 +47,14 @@ namespace ngraph
/// \return A set containing the indices of the broadcast axes (0-based). /// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
void set_broadcast_axes(const AxisSet& broadcast_axes)
{
m_broadcast_axes = broadcast_axes;
}
const Shape& get_broadcast_shape() const { return m_shape; } const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected: protected:
Broadcast(const std::string& node_type, Broadcast(const OutputVector& args, const Shape& shape, const AxisSet& broadcast_axes);
const NodeVector& args,
const Shape& shape,
const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
...@@ -63,6 +68,11 @@ namespace ngraph ...@@ -63,6 +68,11 @@ namespace ngraph
class BroadcastLike : public Broadcast class BroadcastLike : public Broadcast
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike() = default;
/// \brief Broadcast arg to the same shape as like_arg. /// \brief Broadcast arg to the same shape as like_arg.
/// ///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent /// Once the shape of like_arg is known, this op will be replaced with an equivalent
...@@ -72,8 +82,8 @@ namespace ngraph ...@@ -72,8 +82,8 @@ namespace ngraph
/// \param like_arg Provides the shape for the result. /// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty, /// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast. /// arg must be scalar and all axes are broadcast.
BroadcastLike(const std::shared_ptr<Node>& arg, BroadcastLike(const Output<Node>& arg,
const std::shared_ptr<Node>& like_arg, const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes); const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -81,6 +91,11 @@ namespace ngraph ...@@ -81,6 +91,11 @@ namespace ngraph
void infer_shape() override; void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; } const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
{
m_initial_broadcast_axes = initial_broadcast_axes;
}
protected: protected:
AxisSet m_initial_broadcast_axes; AxisSet m_initial_broadcast_axes;
}; };
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg, int root_id) const string op::BroadcastDistributed::type_name{"BroadcastDistributed"};
: Op("BroadcastDistributed", check_single_output_args({arg}))
op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int root_id)
: Op({arg})
, m_root_id(root_id) , m_root_id(root_id)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const ...@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
{ {
return m_root_id; return m_root_id;
} }
void op::BroadcastDistributed::set_root_id(int root_id)
{
m_root_id = root_id;
}
...@@ -27,16 +27,21 @@ namespace ngraph ...@@ -27,16 +27,21 @@ namespace ngraph
class BroadcastDistributed : public Op class BroadcastDistributed : public Op
{ {
public: public:
BroadcastDistributed(const std::shared_ptr<Node>& arg, int root_id = 0); NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BroadcastDistributed() = default;
BroadcastDistributed(const Output<Node>& arg, int root_id = 0);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const; int get_root_id() const;
void set_root_id(int root_id);
private: private:
const int m_root_id; int m_root_id;
}; };
} }
} }
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Ceiling::Ceiling(const shared_ptr<Node>& arg) const string op::Ceiling::type_name{"Ceiling"};
: UnaryElementwiseArithmetic("Ceiling", arg)
op::Ceiling::Ceiling(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Ceiling : public util::UnaryElementwiseArithmetic class Ceiling : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ceiling operation.
Ceiling() = default;
/// \brief Constructs a ceiling operation. /// \brief Constructs a ceiling operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Ceiling(const std::shared_ptr<Node>& arg); Ceiling(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -22,13 +22,20 @@ ...@@ -22,13 +22,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) const string op::Concat::type_name{"Concat"};
: Op("Concat", check_single_output_args(args))
op::Concat::Concat(const OutputVector& args, size_t concatenation_axis)
: Op(args)
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Concat(as_output_vector(args), concatenation_axis)
{
}
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required."); NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
......
...@@ -28,6 +28,17 @@ namespace ngraph ...@@ -28,6 +28,17 @@ namespace ngraph
class Concat : public Op class Concat : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a concatenation operation.
Concat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t concatenation_axis);
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
...@@ -41,10 +52,15 @@ namespace ngraph ...@@ -41,10 +52,15 @@ namespace ngraph
/// \return The concatenation axis. /// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; } size_t get_concatenation_axis() const { return m_concatenation_axis; }
void set_concatenation_axis(size_t concatenation_axis)
{
m_concatenation_axis = concatenation_axis;
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
const size_t m_concatenation_axis; size_t m_concatenation_axis;
}; };
} }
} }
...@@ -45,6 +45,8 @@ string to_cpp_string(T value) ...@@ -45,6 +45,8 @@ string to_cpp_string(T value)
return rc; return rc;
} }
const string op::Constant::type_name{"Constant"};
op::Constant::~Constant() op::Constant::~Constant()
{ {
} }
......
...@@ -34,6 +34,9 @@ namespace ngraph ...@@ -34,6 +34,9 @@ namespace ngraph
class Constant : public Node class Constant : public Node
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor constant. /// \brief Constructs a tensor constant.
/// ///
/// \param type The element type of the tensor constant. /// \param type The element type of the tensor constant.
...@@ -78,7 +81,7 @@ namespace ngraph ...@@ -78,7 +81,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data. /// \param values A list of string values to use as the constant data.
Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values) Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values)
: Node("Constant", {}) : Node({})
, m_element_type(type) , m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), , m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(),
...@@ -135,7 +138,7 @@ namespace ngraph ...@@ -135,7 +138,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param data A void* to constant data. /// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data) Constant(const element::Type& type, const Shape& shape, const void* data)
: Node("Constant", {}) : Node({})
, m_element_type(type) , m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(nullptr) , m_data(nullptr)
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_type) const string op::Convert::type_name{"Convert"};
: Op("Convert", check_single_output_args({arg}))
op::Convert::Convert(const Output<Node>& arg, const element::Type& element_type)
: Op({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -26,11 +26,16 @@ namespace ngraph ...@@ -26,11 +26,16 @@ namespace ngraph
class Convert : public Op class Convert : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a conversion operation.
Convert() = default;
/// \brief Constructs a conversion operation. /// \brief Constructs a conversion operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor. /// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type); Convert(const Output<Node>& arg, const ngraph::element::Type& element_type);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -38,8 +43,13 @@ namespace ngraph ...@@ -38,8 +43,13 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_convert_element_type() const { return m_element_type; } const element::Type& get_convert_element_type() const { return m_element_type; }
void set_convert_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
protected: protected:
const ngraph::element::Type m_element_type; ngraph::element::Type m_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -27,15 +27,17 @@ ...@@ -27,15 +27,17 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const string op::Convolution::type_name{"Convolution"};
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const PadType& pad_type) const PadType& pad_type)
: Op("Convolution", check_single_output_args({data_batch, filters})) : Op({data_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types() ...@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
...@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides) const Strides& window_dilation_strides)
: Convolution(data_batch, : Convolution(data_batch,
...@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: Convolution(data_batch, : Convolution(data_batch,
filters, filters,
...@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_ptr<Node>& filters) op::Convolution::Convolution(const Output<Node>& data_batch, const Output<Node>& filters)
: Convolution(data_batch, filters, Strides(), Strides(), CoordinateDiff(), CoordinateDiff()) : Convolution(data_batch, filters, Strides(), Strides(), CoordinateDiff(), CoordinateDiff())
{ {
} }
...@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node ...@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
m_data_dilation_strides)); m_data_dilation_strides));
} }
const string op::ConvolutionBackpropData::type_name{"ConvolutionBackpropData"};
op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_shape, op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_shape,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& output_delta, const Output<Node>& output_delta,
const Strides& window_movement_strides_forward, const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward, const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward, const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward, const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward) const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropData", check_single_output_args({filters, output_delta})) : Op({filters, output_delta})
, m_data_batch_shape(data_batch_shape) , m_data_batch_shape(data_batch_shape)
, m_window_movement_strides_forward(window_movement_strides_forward) , m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward) , m_window_dilation_strides_forward(window_dilation_strides_forward)
...@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints ...@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
m_data_dilation_strides_forward[i]); m_data_dilation_strides_forward[i]);
} }
auto swap_NC = [](const shared_ptr<Node> n) { auto swap_NC = [](const Output<Node>& n) {
AxisVector ax_order = ngraph::get_default_order(n->get_shape()); AxisVector ax_order = ngraph::get_default_order(n.get_shape());
ax_order[0] = 1; ax_order[0] = 1;
ax_order[1] = 0; ax_order[1] = 0;
auto new_shape = n->get_shape(); auto new_shape = n.get_shape();
new_shape[0] = n->get_shape()[1]; new_shape[0] = n.get_shape()[1];
new_shape[1] = n->get_shape()[0]; new_shape[1] = n.get_shape()[0];
return make_shared<op::Reshape>(n, ax_order, new_shape); return make_shared<op::Reshape>(n, ax_order, new_shape);
}; };
...@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above ...@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
return backward_delta_out_pad_above; return backward_delta_out_pad_above;
} }
const string op::ConvolutionBackpropFilters::type_name{"ConvolutionBackpropFilters"};
op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
const shared_ptr<Node>& data_batch, const Output<Node>& data_batch,
const Shape& filters_shape, const Shape& filters_shape,
const shared_ptr<Node>& output_delta, const Output<Node>& output_delta,
const Strides& window_movement_strides_forward, const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward, const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward, const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward, const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward) const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropFilters", check_single_output_args({data_batch, output_delta})) : Op({data_batch, output_delta})
, m_filters_shape(filters_shape) , m_filters_shape(filters_shape)
, m_window_movement_strides_forward(window_movement_strides_forward) , m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward) , m_window_dilation_strides_forward(window_dilation_strides_forward)
......
This diff is collapsed.
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Cos::Cos(const shared_ptr<Node>& arg) const string op::Cos::type_name{"Cos"};
: UnaryElementwiseArithmetic("Cos", arg)
op::Cos::Cos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Cos : public util::UnaryElementwiseArithmetic class Cos : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a cosine operation.
Cos() = default;
/// \brief Constructs a cosine operation. /// \brief Constructs a cosine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Cos(const std::shared_ptr<Node>& arg); Cos(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Cosh::Cosh(const shared_ptr<Node>& arg) const string op::Cosh::type_name{"Cosh"};
: UnaryElementwiseArithmetic("Cosh", arg)
op::Cosh::Cosh(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Cosh : public util::UnaryElementwiseArithmetic class Cosh : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic cosine operation.
Cosh() = default;
/// \brief Constructs a hyperbolic cosine operation. /// \brief Constructs a hyperbolic cosine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Cosh(const std::shared_ptr<Node>& arg); Cosh(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,13 +20,15 @@ ...@@ -20,13 +20,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Dequantize::Dequantize(const shared_ptr<Node>& input, const string op::Dequantize::type_name{"Dequantize"};
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point, op::Dequantize::Dequantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type, const element::Type& type,
const AxisSet& axes) const AxisSet& axes)
: Op("Dequantize", check_single_output_args({input, scale, zero_point})) : Op({input, scale, zero_point})
, m_type(type) , m_type(type)
, m_axes(axes) , m_axes(axes)
{ {
......
...@@ -30,31 +30,40 @@ namespace ngraph ...@@ -30,31 +30,40 @@ namespace ngraph
class Dequantize : public ngraph::op::Op class Dequantize : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Dequantize operation
Dequantize() = default;
/// \brief Constructs a Dequantize operation /// \brief Constructs a Dequantize operation
/// \param input quantized input /// \param input quantized input
/// \param scale scale used for mapping /// \param scale scale used for mapping
/// \param zero_point zero point used for mapping /// \param zero_point zero point used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified /// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize(const std::shared_ptr<Node>& input, Dequantize(const Output<Node>& input,
const std::shared_ptr<Node>& scale, const Output<Node>& scale,
const std::shared_ptr<Node>& zero_point, const Output<Node>& zero_point,
const ngraph::element::Type& type, const element::Type& type,
const ngraph::AxisSet& axes); const AxisSet& axes);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const ngraph::AxisSet& get_axes() const { return m_axes; } const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
const element::Type& get_type() const { return m_type; }
void set_type(const element::Type& type) { m_type = type; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private: private:
ngraph::element::Type m_type; element::Type m_type;
ngraph::AxisSet m_axes; AxisSet m_axes;
}; };
} }
} }
...@@ -21,20 +21,21 @@ ...@@ -21,20 +21,21 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Divide::Divide(const shared_ptr<Node>& arg0, const string op::Divide::type_name{"Divide"};
const shared_ptr<Node>& arg1,
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
, m_pythondiv(true)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Divide::Divide(const shared_ptr<Node>& arg0, op::Divide::Divide(const Output<Node>& arg0,
const shared_ptr<Node>& arg1, const Output<Node>& arg1,
bool pythondiv, bool pythondiv,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
, m_pythondiv(pythondiv) , m_pythondiv(pythondiv)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1) shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::Divide>(arg0, arg1);
} }
...@@ -26,14 +26,19 @@ namespace ngraph ...@@ -26,14 +26,19 @@ namespace ngraph
class Divide : public util::BinaryElementwiseArithmetic class Divide : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type /// \param pythondiv Use Python style rounding for integral type
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0, Divide(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
bool pythondiv, bool pythondiv,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
...@@ -42,11 +47,12 @@ namespace ngraph ...@@ -42,11 +47,12 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0, Divide(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
bool is_pythondiv() const { return m_pythondiv; } bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -54,10 +60,10 @@ namespace ngraph ...@@ -54,10 +60,10 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
protected: protected:
bool m_pythondiv; bool m_pythondiv{true};
}; };
} }
std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1); const Output<ngraph::Node> arg1);
} }
...@@ -29,16 +29,18 @@ ...@@ -29,16 +29,18 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Dot::Dot(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) const string op::Dot::type_name{"Dot"};
op::Dot::Dot(const Output<Node>& arg0, const Output<Node>& arg1)
: Dot(arg0, arg1, 0, false) : Dot(arg0, arg1, 0, false)
{ {
} }
op::Dot::Dot(const shared_ptr<Node>& arg0, op::Dot::Dot(const Output<Node>& arg0,
const shared_ptr<Node>& arg1, const Output<Node>& arg1,
size_t reduction_axes_count, size_t reduction_axes_count,
bool has_reduction_axes_count) bool has_reduction_axes_count)
: Op("Dot", check_single_output_args({arg0, arg1})) : Op({arg0, arg1})
, m_reduction_axes_count(reduction_axes_count) , m_reduction_axes_count(reduction_axes_count)
, m_has_reduction_axes_count(has_reduction_axes_count) , m_has_reduction_axes_count(has_reduction_axes_count)
{ {
...@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types() ...@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n, shared_ptr<op::Reshape> make_reshape_axes_to_front(const Output<Node>& n,
const Shape& front_shape, const Shape& front_shape,
const Shape& back_shape) const Shape& back_shape)
{ {
......
...@@ -28,13 +28,18 @@ namespace ngraph ...@@ -28,13 +28,18 @@ namespace ngraph
class Dot : public Op class Dot : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dot product operation.
Dot() = default;
/// \brief Constructs a dot product operation. /// \brief Constructs a dot product operation.
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument. /// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot. /// \param reduction_axes_count The number of axes to dot.
Dot(const std::shared_ptr<Node>& arg0, Dot(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
size_t reduction_axes_count, size_t reduction_axes_count,
bool has_reduction_axes_count = true); bool has_reduction_axes_count = true);
...@@ -48,11 +53,20 @@ namespace ngraph ...@@ -48,11 +53,20 @@ namespace ngraph
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument. /// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); Dot(const Output<Node>& arg0, const Output<Node>& arg1);
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; } size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count)
{
m_reduction_axes_count = reduction_axes_count;
}
bool get_has_reduction_axes_count() const { return m_has_reduction_axes_count; }
void set_has_reduction_axes_count(bool has_reduction_axes_count)
{
m_has_reduction_axes_count = has_reduction_axes_count;
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override
{ {
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::EmbeddingLookup::type_name{"EmbeddingLookup"};
void op::EmbeddingLookup::validate_and_infer_types() void op::EmbeddingLookup::validate_and_infer_types()
{ {
element::Type result_et = get_input_element_type(1); element::Type result_et = get_input_element_type(1);
......
...@@ -28,6 +28,11 @@ namespace ngraph ...@@ -28,6 +28,11 @@ namespace ngraph
class EmbeddingLookup : public Op class EmbeddingLookup : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a EmbeddingLookup operation.
EmbeddingLookup() = default;
/// \brief Constructs a EmbeddingLookup operation. /// \brief Constructs a EmbeddingLookup operation.
/// ///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor /// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
...@@ -36,8 +41,8 @@ namespace ngraph ...@@ -36,8 +41,8 @@ namespace ngraph
/// \param data The input indices for tokens to be translated into embeddings /// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N /// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M /// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup(const std::shared_ptr<Node>& data, const std::shared_ptr<Node>& weights) EmbeddingLookup(const Output<Node>& data, const Output<Node>& weights)
: Op("EmbeddingLookup", check_single_output_args({data, weights})) : Op({data, weights})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Equal::Equal(const shared_ptr<Node>& arg0, const string op::Equal::type_name{"Equal"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Equal::Equal(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Equal", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -40,13 +40,18 @@ namespace ngraph ...@@ -40,13 +40,18 @@ namespace ngraph
class Equal : public util::BinaryElementwiseComparison class Equal : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs an is-equal operation. NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Equal(const std::shared_ptr<Node>& arg0, Equal(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -21,14 +21,16 @@ ...@@ -21,14 +21,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Erf::type_name{"Erf"};
shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Erf>(new_args.at(0)); return make_shared<Erf>(new_args.at(0));
} }
op::Erf::Erf(shared_ptr<Node> arg) op::Erf::Erf(const Output<Node>& arg)
: UnaryElementwiseArithmetic("Erf", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -27,7 +27,11 @@ namespace ngraph ...@@ -27,7 +27,11 @@ namespace ngraph
class Erf : public util::UnaryElementwiseArithmetic class Erf : public util::UnaryElementwiseArithmetic
{ {
public: public:
Erf(std::shared_ptr<Node> arg); NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Erf() = default;
Erf(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Exp::Exp(const shared_ptr<Node>& arg) const string op::Exp::type_name{"Exp"};
: UnaryElementwiseArithmetic("Exp", arg)
op::Exp::Exp(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Exp : public util::UnaryElementwiseArithmetic class Exp : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an exponential operation.
Exp() = default;
/// \brief Constructs an exponential operation. /// \brief Constructs an exponential operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Exp(const std::shared_ptr<Node>& arg); Exp(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Floor::Floor(const shared_ptr<Node>& arg) const string op::Floor::type_name{"Floor"};
: UnaryElementwiseArithmetic("Floor", arg)
op::Floor::Floor(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Floor : public util::UnaryElementwiseArithmetic class Floor : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a floor operation.
Floor() = default;
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Floor(const std::shared_ptr<Node>& arg); Floor(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -72,6 +72,13 @@ void op::Split::pre_validate_and_infer_types() ...@@ -72,6 +72,13 @@ void op::Split::pre_validate_and_infer_types()
dimension_at_axis, dimension_at_axis,
" has to be equal to the sum of splits passed to the op: ", " has to be equal to the sum of splits passed to the op: ",
sum_splits); sum_splits);
const bool all_splits_positive =
all_of(begin(m_splits), end(m_splits), [](const size_t v) { return v > 0; });
NODE_VALIDATION_CHECK(this,
all_splits_positive == true,
"All values of the 'splits' attribute must be greater than zero");
} }
} }
......
...@@ -24,10 +24,12 @@ ...@@ -24,10 +24,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Reshape::Reshape(const shared_ptr<Node>& arg, const string op::Reshape::type_name{"Reshape"};
op::Reshape::Reshape(const Output<Node>& arg,
const AxisVector& input_order, const AxisVector& input_order,
const Shape& output_shape) const Shape& output_shape)
: Op("Reshape", check_single_output_args({arg})) : Op({arg})
, m_input_order(input_order) , m_input_order(input_order)
, m_output_shape(output_shape) , m_output_shape(output_shape)
{ {
......
...@@ -60,6 +60,11 @@ namespace ngraph ...@@ -60,6 +60,11 @@ namespace ngraph
class Reshape : public Op class Reshape : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a reshape operation.
Reshape() = default;
/// \brief Constructs a reshape operation. /// \brief Constructs a reshape operation.
/// ///
/// \param arg The tensor to be reshaped. /// \param arg The tensor to be reshaped.
...@@ -67,7 +72,7 @@ namespace ngraph ...@@ -67,7 +72,7 @@ namespace ngraph
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor. /// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must /// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$. /// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg, Reshape(const Output<Node>& arg,
const AxisVector& input_order, const AxisVector& input_order,
const Shape& output_shape); const Shape& output_shape);
...@@ -78,15 +83,18 @@ namespace ngraph ...@@ -78,15 +83,18 @@ namespace ngraph
/// \return The order in which to iterate over input axes. /// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; } const AxisVector& get_input_order() const { return m_input_order; }
void set_input_order(const AxisVector& input_order) { m_input_order = input_order; }
/// \return The shape of the output tensor. /// \return The shape of the output tensor.
const Shape& get_output_shape() const { return m_output_shape; } const Shape& get_output_shape() const { return m_output_shape; }
void set_output_shape(const Shape& output_shape) { m_output_shape = output_shape; }
bool get_is_transpose() const { return m_is_transpose; } bool get_is_transpose() const { return m_is_transpose; }
void set_is_transpose(bool is_transpose) { m_is_transpose = is_transpose; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
const AxisVector m_input_order; AxisVector m_input_order;
const Shape m_output_shape; Shape m_output_shape;
bool m_is_transpose{false}; bool m_is_transpose{false};
}; };
} }
......
...@@ -24,8 +24,10 @@ ...@@ -24,8 +24,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Result::Result(const shared_ptr<Node>& arg, bool needs_default_layout) const string op::Result::type_name{"Result"};
: Op("Result", check_single_output_args({arg}))
op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
: Op({arg})
, m_needs_default_layout(needs_default_layout) , m_needs_default_layout(needs_default_layout)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -27,10 +27,15 @@ namespace ngraph ...@@ -27,10 +27,15 @@ namespace ngraph
class Result : public Op class Result : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result. /// \brief Allows a value to be used as a function result.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Result(const std::shared_ptr<Node>& arg, bool needs_default_layout = false); Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -298,9 +298,14 @@ namespace ngraph ...@@ -298,9 +298,14 @@ namespace ngraph
if (graph_node->is_commutative()) if (graph_node->is_commutative())
{ {
std::sort( // TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
begin(pattern_args), std::sort(begin(pattern_args),
end(pattern_args)); // TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() < n2->get_instance_id();
});
do do
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node " NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
...@@ -311,7 +316,13 @@ namespace ngraph ...@@ -311,7 +316,13 @@ namespace ngraph
pattern_map.insert(begin(copy), end(copy)); pattern_map.insert(begin(copy), end(copy));
return true; return true;
} }
} while (std::next_permutation(begin(pattern_args), end(pattern_args))); } while (std::next_permutation(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() <
n2->get_instance_id();
}));
} }
else else
{ {
......
...@@ -90,3 +90,9 @@ std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& input_strea ...@@ -90,3 +90,9 @@ std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& input_strea
{ {
throw runtime_error("load opertion unimplemented."); throw runtime_error("load opertion unimplemented.");
} }
bool runtime::Backend::set_config(const map<string, string>& config, string& error)
{
error = "set_config not supported";
return false;
}
...@@ -139,4 +139,13 @@ public: ...@@ -139,4 +139,13 @@ public:
/// \param op_name is the name of the backend specific op /// \param op_name is the name of the backend specific op
/// \returns a shared pointer to the op if found, else nullptr /// \returns a shared pointer to the op if found, else nullptr
virtual std::shared_ptr<ngraph::Node> get_backend_op(const std::string& op_name, ...); virtual std::shared_ptr<ngraph::Node> get_backend_op(const std::string& op_name, ...);
/// \brief Allows sending backend specific configuration. The map contains key, value pairs
/// specific to a particluar backend. The definition of these key, value pairs is
/// defined by each backend.
/// \param config The configuration map sent to the backend
/// \param error An error string describing any error encountered
/// \returns true if the configuration is supported, false otherwise. On false the error
/// parameter value is valid.
virtual bool set_config(const std::map<std::string, std::string>& config, std::string& error);
}; };
...@@ -105,3 +105,16 @@ std::shared_ptr<runtime::Executable> runtime::interpreter::INTBackend::load(istr ...@@ -105,3 +105,16 @@ std::shared_ptr<runtime::Executable> runtime::interpreter::INTBackend::load(istr
} }
return exec; return exec;
} }
bool runtime::interpreter::INTBackend::set_config(const map<string, string>& config, string& error)
{
bool rc = false;
auto it = config.find("test_echo");
error = "";
if (it != config.end())
{
error = it->second;
rc = true;
}
return rc;
}
...@@ -58,6 +58,8 @@ public: ...@@ -58,6 +58,8 @@ public:
bool is_supported(const Node& node) const override; bool is_supported(const Node& node) const override;
bool set_config(const std::map<std::string, std::string>& config, std::string& error) override;
private: private:
std::set<std::string> m_unsupported_op_name_list; std::set<std::string> m_unsupported_op_name_list;
}; };
...@@ -141,6 +141,7 @@ ...@@ -141,6 +141,7 @@
#include "ngraph/op/tan.hpp" #include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp" #include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
...@@ -1803,6 +1804,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js) ...@@ -1803,6 +1804,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
{ {
node->set_friendly_name(node_name); node->set_friendly_name(node_name);
} }
if (ngraph::get_provenance_enabled())
{
std::vector<json> prov_js = node_js.at("provenance_tags");
for (auto prov_tag : prov_js)
{
node->add_provenance_tag(prov_tag);
}
}
m_node_map[node_name] = node; m_node_map[node_name] = node;
} }
catch (...) catch (...)
...@@ -1914,6 +1923,15 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -1914,6 +1923,15 @@ json JSONSerializer::serialize_node(const Node& n)
} }
node["output_shapes"] = output_shapes; node["output_shapes"] = output_shapes;
} }
if (ngraph::get_provenance_enabled())
{
json provenance_tags = json::array();
for (auto prov_tag : n.get_provenance_tags())
{
provenance_tags.push_back(prov_tag);
}
node["provenance_tags"] = provenance_tags;
}
string node_op = n.description(); string node_op = n.description();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
......
...@@ -37,6 +37,28 @@ TEST(backend_api, invalid_name) ...@@ -37,6 +37,28 @@ TEST(backend_api, invalid_name)
ASSERT_ANY_THROW(ngraph::runtime::Backend::create("COMPLETELY-BOGUS-NAME")); ASSERT_ANY_THROW(ngraph::runtime::Backend::create("COMPLETELY-BOGUS-NAME"));
} }
TEST(backend_api, config)
{
auto backend = runtime::Backend::create("INTERPRETER");
string error;
string message = "hello";
map<string, string> config = {{"test_echo", message}};
EXPECT_TRUE(backend->set_config(config, error));
EXPECT_STREQ(error.c_str(), message.c_str());
EXPECT_FALSE(backend->set_config({}, error));
EXPECT_STREQ(error.c_str(), "");
}
TEST(backend_api, config_unsupported)
{
auto backend = runtime::Backend::create("NOP");
string error;
string message = "hello";
map<string, string> config = {{"test_echo", message}};
EXPECT_FALSE(backend->set_config(config, error));
EXPECT_FALSE(error == "");
}
#ifndef NGRAPH_JSON_DISABLE #ifndef NGRAPH_JSON_DISABLE
TEST(backend_api, save_load) TEST(backend_api, save_load)
{ {
......
...@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type) ...@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type)
#pragma GCC diagnostic error "-Wswitch" #pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum" #pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (reduce_type.get_type()) switch (reduce_type)
{ {
case reduction::Type_t::sum: case reduction::Type::SUM:
copy_data(a, v); copy_data(a, v);
std::transform( std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size)); v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
break; break;
case reduction::Type_t::prod: case reduction::Type::PROD:
copy_data(a, v); copy_data(a, v);
std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float { std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float {
return pow(elm, comm_size); return pow(elm, comm_size);
}); });
break; break;
case reduction::Type_t::min: case reduction::Type::MIN:
case reduction::Type_t::max: case reduction::Type::MAX:
auto shift = get_distributed_interface()->get_rank(); auto shift = get_distributed_interface()->get_rank();
std::rotate(v.begin(), v.begin() + shift % v.size(), v.end()); std::rotate(v.begin(), v.begin() + shift % v.size(), v.end());
copy_data(a, v); copy_data(a, v);
if (reduce_type == reduction::Type_t::min) if (reduce_type == reduction::Type::MIN)
{ {
std::fill(v.begin(), v.end(), 1); std::fill(v.begin(), v.end(), 1);
for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++) for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++)
...@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type) ...@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type)
TEST(distributed_${BACKEND_NAME}, allreduce_sum) TEST(distributed_${BACKEND_NAME}, allreduce_sum)
{ {
test_allreduce_common(reduction::sum); test_allreduce_common(reduction::Type::SUM);
} }
TEST(distributed_${BACKEND_NAME}, allreduce_min) TEST(distributed_${BACKEND_NAME}, allreduce_min)
{ {
test_allreduce_common(reduction::min); test_allreduce_common(reduction::Type::MIN);
} }
TEST(distributed_${BACKEND_NAME}, allreduce_max) TEST(distributed_${BACKEND_NAME}, allreduce_max)
{ {
test_allreduce_common(reduction::max); test_allreduce_common(reduction::Type::MAX);
} }
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE) #if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod) TEST(distributed_${BACKEND_NAME}, allreduce_prod)
{ {
test_allreduce_common(reduction::prod); test_allreduce_common(reduction::Type::PROD);
} }
#endif #endif
......
...@@ -514,6 +514,33 @@ TEST(pattern, previous_matches) ...@@ -514,6 +514,33 @@ TEST(pattern, previous_matches)
} }
} }
TEST(pattern, test_sort)
{
using ngraph::pattern::Matcher;
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto abs1 = make_shared<op::Abs>(a);
auto abs2 = make_shared<op::Abs>(b);
auto add = abs1 + abs2;
auto pa = make_shared<op::Parameter>(element::i32, shape);
auto pb = make_shared<op::Parameter>(element::i32, shape);
auto pabs1 = make_shared<op::Abs>(pa);
auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
auto pabs2 = make_shared<op::Abs>(b);
auto padd = pabs1_label + pabs2;
{
Matcher n1(padd);
ASSERT_TRUE(n1.match(add));
auto r1 = n1.get_pattern_map()[pabs1_label];
ASSERT_TRUE(n1.match(add));
ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
}
}
TEST(pattern, recurrent_pattern) TEST(pattern, recurrent_pattern)
{ {
using ngraph::pattern::RecurrentMatcher; using ngraph::pattern::RecurrentMatcher;
......
...@@ -90,20 +90,7 @@ namespace ngraph ...@@ -90,20 +90,7 @@ namespace ngraph
auto c_vec = read_vector<T>(c_arg); auto c_vec = read_vector<T>(c_arg);
fill(c_vec.begin(), c_vec.end(), static_cast<T>(0)); fill(c_vec.begin(), c_vec.end(), static_cast<T>(0));
static std::unordered_map<std::shared_ptr<Function>, auto df_handle = backend->compile(df);
std::shared_ptr<runtime::Executable>>
s_compiled_functions;
auto it = s_compiled_functions.find(df);
std::shared_ptr<runtime::Executable> df_handle;
if (it == s_compiled_functions.end())
{
df_handle = backend->compile(df);
s_compiled_functions.insert({df, df_handle});
}
else
{
df_handle = it->second;
}
// for each element of the adjoint // for each element of the adjoint
// same as saying for each element of y // same as saying for each element of y
...@@ -212,20 +199,7 @@ namespace ngraph ...@@ -212,20 +199,7 @@ namespace ngraph
s_clone_fwd_map[f] = clone_function(*fprop_cache.fprop); s_clone_fwd_map[f] = clone_function(*fprop_cache.fprop);
} }
auto clone_fwd = s_clone_fwd_map[f]; auto clone_fwd = s_clone_fwd_map[f];
static std::unordered_map<std::shared_ptr<Function>, auto clone_fwd_handle = backend->compile(clone_fwd);
std::shared_ptr<runtime::Executable>>
s_compiled_functions;
auto it = s_compiled_functions.find(clone_fwd);
std::shared_ptr<runtime::Executable> clone_fwd_handle;
if (it == s_compiled_functions.end())
{
clone_fwd_handle = backend->compile(clone_fwd);
s_compiled_functions.insert({clone_fwd, clone_fwd_handle});
}
else
{
clone_fwd_handle = it->second;
}
clone_fwd_handle->call_with_validate(mod_f_output_args, f_input_args); clone_fwd_handle->call_with_validate(mod_f_output_args, f_input_args);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment