Unverified Commit 6f1e728a authored by Fenglei Tian's avatar Fenglei Tian Committed by GitHub

Merge branch 'master' into tfl/send_recv_op

parents 9c2230aa d0a83a35
......@@ -22,11 +22,6 @@
using namespace ngraph;
NGRAPH_API const reduction::Type reduction::sum(reduction::Type_t::sum);
NGRAPH_API const reduction::Type reduction::prod(reduction::Type_t::prod);
NGRAPH_API const reduction::Type reduction::min(reduction::Type_t::min);
NGRAPH_API const reduction::Type reduction::max(reduction::Type_t::max);
std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& obj)
{
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
......@@ -34,12 +29,12 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (obj.get_type())
switch (obj)
{
case reduction::Type_t::sum: out << "sum"; break;
case reduction::Type_t::prod: out << "prod"; break;
case reduction::Type_t::min: out << "min"; break;
case reduction::Type_t::max: out << "max"; break;
case reduction::Type::SUM: out << "SUM"; break;
case reduction::Type::PROD: out << "PROD"; break;
case reduction::Type::MIN: out << "MIN"; break;
case reduction::Type::MAX: out << "MAX"; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......@@ -47,16 +42,6 @@ std::ostream& reduction::operator<<(std::ostream& out, const reduction::Type& ob
return out;
};
bool reduction::Type::operator==(const reduction::Type& other) const
{
return m_type == other.m_type;
}
reduction::Type_t reduction::Type::get_type() const
{
return m_type;
}
static std::unique_ptr<DistributedInterface> s_distributed_interface;
void ngraph::set_distributed_interface(std::unique_ptr<DistributedInterface> distributed_interface)
......
......@@ -26,34 +26,15 @@ namespace ngraph
{
namespace reduction
{
enum class Type_t
enum class Type
{
sum,
prod,
min,
max,
SUM,
PROD,
MIN,
MAX,
};
class Type
{
public:
Type(const Type_t t)
: m_type(t)
{
}
friend std::ostream& operator<<(std::ostream&, const Type&);
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
Type_t get_type() const;
private:
Type_t m_type;
};
std::ostream& operator<<(std::ostream& out, const Type& obj);
extern NGRAPH_API const Type sum;
extern NGRAPH_API const Type prod;
extern NGRAPH_API const Type min;
extern NGRAPH_API const Type max;
}
class DistributedInterface
......
......@@ -92,14 +92,14 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
switch (reduce_type)
{
case reduction::Type_t::sum: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type_t::prod:
case reduction::Type::SUM: mlsl_reduce_type = MLSL::RT_SUM; break;
case reduction::Type::PROD:
throw std::runtime_error("MLSL doesn't support allreduce prod");
break;
case reduction::Type_t::min: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type_t::max: mlsl_reduce_type = MLSL::RT_MAX; break;
case reduction::Type::MIN: mlsl_reduce_type = MLSL::RT_MIN; break;
case reduction::Type::MAX: mlsl_reduce_type = MLSL::RT_MAX; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......
......@@ -104,12 +104,12 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
switch (reduce_type)
{
case reduction::Type_t::sum: mpi_reduce_type = MPI_SUM; break;
case reduction::Type_t::prod: mpi_reduce_type = MPI_PROD; break;
case reduction::Type_t::min: mpi_reduce_type = MPI_MIN; break;
case reduction::Type_t::max: mpi_reduce_type = MPI_MAX; break;
case reduction::Type::SUM: mpi_reduce_type = MPI_SUM; break;
case reduction::Type::PROD: mpi_reduce_type = MPI_PROD; break;
case reduction::Type::MIN: mpi_reduce_type = MPI_MIN; break;
case reduction::Type::MAX: mpi_reduce_type = MPI_MAX; break;
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
......
......@@ -233,6 +233,11 @@ std::list<std::shared_ptr<ngraph::Node>>
// There is a friendly name for this node so copy it
cloned_node->set_friendly_name(node->get_friendly_name());
}
for (auto tag : node->get_provenance_tags())
{
cloned_node->add_provenance_tag(tag);
}
node_map[node.get()] = cloned_node;
}
}
......
......@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args;
}
OutputVector ngraph::as_output_vector(const NodeVector& args)
{
OutputVector output_vector;
for (auto& arg : check_single_output_args(args))
{
output_vector.push_back(arg);
}
return output_vector;
}
std::tuple<element::Type, PartialShape>
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
{
......
......@@ -73,6 +73,8 @@ namespace ngraph
size_t i);
const NodeVector& check_single_output_args(const NodeVector& args);
OutputVector as_output_vector(const NodeVector& args);
/// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
......@@ -487,7 +489,7 @@ namespace ngraph
/// \param node A pointer to the node for the output handle.
/// \param index The index of the output.
Output(NodeType* node, size_t index)
: m_node(node)
: m_node(node->shared_from_this())
, m_index(index)
{
}
......@@ -498,7 +500,7 @@ namespace ngraph
///
/// TODO: Make a plan to deprecate this.
Output(const std::shared_ptr<NodeType>& node, size_t index)
: m_node(node.get())
: m_node(node)
, m_index(index)
{
}
......@@ -511,12 +513,15 @@ namespace ngraph
{
}
// A null output
Output() = default;
/// \return A pointer to the node referred to by this output handle.
NodeType* get_node() const { return m_node; }
NodeType* get_node() const { return m_node.get(); }
/// \return A `shared_ptr` to the node referred to by this output handle.
///
/// TODO: Make a plan to deprecate this.
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node->shared_from_this(); }
std::shared_ptr<NodeType> get_node_shared_ptr() const { return m_node; }
/// \return The index of the output referred to by this output handle.
size_t get_index() const { return m_index; }
/// \return A reference to the tensor descriptor for this output.
......@@ -568,8 +573,8 @@ namespace ngraph
bool operator<=(const Output& other) const { return !(*this > other); }
bool operator>=(const Output& other) const { return !(*this < other); }
private:
NodeType* const m_node;
const size_t m_index;
std::shared_ptr<NodeType> m_node;
size_t m_index{0};
};
inline Input<Node> Node::input(size_t input_index)
......
......@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::Abs::type_name{"Abs"};
op::Abs::Abs()
{
}
op::Abs::Abs(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an absolute value operation.
Abs();
Abs() = default;
/// \brief Constructs an absolute value operation.
///
......
......@@ -34,10 +34,6 @@ using namespace ngraph;
const string op::Acos::type_name{"Acos"};
op::Acos::Acos()
{
}
op::Acos::Acos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arccos operation.
Acos();
Acos() = default;
/// \brief Constructs an arccos operation.
///
/// \param arg Output that produces the input tensor.<br>
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Add::type_name{"Add"};
op::Add::Add()
{
}
op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized addition operation
Add();
Add() = default;
/// \brief Constructs an addition operation.
///
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::All::type_name{"All"};
op::All::All()
{
}
op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "all" reduction operation.
All();
All() = default;
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor to be reduced.
......
......@@ -21,13 +21,8 @@ using namespace ngraph;
const string op::AllReduce::type_name{"AllReduce"};
op::AllReduce::AllReduce()
: m_reduce_type(reduction::sum)
{
}
op::AllReduce::AllReduce(const shared_ptr<Node>& arg, const reduction::Type reduce_type)
: Op(check_single_output_args({arg}))
op::AllReduce::AllReduce(const Output<Node>& arg, reduction::Type reduce_type)
: Op({arg})
, m_reduce_type(reduce_type)
{
constructor_validate_and_infer_types();
......@@ -56,3 +51,8 @@ reduction::Type op::AllReduce::get_reduce_type() const
{
return m_reduce_type;
}
void op::AllReduce::set_reduce_type(reduction::Type reduce_type)
{
m_reduce_type = reduce_type;
}
......@@ -29,17 +29,17 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
AllReduce();
AllReduce(const std::shared_ptr<Node>& arg,
const reduction::Type reduce_type = reduction::sum);
AllReduce() = default;
AllReduce(const Output<Node>& arg, reduction::Type reduce_type = reduction::Type::SUM);
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
reduction::Type get_reduce_type() const;
void set_reduce_type(reduction::Type reduce_type);
private:
const reduction::Type m_reduce_type;
reduction::Type m_reduce_type{reduction::Type::SUM};
};
}
}
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::And::type_name{"And"};
op::And::And()
{
}
op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-and operation.
And();
And() = default;
/// \brief Constructs a logical-and operation.
///
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Any::type_name{"Any"};
op::Any::Any()
{
}
op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "any" reduction operation.
Any();
Any() = default;
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor to be reduced.
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMax::type_name{"ArgMax"};
op::ArgMax::ArgMax()
{
}
op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
......
......@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMax operation.
ArgMax();
ArgMax() = default;
/// \brief Constructs a ArgMax operation.
///
/// \param arg The input tensor
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMin::type_name{"ArgMin"};
op::ArgMin::ArgMin()
{
}
op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
......
......@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMin operation.
ArgMin();
ArgMin() = default;
/// \brief Constructs a ArgMin operation.
///
......
......@@ -33,10 +33,6 @@ using namespace ngraph;
const string op::Asin::type_name{"Asin"};
op::Asin::Asin()
{
}
op::Asin::Asin(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation.
Asin();
Asin() = default;
/// \brief Constructs an arcsin operation.
///
/// \param arg Output that produces the input tensor.<br>
......
......@@ -32,10 +32,6 @@ using namespace ngraph;
const string op::Atan::type_name{"Atan"};
op::Atan::Atan()
{
}
op::Atan::Atan(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arctan operation.
Atan();
Atan() = default;
/// \brief Constructs an arctan operation.
///
......
......@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::AvgPool::type_name{"AvgPool"};
op::AvgPool::AvgPool()
{
}
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
......@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
op::AvgPoolBackprop::AvgPoolBackprop()
{
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta,
const Shape& window_shape,
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation.
AvgPool();
AvgPool() = default;
/// \brief Constructs a batched average pooling operation.
///
......@@ -175,7 +175,7 @@ namespace ngraph
public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
AvgPoolBackprop();
AvgPoolBackprop() = default;
AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta,
const Shape& window_shape,
......
......@@ -22,11 +22,13 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"};
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
double epsilon)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input}))
: Op({gamma, beta, input})
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
......@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
// DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input}))
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input)
: Op({gamma, beta, input})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
......@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta);
}
ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance}))
: Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
......@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
// DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance}))
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
......@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
}
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> delta,
double epsilon)
: Op("BatchNormTrainingBackprop",
check_single_output_args({gamma, beta, input, mean, variance, delta}))
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta,
double epsilon)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
{
......@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
constructor_validate_and_infer_types();
}
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
double epsilon,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> delta)
: Op("BatchNormTrainingBackprop",
check_single_output_args({gamma, beta, input, mean, variance, delta}))
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
{
......
......@@ -31,13 +31,17 @@ namespace ngraph
class BatchNormTraining : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTraining() = default;
/// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(std::shared_ptr<Node> input,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
BatchNormTraining(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -62,13 +66,14 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
Output<Node> gamma,
Output<Node> beta,
Output<Node> input);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -87,17 +92,20 @@ namespace ngraph
class BatchNormInference : public Op
{
public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -120,15 +128,16 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance);
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -152,28 +161,33 @@ namespace ngraph
class BatchNormTrainingBackprop : public Op
{
public:
BatchNormTrainingBackprop(std::shared_ptr<Node> input,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta,
double epsilon);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input,
Output<Node> gamma,
Output<Node> beta,
Output<Node> input,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta);
Output<Node> mean,
Output<Node> variance,
Output<Node> delta);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,21 +20,20 @@
using namespace std;
using namespace ngraph;
op::Broadcast::Broadcast(const std::string& name,
const NodeVector& args,
const string op::Broadcast::type_name{"Broadcast"};
op::Broadcast::Broadcast(const OutputVector& args,
const Shape& shape,
const AxisSet& broadcast_axes)
: Op(name, check_single_output_args(args))
: Op(args)
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
constructor_validate_and_infer_types();
}
op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Broadcast("Broadcast", {arg}, shape, broadcast_axes)
op::Broadcast::Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
: Broadcast(OutputVector{arg}, shape, broadcast_axes)
{
}
......@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
}
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
const string op::BroadcastLike::type_name{"BroadcastLike"};
op::BroadcastLike::BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {})
: Broadcast({arg, like_arg}, {}, {})
, m_initial_broadcast_axes(initial_broadcast_axes)
{
constructor_validate_and_infer_types();
......
......@@ -27,15 +27,18 @@ namespace ngraph
class Broadcast : public Op
{
public:
/// \brief Constructs a conversion operation.
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes);
Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes);
void validate_and_infer_types() override;
......@@ -44,12 +47,14 @@ namespace ngraph
/// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
void set_broadcast_axes(const AxisSet& broadcast_axes)
{
m_broadcast_axes = broadcast_axes;
}
const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected:
Broadcast(const std::string& node_type,
const NodeVector& args,
const Shape& shape,
const AxisSet& broadcast_axes);
Broadcast(const OutputVector& args, const Shape& shape, const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......@@ -63,6 +68,11 @@ namespace ngraph
class BroadcastLike : public Broadcast
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike() = default;
/// \brief Broadcast arg to the same shape as like_arg.
///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
......@@ -72,8 +82,8 @@ namespace ngraph
/// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node>
......@@ -81,6 +91,11 @@ namespace ngraph
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
{
m_initial_broadcast_axes = initial_broadcast_axes;
}
protected:
AxisSet m_initial_broadcast_axes;
};
......
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg, int root_id)
: Op("BroadcastDistributed", check_single_output_args({arg}))
const string op::BroadcastDistributed::type_name{"BroadcastDistributed"};
op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int root_id)
: Op({arg})
, m_root_id(root_id)
{
constructor_validate_and_infer_types();
......@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
{
return m_root_id;
}
void op::BroadcastDistributed::set_root_id(int root_id)
{
m_root_id = root_id;
}
......@@ -27,16 +27,21 @@ namespace ngraph
class BroadcastDistributed : public Op
{
public:
BroadcastDistributed(const std::shared_ptr<Node>& arg, int root_id = 0);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BroadcastDistributed() = default;
BroadcastDistributed(const Output<Node>& arg, int root_id = 0);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const;
void set_root_id(int root_id);
private:
const int m_root_id;
int m_root_id;
};
}
}
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::Ceiling::Ceiling(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Ceiling", arg)
const string op::Ceiling::type_name{"Ceiling"};
op::Ceiling::Ceiling(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Ceiling : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ceiling operation.
Ceiling() = default;
/// \brief Constructs a ceiling operation.
///
/// \param arg Node that produces the input tensor.
Ceiling(const std::shared_ptr<Node>& arg);
Ceiling(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -22,13 +22,20 @@
using namespace std;
using namespace ngraph;
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Op("Concat", check_single_output_args(args))
const string op::Concat::type_name{"Concat"};
op::Concat::Concat(const OutputVector& args, size_t concatenation_axis)
: Op(args)
, m_concatenation_axis(concatenation_axis)
{
constructor_validate_and_infer_types();
}
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Concat(as_output_vector(args), concatenation_axis)
{
}
void op::Concat::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
......
......@@ -28,6 +28,17 @@ namespace ngraph
class Concat : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a concatenation operation.
Concat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t concatenation_axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
......@@ -41,10 +52,15 @@ namespace ngraph
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
void set_concatenation_axis(size_t concatenation_axis)
{
m_concatenation_axis = concatenation_axis;
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
const size_t m_concatenation_axis;
size_t m_concatenation_axis;
};
}
}
......@@ -45,6 +45,8 @@ string to_cpp_string(T value)
return rc;
}
const string op::Constant::type_name{"Constant"};
op::Constant::~Constant()
{
}
......
......@@ -34,6 +34,9 @@ namespace ngraph
class Constant : public Node
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor constant.
///
/// \param type The element type of the tensor constant.
......@@ -78,7 +81,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data.
Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values)
: Node("Constant", {})
: Node({})
, m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(),
......@@ -135,7 +138,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data)
: Node("Constant", {})
: Node({})
, m_element_type(type)
, m_shape(shape)
, m_data(nullptr)
......
......@@ -21,8 +21,10 @@
using namespace std;
using namespace ngraph;
op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_type)
: Op("Convert", check_single_output_args({arg}))
const string op::Convert::type_name{"Convert"};
op::Convert::Convert(const Output<Node>& arg, const element::Type& element_type)
: Op({arg})
, m_element_type(element_type)
{
constructor_validate_and_infer_types();
......
......@@ -26,11 +26,16 @@ namespace ngraph
class Convert : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a conversion operation.
Convert() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
Convert(const Output<Node>& arg, const ngraph::element::Type& element_type);
void validate_and_infer_types() override;
......@@ -38,8 +43,13 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_convert_element_type() const { return m_element_type; }
void set_convert_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
protected:
const ngraph::element::Type m_element_type;
ngraph::element::Type m_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -27,15 +27,17 @@
using namespace std;
using namespace ngraph;
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const string op::Convolution::type_name{"Convolution"};
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const PadType& pad_type)
: Op("Convolution", check_single_output_args({data_batch, filters}))
: Op({data_batch, filters})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
set_output_type(0, result_et, result_shape);
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
......@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides)
: Convolution(data_batch,
......@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides)
: Convolution(data_batch,
filters,
......@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_ptr<Node>& filters)
op::Convolution::Convolution(const Output<Node>& data_batch, const Output<Node>& filters)
: Convolution(data_batch, filters, Strides(), Strides(), CoordinateDiff(), CoordinateDiff())
{
}
......@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
m_data_dilation_strides));
}
const string op::ConvolutionBackpropData::type_name{"ConvolutionBackpropData"};
op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_shape,
const shared_ptr<Node>& filters,
const shared_ptr<Node>& output_delta,
const Output<Node>& filters,
const Output<Node>& output_delta,
const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropData", check_single_output_args({filters, output_delta}))
: Op({filters, output_delta})
, m_data_batch_shape(data_batch_shape)
, m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward)
......@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
m_data_dilation_strides_forward[i]);
}
auto swap_NC = [](const shared_ptr<Node> n) {
AxisVector ax_order = ngraph::get_default_order(n->get_shape());
auto swap_NC = [](const Output<Node>& n) {
AxisVector ax_order = ngraph::get_default_order(n.get_shape());
ax_order[0] = 1;
ax_order[1] = 0;
auto new_shape = n->get_shape();
new_shape[0] = n->get_shape()[1];
new_shape[1] = n->get_shape()[0];
auto new_shape = n.get_shape();
new_shape[0] = n.get_shape()[1];
new_shape[1] = n.get_shape()[0];
return make_shared<op::Reshape>(n, ax_order, new_shape);
};
......@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
return backward_delta_out_pad_above;
}
const string op::ConvolutionBackpropFilters::type_name{"ConvolutionBackpropFilters"};
op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
const shared_ptr<Node>& data_batch,
const Output<Node>& data_batch,
const Shape& filters_shape,
const shared_ptr<Node>& output_delta,
const Output<Node>& output_delta,
const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropFilters", check_single_output_args({data_batch, output_delta}))
: Op({data_batch, output_delta})
, m_filters_shape(filters_shape)
, m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward)
......
This diff is collapsed.
......@@ -22,8 +22,10 @@
using namespace std;
using namespace ngraph;
op::Cos::Cos(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Cos", arg)
const string op::Cos::type_name{"Cos"};
op::Cos::Cos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Cos : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a cosine operation.
Cos() = default;
/// \brief Constructs a cosine operation.
///
/// \param arg Node that produces the input tensor.
Cos(const std::shared_ptr<Node>& arg);
Cos(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -21,8 +21,10 @@
using namespace std;
using namespace ngraph;
op::Cosh::Cosh(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Cosh", arg)
const string op::Cosh::type_name{"Cosh"};
op::Cosh::Cosh(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Cosh : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic cosine operation.
Cosh() = default;
/// \brief Constructs a hyperbolic cosine operation.
///
/// \param arg Node that produces the input tensor.
Cosh(const std::shared_ptr<Node>& arg);
Cosh(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,13 +20,15 @@
using namespace std;
using namespace ngraph;
op::Dequantize::Dequantize(const shared_ptr<Node>& input,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point,
const string op::Dequantize::type_name{"Dequantize"};
op::Dequantize::Dequantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type,
const AxisSet& axes)
: Op("Dequantize", check_single_output_args({input, scale, zero_point}))
: Op({input, scale, zero_point})
, m_type(type)
, m_axes(axes)
{
......
......@@ -30,31 +30,40 @@ namespace ngraph
class Dequantize : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Dequantize operation
Dequantize() = default;
/// \brief Constructs a Dequantize operation
/// \param input quantized input
/// \param scale scale used for mapping
/// \param zero_point zero point used for mapping
/// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& zero_point,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes);
Dequantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type,
const AxisSet& axes);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const ngraph::AxisSet& get_axes() const { return m_axes; }
const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
const element::Type& get_type() const { return m_type; }
void set_type(const element::Type& type) { m_type = type; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
ngraph::element::Type m_type;
ngraph::AxisSet m_axes;
element::Type m_type;
AxisSet m_axes;
};
}
}
......@@ -21,20 +21,21 @@
using namespace std;
using namespace ngraph;
op::Divide::Divide(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Divide::type_name{"Divide"};
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob)
, m_pythondiv(true)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
op::Divide::Divide(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
, m_pythondiv(pythondiv)
{
constructor_validate_and_infer_types();
......@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y);
}
shared_ptr<Node> ngraph::operator/(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1)
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1)
{
return make_shared<op::Divide>(arg0, arg1);
}
......@@ -26,14 +26,19 @@ namespace ngraph
class Divide : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
......@@ -42,11 +47,12 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -54,10 +60,10 @@ namespace ngraph
const NodeVector& deltas) override;
protected:
bool m_pythondiv;
bool m_pythondiv{true};
};
}
std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1);
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1);
}
......@@ -29,16 +29,18 @@
using namespace std;
using namespace ngraph;
op::Dot::Dot(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
const string op::Dot::type_name{"Dot"};
op::Dot::Dot(const Output<Node>& arg0, const Output<Node>& arg1)
: Dot(arg0, arg1, 0, false)
{
}
op::Dot::Dot(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
op::Dot::Dot(const Output<Node>& arg0,
const Output<Node>& arg1,
size_t reduction_axes_count,
bool has_reduction_axes_count)
: Op("Dot", check_single_output_args({arg0, arg1}))
: Op({arg0, arg1})
, m_reduction_axes_count(reduction_axes_count)
, m_has_reduction_axes_count(has_reduction_axes_count)
{
......@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
set_output_type(0, result_et, result_shape);
}
shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n,
shared_ptr<op::Reshape> make_reshape_axes_to_front(const Output<Node>& n,
const Shape& front_shape,
const Shape& back_shape)
{
......
......@@ -28,13 +28,18 @@ namespace ngraph
class Dot : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dot product operation.
Dot() = default;
/// \brief Constructs a dot product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot.
Dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Dot(const Output<Node>& arg0,
const Output<Node>& arg1,
size_t reduction_axes_count,
bool has_reduction_axes_count = true);
......@@ -48,11 +53,20 @@ namespace ngraph
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Dot(const Output<Node>& arg0, const Output<Node>& arg1);
void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count)
{
m_reduction_axes_count = reduction_axes_count;
}
bool get_has_reduction_axes_count() const { return m_has_reduction_axes_count; }
void set_has_reduction_axes_count(bool has_reduction_axes_count)
{
m_has_reduction_axes_count = has_reduction_axes_count;
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override
{
......
......@@ -19,6 +19,8 @@
using namespace std;
using namespace ngraph;
const string op::EmbeddingLookup::type_name{"EmbeddingLookup"};
void op::EmbeddingLookup::validate_and_infer_types()
{
element::Type result_et = get_input_element_type(1);
......
......@@ -28,6 +28,11 @@ namespace ngraph
class EmbeddingLookup : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a EmbeddingLookup operation.
EmbeddingLookup() = default;
/// \brief Constructs a EmbeddingLookup operation.
///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
......@@ -36,8 +41,8 @@ namespace ngraph
/// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup(const std::shared_ptr<Node>& data, const std::shared_ptr<Node>& weights)
: Op("EmbeddingLookup", check_single_output_args({data, weights}))
EmbeddingLookup(const Output<Node>& data, const Output<Node>& weights)
: Op({data, weights})
{
constructor_validate_and_infer_types();
}
......
......@@ -19,10 +19,10 @@
using namespace std;
using namespace ngraph;
op::Equal::Equal(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Equal", arg0, arg1, autob)
const string op::Equal::type_name{"Equal"};
op::Equal::Equal(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -40,13 +40,18 @@ namespace ngraph
class Equal : public util::BinaryElementwiseComparison
{
public:
/// \brief Constructs an is-equal operation.
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -21,14 +21,16 @@
using namespace std;
using namespace ngraph;
const string op::Erf::type_name{"Erf"};
shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Erf>(new_args.at(0));
}
op::Erf::Erf(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Erf", arg)
op::Erf::Erf(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......@@ -27,7 +27,11 @@ namespace ngraph
class Erf : public util::UnaryElementwiseArithmetic
{
public:
Erf(std::shared_ptr<Node> arg);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Erf() = default;
Erf(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,8 +20,10 @@
using namespace std;
using namespace ngraph;
op::Exp::Exp(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Exp", arg)
const string op::Exp::type_name{"Exp"};
op::Exp::Exp(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Exp : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an exponential operation.
Exp() = default;
/// \brief Constructs an exponential operation.
///
/// \param arg Node that produces the input tensor.
Exp(const std::shared_ptr<Node>& arg);
Exp(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::Floor::Floor(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Floor", arg)
const string op::Floor::type_name{"Floor"};
op::Floor::Floor(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Floor : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a floor operation.
Floor() = default;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Floor(const std::shared_ptr<Node>& arg);
Floor(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -72,6 +72,13 @@ void op::Split::pre_validate_and_infer_types()
dimension_at_axis,
" has to be equal to the sum of splits passed to the op: ",
sum_splits);
const bool all_splits_positive =
all_of(begin(m_splits), end(m_splits), [](const size_t v) { return v > 0; });
NODE_VALIDATION_CHECK(this,
all_splits_positive == true,
"All values of the 'splits' attribute must be greater than zero");
}
}
......
......@@ -24,10 +24,12 @@
using namespace std;
using namespace ngraph;
op::Reshape::Reshape(const shared_ptr<Node>& arg,
const string op::Reshape::type_name{"Reshape"};
op::Reshape::Reshape(const Output<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: Op("Reshape", check_single_output_args({arg}))
: Op({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
......
......@@ -60,6 +60,11 @@ namespace ngraph
class Reshape : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a reshape operation.
Reshape() = default;
/// \brief Constructs a reshape operation.
///
/// \param arg The tensor to be reshaped.
......@@ -67,7 +72,7 @@ namespace ngraph
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg,
Reshape(const Output<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape);
......@@ -78,15 +83,18 @@ namespace ngraph
/// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; }
void set_input_order(const AxisVector& input_order) { m_input_order = input_order; }
/// \return The shape of the output tensor.
const Shape& get_output_shape() const { return m_output_shape; }
void set_output_shape(const Shape& output_shape) { m_output_shape = output_shape; }
bool get_is_transpose() const { return m_is_transpose; }
void set_is_transpose(bool is_transpose) { m_is_transpose = is_transpose; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
const AxisVector m_input_order;
const Shape m_output_shape;
AxisVector m_input_order;
Shape m_output_shape;
bool m_is_transpose{false};
};
}
......
......@@ -24,8 +24,10 @@
using namespace std;
using namespace ngraph;
op::Result::Result(const shared_ptr<Node>& arg, bool needs_default_layout)
: Op("Result", check_single_output_args({arg}))
const string op::Result::type_name{"Result"};
op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
: Op({arg})
, m_needs_default_layout(needs_default_layout)
{
constructor_validate_and_infer_types();
......
......@@ -27,10 +27,15 @@ namespace ngraph
class Result : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result(const std::shared_ptr<Node>& arg, bool needs_default_layout = false);
Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override;
......
......@@ -298,9 +298,14 @@ namespace ngraph
if (graph_node->is_commutative())
{
std::sort(
begin(pattern_args),
end(pattern_args)); // TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
// TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
std::sort(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() < n2->get_instance_id();
});
do
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
......@@ -311,7 +316,13 @@ namespace ngraph
pattern_map.insert(begin(copy), end(copy));
return true;
}
} while (std::next_permutation(begin(pattern_args), end(pattern_args)));
} while (std::next_permutation(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() <
n2->get_instance_id();
}));
}
else
{
......
......@@ -90,3 +90,9 @@ std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& input_strea
{
throw runtime_error("load opertion unimplemented.");
}
bool runtime::Backend::set_config(const map<string, string>& config, string& error)
{
error = "set_config not supported";
return false;
}
......@@ -139,4 +139,13 @@ public:
/// \param op_name is the name of the backend specific op
/// \returns a shared pointer to the op if found, else nullptr
virtual std::shared_ptr<ngraph::Node> get_backend_op(const std::string& op_name, ...);
/// \brief Allows sending backend specific configuration. The map contains key, value pairs
/// specific to a particluar backend. The definition of these key, value pairs is
/// defined by each backend.
/// \param config The configuration map sent to the backend
/// \param error An error string describing any error encountered
/// \returns true if the configuration is supported, false otherwise. On false the error
/// parameter value is valid.
virtual bool set_config(const std::map<std::string, std::string>& config, std::string& error);
};
......@@ -105,3 +105,16 @@ std::shared_ptr<runtime::Executable> runtime::interpreter::INTBackend::load(istr
}
return exec;
}
bool runtime::interpreter::INTBackend::set_config(const map<string, string>& config, string& error)
{
bool rc = false;
auto it = config.find("test_echo");
error = "";
if (it != config.end())
{
error = it->second;
rc = true;
}
return rc;
}
......@@ -58,6 +58,8 @@ public:
bool is_supported(const Node& node) const override;
bool set_config(const std::map<std::string, std::string>& config, std::string& error) override;
private:
std::set<std::string> m_unsupported_op_name_list;
};
......@@ -141,6 +141,7 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/provenance.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......@@ -1803,6 +1804,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
{
node->set_friendly_name(node_name);
}
if (ngraph::get_provenance_enabled())
{
std::vector<json> prov_js = node_js.at("provenance_tags");
for (auto prov_tag : prov_js)
{
node->add_provenance_tag(prov_tag);
}
}
m_node_map[node_name] = node;
}
catch (...)
......@@ -1914,6 +1923,15 @@ json JSONSerializer::serialize_node(const Node& n)
}
node["output_shapes"] = output_shapes;
}
if (ngraph::get_provenance_enabled())
{
json provenance_tags = json::array();
for (auto prov_tag : n.get_provenance_tags())
{
provenance_tags.push_back(prov_tag);
}
node["provenance_tags"] = provenance_tags;
}
string node_op = n.description();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
......
......@@ -37,6 +37,28 @@ TEST(backend_api, invalid_name)
ASSERT_ANY_THROW(ngraph::runtime::Backend::create("COMPLETELY-BOGUS-NAME"));
}
TEST(backend_api, config)
{
auto backend = runtime::Backend::create("INTERPRETER");
string error;
string message = "hello";
map<string, string> config = {{"test_echo", message}};
EXPECT_TRUE(backend->set_config(config, error));
EXPECT_STREQ(error.c_str(), message.c_str());
EXPECT_FALSE(backend->set_config({}, error));
EXPECT_STREQ(error.c_str(), "");
}
TEST(backend_api, config_unsupported)
{
auto backend = runtime::Backend::create("NOP");
string error;
string message = "hello";
map<string, string> config = {{"test_echo", message}};
EXPECT_FALSE(backend->set_config(config, error));
EXPECT_FALSE(error == "");
}
#ifndef NGRAPH_JSON_DISABLE
TEST(backend_api, save_load)
{
......
......@@ -50,25 +50,25 @@ static void test_allreduce_common(reduction::Type reduce_type)
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (reduce_type.get_type())
switch (reduce_type)
{
case reduction::Type_t::sum:
case reduction::Type::SUM:
copy_data(a, v);
std::transform(
v.begin(), v.end(), v.begin(), std::bind1st(std::multiplies<float>(), comm_size));
break;
case reduction::Type_t::prod:
case reduction::Type::PROD:
copy_data(a, v);
std::transform(v.begin(), v.end(), v.begin(), [&](float elm) -> float {
return pow(elm, comm_size);
});
break;
case reduction::Type_t::min:
case reduction::Type_t::max:
case reduction::Type::MIN:
case reduction::Type::MAX:
auto shift = get_distributed_interface()->get_rank();
std::rotate(v.begin(), v.begin() + shift % v.size(), v.end());
copy_data(a, v);
if (reduce_type == reduction::Type_t::min)
if (reduce_type == reduction::Type::MIN)
{
std::fill(v.begin(), v.end(), 1);
for (int i = 1; i < static_cast<int>(v.size()) - comm_size + 1; i++)
......@@ -93,23 +93,23 @@ static void test_allreduce_common(reduction::Type reduce_type)
TEST(distributed_${BACKEND_NAME}, allreduce_sum)
{
test_allreduce_common(reduction::sum);
test_allreduce_common(reduction::Type::SUM);
}
TEST(distributed_${BACKEND_NAME}, allreduce_min)
{
test_allreduce_common(reduction::min);
test_allreduce_common(reduction::Type::MIN);
}
TEST(distributed_${BACKEND_NAME}, allreduce_max)
{
test_allreduce_common(reduction::max);
test_allreduce_common(reduction::Type::MAX);
}
#if !defined(NGRAPH_DISTRIBUTED_MLSL_ENABLE)
TEST(distributed_${BACKEND_NAME}, allreduce_prod)
{
test_allreduce_common(reduction::prod);
test_allreduce_common(reduction::Type::PROD);
}
#endif
......
......@@ -514,6 +514,33 @@ TEST(pattern, previous_matches)
}
}
TEST(pattern, test_sort)
{
using ngraph::pattern::Matcher;
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto abs1 = make_shared<op::Abs>(a);
auto abs2 = make_shared<op::Abs>(b);
auto add = abs1 + abs2;
auto pa = make_shared<op::Parameter>(element::i32, shape);
auto pb = make_shared<op::Parameter>(element::i32, shape);
auto pabs1 = make_shared<op::Abs>(pa);
auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
auto pabs2 = make_shared<op::Abs>(b);
auto padd = pabs1_label + pabs2;
{
Matcher n1(padd);
ASSERT_TRUE(n1.match(add));
auto r1 = n1.get_pattern_map()[pabs1_label];
ASSERT_TRUE(n1.match(add));
ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
}
}
TEST(pattern, recurrent_pattern)
{
using ngraph::pattern::RecurrentMatcher;
......
......@@ -90,20 +90,7 @@ namespace ngraph
auto c_vec = read_vector<T>(c_arg);
fill(c_vec.begin(), c_vec.end(), static_cast<T>(0));
static std::unordered_map<std::shared_ptr<Function>,
std::shared_ptr<runtime::Executable>>
s_compiled_functions;
auto it = s_compiled_functions.find(df);
std::shared_ptr<runtime::Executable> df_handle;
if (it == s_compiled_functions.end())
{
df_handle = backend->compile(df);
s_compiled_functions.insert({df, df_handle});
}
else
{
df_handle = it->second;
}
auto df_handle = backend->compile(df);
// for each element of the adjoint
// same as saying for each element of y
......@@ -212,20 +199,7 @@ namespace ngraph
s_clone_fwd_map[f] = clone_function(*fprop_cache.fprop);
}
auto clone_fwd = s_clone_fwd_map[f];
static std::unordered_map<std::shared_ptr<Function>,
std::shared_ptr<runtime::Executable>>
s_compiled_functions;
auto it = s_compiled_functions.find(clone_fwd);
std::shared_ptr<runtime::Executable> clone_fwd_handle;
if (it == s_compiled_functions.end())
{
clone_fwd_handle = backend->compile(clone_fwd);
s_compiled_functions.insert({clone_fwd, clone_fwd_handle});
}
else
{
clone_fwd_handle = it->second;
}
auto clone_fwd_handle = backend->compile(clone_fwd);
clone_fwd_handle->call_with_validate(mod_f_output_args, f_input_args);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment