Unverified Commit 28e5c2f7 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/noexport

parents 13d63f21 8dd818e5
...@@ -43,6 +43,8 @@ namespace ...@@ -43,6 +43,8 @@ namespace
using namespace mlir::edsc::op; using namespace mlir::edsc::op;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
class DialectLoweringPass; class DialectLoweringPass;
...@@ -682,7 +684,8 @@ namespace ...@@ -682,7 +684,8 @@ namespace
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices); MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values // Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices); IndexedValue iRes(result), iIndices(indices);
StdIndexedValue iParams(params);
// Construct outer loop for params dims. Exclude the axis dim. // Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs; SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
...@@ -894,7 +897,8 @@ namespace ...@@ -894,7 +897,8 @@ namespace
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
// Index Values // Index Values
IndexedValue iRes(result), iArg(arg); StdIndexedValue iRes(result), stdArg(arg);
IndexedValue affineArg(arg);
// Bounds Index Handles // Bounds Index Handles
auto resLbs = vRes.getLbs(); auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs(); auto resUbs = vRes.getUbs();
...@@ -944,9 +948,9 @@ namespace ...@@ -944,9 +948,9 @@ namespace
ValueHandle newRedIdx = ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>() std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select( ? edsc::intrinsics::select(
iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx) affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select( : edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx); stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy); iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
}); });
......
...@@ -173,6 +173,7 @@ namespace ngraph ...@@ -173,6 +173,7 @@ namespace ngraph
class AvgPoolBackprop : public Op class AvgPoolBackprop : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AvgPoolBackprop() = default; AvgPoolBackprop() = default;
......
...@@ -92,6 +92,7 @@ namespace ngraph ...@@ -92,6 +92,7 @@ namespace ngraph
class BatchNormInference : public Op class BatchNormInference : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormInference() = default; BatchNormInference() = default;
......
...@@ -24,8 +24,10 @@ ...@@ -24,8 +24,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::BatchMatMul::type_name{"BatchMatMul"};
op::BatchMatMul::BatchMatMul(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) op::BatchMatMul::BatchMatMul(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: Op("BatchMatMul", check_single_output_args({arg0, arg1})) : Op(check_single_output_args({arg0, arg1}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -32,6 +32,9 @@ namespace ngraph ...@@ -32,6 +32,9 @@ namespace ngraph
class BatchMatMul : public Op class BatchMatMul : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batch of matmul product operation. /// \brief Constructs a batch of matmul product operation.
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::CompiledKernel::type_name{"CompiledKernel"};
shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector& new_args) const
{ {
auto args = inputs(); auto args = inputs();
...@@ -64,7 +66,7 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector ...@@ -64,7 +66,7 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list, ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args) const NodeVector& args)
: Op("CompiledKernel", check_single_output_args({args})) : Op(check_single_output_args({args}))
, m_node_list(node_list) , m_node_list(node_list)
, m_output_nodes(outputs) , m_output_nodes(outputs)
{ {
......
...@@ -32,6 +32,9 @@ namespace ngraph ...@@ -32,6 +32,9 @@ namespace ngraph
class CompiledKernel : public ngraph::op::Op class CompiledKernel : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CompiledKernel(const NodeVector& node_list, CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args); const NodeVector& args);
......
...@@ -20,10 +20,12 @@ ...@@ -20,10 +20,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynBroadcast::type_name{"DynBroadcast"};
op::DynBroadcast::DynBroadcast(const shared_ptr<Node>& arg, op::DynBroadcast::DynBroadcast(const shared_ptr<Node>& arg,
const shared_ptr<Node>& shape, const shared_ptr<Node>& shape,
const shared_ptr<Node>& broadcast_axes) const shared_ptr<Node>& broadcast_axes)
: Op("DynBroadcast", check_single_output_args({arg, shape, broadcast_axes})) : Op(check_single_output_args({arg, shape, broadcast_axes}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class DynBroadcast : public Op class DynBroadcast : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic broadcast operation. /// \brief Constructs a dynamic broadcast operation.
/// ///
/// \param arg Node that produces the input tensor to be broadcast. /// \param arg Node that produces the input tensor to be broadcast.
......
...@@ -19,12 +19,14 @@ ...@@ -19,12 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynPad::type_name{"DynPad"};
op::DynPad::DynPad(const std::shared_ptr<Node>& arg, op::DynPad::DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below, const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above, const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value, const std::shared_ptr<Node>& padding_value,
op::PadMode pad_mode) op::PadMode pad_mode)
: Op("DynPad", check_single_output_args({arg, padding_below, padding_above, padding_value})) : Op(check_single_output_args({arg, padding_below, padding_above, padding_value}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class DynPad : public Op class DynPad : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Perform dynamic padding of a tensor /// \brief Perform dynamic padding of a tensor
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynReplaceSlice::type_name{"DynReplaceSlice"};
op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg, op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& replacement, const shared_ptr<Node>& replacement,
const shared_ptr<Node>& lower_bounds, const shared_ptr<Node>& lower_bounds,
...@@ -34,8 +36,7 @@ op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg, ...@@ -34,8 +36,7 @@ op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg,
const AxisSet& new_axis, const AxisSet& new_axis,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
: Op("DynReplaceSlice", : Op(check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides}))
check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides}))
, m_lower_bounds_mask(lower_bounds_mask) , m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask) , m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis) , m_new_axis(new_axis)
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class DynReplaceSlice : public Op class DynReplaceSlice : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic tensor replace-slice operation. /// \brief Constructs a dynamic tensor replace-slice operation.
/// ///
/// \param arg The tensor in which to replace the slice. /// \param arg The tensor in which to replace the slice.
......
...@@ -24,10 +24,12 @@ ...@@ -24,10 +24,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynReshape::type_name{"DynReshape"};
op::DynReshape::DynReshape(const shared_ptr<Node>& arg, op::DynReshape::DynReshape(const shared_ptr<Node>& arg,
const shared_ptr<Node>& pattern, const shared_ptr<Node>& pattern,
bool zero_flag) bool zero_flag)
: Op("DynReshape", check_single_output_args({arg, pattern})) : Op(check_single_output_args({arg, pattern}))
, m_zero_flag(zero_flag) , m_zero_flag(zero_flag)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -31,6 +31,9 @@ namespace ngraph ...@@ -31,6 +31,9 @@ namespace ngraph
class DynReshape : public Op class DynReshape : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic reshape operation. This operation does not perform transpose. /// \brief Constructs a dynamic reshape operation. This operation does not perform transpose.
/// ///
/// \param arg The tensor to be reshaped. /// \param arg The tensor to be reshaped.
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DynSlice::type_name{"DynSlice"};
op::DynSlice::DynSlice(const shared_ptr<Node>& arg, op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& lower_bounds, const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds, const shared_ptr<Node>& upper_bounds,
...@@ -33,7 +35,7 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg, ...@@ -33,7 +35,7 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const AxisSet& new_axis, const AxisSet& new_axis,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
: Op("DynSlice", check_single_output_args({arg, lower_bounds, upper_bounds, strides})) : Op(check_single_output_args({arg, lower_bounds, upper_bounds, strides}))
, m_lower_bounds_mask(lower_bounds_mask) , m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask) , m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis) , m_new_axis(new_axis)
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class DynSlice : public Op class DynSlice : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dynamic tensor slice operation. /// \brief Constructs a dynamic tensor slice operation.
/// ///
/// \param arg The tensor to be sliced. /// \param arg The tensor to be sliced.
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::CTCGreedyDecoder::type_name{"CTCGreedyDecoder"};
op::CTCGreedyDecoder::CTCGreedyDecoder(const shared_ptr<Node>& input, op::CTCGreedyDecoder::CTCGreedyDecoder(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& seq_len, const std::shared_ptr<Node>& seq_len,
const bool ctc_merge_repeated) const bool ctc_merge_repeated)
: Op("CTCGreedyDecoder", check_single_output_args({input, seq_len})) : Op(check_single_output_args({input, seq_len}))
, m_ctc_merge_repeated(ctc_merge_repeated) , m_ctc_merge_repeated(ctc_merge_repeated)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class CTCGreedyDecoder : public Op class CTCGreedyDecoder : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a CTCGreedyDecoder operation /// \brief Constructs a CTCGreedyDecoder operation
/// ///
/// \param input Logits on which greedy decoding is performed /// \param input Logits on which greedy decoding is performed
......
...@@ -21,14 +21,15 @@ ...@@ -21,14 +21,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::DetectionOutput::type_name{"DetectionOutput"};
op::DetectionOutput::DetectionOutput(const std::shared_ptr<Node>& box_logits, op::DetectionOutput::DetectionOutput(const std::shared_ptr<Node>& box_logits,
const std::shared_ptr<Node>& class_preds, const std::shared_ptr<Node>& class_preds,
const std::shared_ptr<Node>& proposals, const std::shared_ptr<Node>& proposals,
const std::shared_ptr<Node>& aux_class_preds, const std::shared_ptr<Node>& aux_class_preds,
const std::shared_ptr<Node>& aux_box_preds, const std::shared_ptr<Node>& aux_box_preds,
const DetectionOutputAttrs& attrs) const DetectionOutputAttrs& attrs)
: Op("DetectionOutput", : Op(check_single_output_args(
check_single_output_args(
{box_logits, class_preds, proposals, aux_class_preds, aux_box_preds})) {box_logits, class_preds, proposals, aux_class_preds, aux_box_preds}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
......
...@@ -47,6 +47,9 @@ namespace ngraph ...@@ -47,6 +47,9 @@ namespace ngraph
class DetectionOutput : public Op class DetectionOutput : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a DetectionOutput operation /// \brief Constructs a DetectionOutput operation
/// ///
/// \param box_logits Box logits /// \param box_logits Box logits
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Interpolate::type_name{"Interpolate"};
op::Interpolate::Interpolate(const std::shared_ptr<Node>& image, op::Interpolate::Interpolate(const std::shared_ptr<Node>& image,
const std::shared_ptr<Node>& output_shape, const std::shared_ptr<Node>& output_shape,
const InterpolateAttrs& attrs) const InterpolateAttrs& attrs)
: Op("Interpolate", check_single_output_args({image, output_shape})) : Op(check_single_output_args({image, output_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -36,6 +36,9 @@ namespace ngraph ...@@ -36,6 +36,9 @@ namespace ngraph
class Interpolate : public Op class Interpolate : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Interpolate operation /// \brief Constructs a Interpolate operation
/// ///
/// \param image Input image /// \param image Input image
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::PriorBox::type_name{"PriorBox"};
op::PriorBox::PriorBox(const shared_ptr<Node>& layer_shape, op::PriorBox::PriorBox(const shared_ptr<Node>& layer_shape,
const shared_ptr<Node>& image_shape, const shared_ptr<Node>& image_shape,
const PriorBoxAttrs& attrs) const PriorBoxAttrs& attrs)
: Op("PriorBox", check_single_output_args({layer_shape, image_shape})) : Op(check_single_output_args({layer_shape, image_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -49,6 +49,9 @@ namespace ngraph ...@@ -49,6 +49,9 @@ namespace ngraph
class PriorBox : public Op class PriorBox : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a PriorBox operation /// \brief Constructs a PriorBox operation
/// ///
/// \param layer_shape Shape of layer for which prior boxes are computed /// \param layer_shape Shape of layer for which prior boxes are computed
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::PriorBoxClustered::type_name{"PriorBoxClustered"};
op::PriorBoxClustered::PriorBoxClustered(const shared_ptr<Node>& layer_shape, op::PriorBoxClustered::PriorBoxClustered(const shared_ptr<Node>& layer_shape,
const shared_ptr<Node>& image_shape, const shared_ptr<Node>& image_shape,
const PriorBoxClusteredAttrs& attrs) const PriorBoxClusteredAttrs& attrs)
: Op("PriorBoxClustered", check_single_output_args({layer_shape, image_shape})) : Op(check_single_output_args({layer_shape, image_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -47,6 +47,9 @@ namespace ngraph ...@@ -47,6 +47,9 @@ namespace ngraph
class PriorBoxClustered : public Op class PriorBoxClustered : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a PriorBoxClustered operation /// \brief Constructs a PriorBoxClustered operation
/// ///
/// \param layer_shape Shape of layer for which prior boxes are computed /// \param layer_shape Shape of layer for which prior boxes are computed
......
...@@ -21,11 +21,13 @@ ...@@ -21,11 +21,13 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Proposal::type_name{"Proposal"};
op::Proposal::Proposal(const std::shared_ptr<Node>& class_probs, op::Proposal::Proposal(const std::shared_ptr<Node>& class_probs,
const std::shared_ptr<Node>& class_logits, const std::shared_ptr<Node>& class_logits,
const std::shared_ptr<Node>& image_shape, const std::shared_ptr<Node>& image_shape,
const ProposalAttrs& attrs) const ProposalAttrs& attrs)
: Op("Proposal", check_single_output_args({class_probs, class_logits, image_shape})) : Op(check_single_output_args({class_probs, class_logits, image_shape}))
, m_attrs(attrs) , m_attrs(attrs)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -57,6 +57,9 @@ namespace ngraph ...@@ -57,6 +57,9 @@ namespace ngraph
class Proposal : public Op class Proposal : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Proposal operation /// \brief Constructs a Proposal operation
/// ///
/// \param class_probs Class probability scores /// \param class_probs Class probability scores
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::PSROIPooling::type_name{"PSROIPooling"};
op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input, op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& coords, const std::shared_ptr<Node>& coords,
const size_t output_dim, const size_t output_dim,
...@@ -26,7 +28,7 @@ op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input, ...@@ -26,7 +28,7 @@ op::PSROIPooling::PSROIPooling(const shared_ptr<Node>& input,
const float spatial_scale, const float spatial_scale,
const Shape& num_bins, const Shape& num_bins,
const std::string& kind) const std::string& kind)
: Op("PSROIPooling", check_single_output_args({input, coords})) : Op(check_single_output_args({input, coords}))
, m_output_dim(output_dim) , m_output_dim(output_dim)
, m_group_size(group_size) , m_group_size(group_size)
, m_spatial_scale(spatial_scale) , m_spatial_scale(spatial_scale)
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class PSROIPooling : public Op class PSROIPooling : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a PSROIPooling operation /// \brief Constructs a PSROIPooling operation
/// ///
/// \param input Input feature map {N, C, ...} /// \param input Input feature map {N, C, ...}
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::RegionYolo::type_name{"RegionYolo"};
op::RegionYolo::RegionYolo(const shared_ptr<Node>& input, op::RegionYolo::RegionYolo(const shared_ptr<Node>& input,
const size_t num_coords, const size_t num_coords,
const size_t num_classes, const size_t num_classes,
...@@ -27,7 +29,7 @@ op::RegionYolo::RegionYolo(const shared_ptr<Node>& input, ...@@ -27,7 +29,7 @@ op::RegionYolo::RegionYolo(const shared_ptr<Node>& input,
const vector<int64_t>& mask, const vector<int64_t>& mask,
const int axis, const int axis,
const int end_axis) const int end_axis)
: Op("RegionYolo", check_single_output_args({input})) : Op(check_single_output_args({input}))
, m_num_coords(num_coords) , m_num_coords(num_coords)
, m_num_classes(num_classes) , m_num_classes(num_classes)
, m_num_regions(num_regions) , m_num_regions(num_regions)
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class RegionYolo : public Op class RegionYolo : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a RegionYolo operation /// \brief Constructs a RegionYolo operation
/// ///
/// \param input Input /// \param input Input
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ReorgYolo::type_name{"ReorgYolo"};
op::ReorgYolo::ReorgYolo(const shared_ptr<Node>& input, const Strides& strides) op::ReorgYolo::ReorgYolo(const shared_ptr<Node>& input, const Strides& strides)
: Op("ReorgYolo", check_single_output_args({input})) : Op(check_single_output_args({input}))
, m_strides(strides) , m_strides(strides)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class ReorgYolo : public Op class ReorgYolo : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ReorgYolo operation /// \brief Constructs a ReorgYolo operation
/// ///
/// \param input Input /// \param input Input
......
...@@ -19,12 +19,14 @@ ...@@ -19,12 +19,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ROIPooling::type_name{"ROIPooling"};
op::ROIPooling::ROIPooling(const shared_ptr<Node>& input, op::ROIPooling::ROIPooling(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& coords, const std::shared_ptr<Node>& coords,
const Shape& output_size, const Shape& output_size,
const float spatial_scale, const float spatial_scale,
const std::string& kind) const std::string& kind)
: Op("ROIPooling", check_single_output_args({input, coords})) : Op(check_single_output_args({input, coords}))
, m_output_size(output_size) , m_output_size(output_size)
, m_spatial_scale(spatial_scale) , m_spatial_scale(spatial_scale)
, m_kind(kind) , m_kind(kind)
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class ROIPooling : public Op class ROIPooling : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ROIPooling operation /// \brief Constructs a ROIPooling operation
/// ///
/// \param input Input feature map {N, C, ...} /// \param input Input feature map {N, C, ...}
......
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::QuantizedConcat::type_name{"QuantizedConcat"};
op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis) op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis)
: Op("QuantizedConcat", check_single_output_args(args)) : Op(check_single_output_args(args))
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class QuantizedConcat : public Op class QuantizedConcat : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ShapeOf::type_name{"ShapeOf"};
op::ShapeOf::ShapeOf(const shared_ptr<Node>& arg) op::ShapeOf::ShapeOf(const shared_ptr<Node>& arg)
: Op("ShapeOf", check_single_output_args({arg})) : Op(check_single_output_args({arg}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class ShapeOf : public Op class ShapeOf : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a shape-of operation. /// \brief Constructs a shape-of operation.
ShapeOf(const std::shared_ptr<Node>& arg); ShapeOf(const std::shared_ptr<Node>& arg);
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Tile::type_name{"Tile"};
op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats) op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats)
: Op("Tile", check_single_output_args({arg, repeats})) : Op(check_single_output_args({arg, repeats}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class Tile : public Op class Tile : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Perform dynamic padding of a tensor /// \brief Perform dynamic padding of a tensor
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Transpose::type_name{"Transpose"};
op::Transpose::Transpose(const shared_ptr<Node>& arg, const shared_ptr<Node>& input_order) op::Transpose::Transpose(const shared_ptr<Node>& arg, const shared_ptr<Node>& input_order)
: Op("Transpose", check_single_output_args({arg, input_order})) : Op(check_single_output_args({arg, input_order}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class Transpose : public Op class Transpose : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a transpose operation. /// \brief Constructs a transpose operation.
/// ///
/// \param arg Node producing the tensor to be transposed. /// \param arg Node producing the tensor to be transposed.
......
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Pad::type_name{"Pad"};
op::Pad::Pad(const shared_ptr<Node>& arg, op::Pad::Pad(const shared_ptr<Node>& arg,
const shared_ptr<Node>& arg_pad_value, const shared_ptr<Node>& arg_pad_value,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
PadMode pad_mode) PadMode pad_mode)
: Op("Pad", check_single_output_args({arg, arg_pad_value})) : Op(check_single_output_args({arg, arg_pad_value}))
, m_padding_below(padding_below) , m_padding_below(padding_below)
, m_padding_above(padding_above) , m_padding_above(padding_above)
, m_padding_interior_fake(padding_below.size()) , m_padding_interior_fake(padding_below.size())
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class Pad : public Op class Pad : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a generic padding operation. /// \brief Constructs a generic padding operation.
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
......
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Parameter::type_name{"Parameter"};
op::Parameter::Parameter(const element::Type& element_type, op::Parameter::Parameter(const element::Type& element_type,
const PartialShape& pshape, const PartialShape& pshape,
const bool cacheable) const bool cacheable)
: Op("Parameter", {}) : Op(NodeVector{})
, m_cacheable(cacheable) , m_cacheable(cacheable)
, m_partial_shape(pshape) , m_partial_shape(pshape)
, m_element_type(element_type) , m_element_type(element_type)
......
...@@ -35,6 +35,9 @@ namespace ngraph ...@@ -35,6 +35,9 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructions a tensor-typed parameter node. /// \brief Constructions a tensor-typed parameter node.
/// ///
/// \param element_type The element type of the parameter. /// \param element_type The element type of the parameter.
......
...@@ -18,12 +18,17 @@ ...@@ -18,12 +18,17 @@
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
using namespace std;
using namespace ngraph;
const string op::Passthrough::type_name{"Passthrough"};
ngraph::op::Passthrough::Passthrough(const std::string& logical_type, ngraph::op::Passthrough::Passthrough(const std::string& logical_type,
const std::string& language, const std::string& language,
const std::string& function, const std::string& function,
const NodeVector& args, const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs) std::vector<std::tuple<element::Type, PartialShape>> outputs)
: Op{"Passthrough", args} : Op{args}
, m_logical_type{logical_type} , m_logical_type{logical_type}
, m_language{language} , m_language{language}
, m_function{function} , m_function{function}
...@@ -65,5 +70,5 @@ std::shared_ptr<ngraph::Node> ...@@ -65,5 +70,5 @@ std::shared_ptr<ngraph::Node>
"Passthrough node input counts cannot be changed for a given Passthrough function"}; "Passthrough node input counts cannot be changed for a given Passthrough function"};
} }
return std::make_shared<Passthrough>( return std::make_shared<Passthrough>(
description(), m_language, m_function, new_args, m_output_shapes); m_logical_type, m_language, m_function, new_args, m_output_shapes);
} }
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class ngraph::op::Passthrough final : public Op class ngraph::op::Passthrough final : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Passthrough(const std::string& logical_type, // aka "What this operation is doing" Passthrough(const std::string& logical_type, // aka "What this operation is doing"
const std::string& language, // The language the implementation is written in const std::string& language, // The language the implementation is written in
const std::string& function, // The operation implementation const std::string& function, // The operation implementation
......
...@@ -22,10 +22,12 @@ ...@@ -22,10 +22,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Power::type_name{"Power"};
op::Power::Power(const shared_ptr<Node>& arg0, op::Power::Power(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1, const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Power", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -39,6 +39,9 @@ namespace ngraph ...@@ -39,6 +39,9 @@ namespace ngraph
class Power : public util::BinaryElementwiseArithmetic class Power : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an exponentiation operation. /// \brief Constructs an exponentiation operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -20,8 +20,11 @@ ...@@ -20,8 +20,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Relu::type_name{"Relu"};
const string op::ReluBackprop::type_name{"ReluBackprop"};
op::Relu::Relu(shared_ptr<Node> arg) op::Relu::Relu(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Relu", {arg}) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -33,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const ...@@ -33,7 +36,7 @@ shared_ptr<Node> op::Relu::copy_with_new_args(const NodeVector& new_args) const
} }
op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::ReluBackprop::ReluBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: BinaryElementwiseArithmetic("ReluBackprop", arg, delta) : BinaryElementwiseArithmetic(arg, delta)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -33,6 +33,9 @@ namespace ngraph ...@@ -33,6 +33,9 @@ namespace ngraph
class Relu : public ngraph::op::util::UnaryElementwiseArithmetic class Relu : public ngraph::op::util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Relu operation. /// \brief Constructs a Relu operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
...@@ -50,6 +53,9 @@ namespace ngraph ...@@ -50,6 +53,9 @@ namespace ngraph
class ReluBackprop : public ngraph::op::util::BinaryElementwiseArithmetic class ReluBackprop : public ngraph::op::util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ReluBackprop operation. /// \brief Constructs a ReluBackprop operation.
/// ///
/// \param arg Node that produces the relu forward input tensor. /// \param arg Node that produces the relu forward input tensor.
......
...@@ -23,8 +23,10 @@ ...@@ -23,8 +23,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Reverse::type_name{"Reverse"};
op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes) op::Reverse::Reverse(const shared_ptr<Node>& arg, const AxisSet& reversed_axes)
: Op("Reverse", check_single_output_args({arg})) : Op(check_single_output_args({arg}))
, m_reversed_axes(reversed_axes) , m_reversed_axes(reversed_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -46,6 +46,9 @@ namespace ngraph ...@@ -46,6 +46,9 @@ namespace ngraph
class Reverse : public Op class Reverse : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a reverse operation. /// \brief Constructs a reverse operation.
/// ///
/// \param arg The input tensor, some of whose axes are to be reversed. /// \param arg The input tensor, some of whose axes are to be reversed.
......
...@@ -25,11 +25,13 @@ ...@@ -25,11 +25,13 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::ReverseSequence::type_name{"ReverseSequence"};
op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg, op::ReverseSequence::ReverseSequence(const std::shared_ptr<Node> arg,
const std::shared_ptr<Node> seq_indices, const std::shared_ptr<Node> seq_indices,
size_t batch_axis, size_t batch_axis,
size_t seq_axis) size_t seq_axis)
: Op("ReverseSequence", check_single_output_args({arg, seq_indices})) : Op(check_single_output_args({arg, seq_indices}))
, m_batch_axis(batch_axis) , m_batch_axis(batch_axis)
, m_seq_axis(seq_axis) , m_seq_axis(seq_axis)
{ {
......
...@@ -25,6 +25,9 @@ namespace ngraph ...@@ -25,6 +25,9 @@ namespace ngraph
class ReverseSequence : public Op class ReverseSequence : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -24,6 +24,8 @@ static int INPUTS = 0; ...@@ -24,6 +24,8 @@ static int INPUTS = 0;
static int INDICES = 1; static int INDICES = 1;
static int UPDATES = 2; static int UPDATES = 2;
const string op::ScatterAdd::type_name{"ScatterAdd"};
shared_ptr<Node> op::ScatterAdd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ScatterAdd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,16 @@ namespace ngraph ...@@ -26,13 +26,16 @@ namespace ngraph
class ScatterAdd : public Op class ScatterAdd : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \param inputs Tensor /// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs /// \param updates Tensor: Must have same type as inputs
ScatterAdd(const std::shared_ptr<Node>& inputs, ScatterAdd(const std::shared_ptr<Node>& inputs,
const std::shared_ptr<Node>& indices, const std::shared_ptr<Node>& indices,
const std::shared_ptr<Node>& updates) const std::shared_ptr<Node>& updates)
: Op("ScatterAdd", check_single_output_args({inputs, indices, updates})) : Op(check_single_output_args({inputs, indices, updates}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -24,6 +24,8 @@ static int INPUTS = 0; ...@@ -24,6 +24,8 @@ static int INPUTS = 0;
static int INDICES = 1; static int INDICES = 1;
static int UPDATES = 2; static int UPDATES = 2;
const string op::ScatterNDAdd::type_name{"ScatterNDAdd"};
shared_ptr<Node> op::ScatterNDAdd::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ScatterNDAdd::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,16 @@ namespace ngraph ...@@ -26,13 +26,16 @@ namespace ngraph
class ScatterNDAdd : public Op class ScatterNDAdd : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \param inputs Tensor /// \param inputs Tensor
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param updates Tensor: Must have same type as inputs /// \param updates Tensor: Must have same type as inputs
ScatterNDAdd(const std::shared_ptr<Node>& inputs, ScatterNDAdd(const std::shared_ptr<Node>& inputs,
const std::shared_ptr<Node>& indices, const std::shared_ptr<Node>& indices,
const std::shared_ptr<Node>& updates) const std::shared_ptr<Node>& updates)
: Op("ScatterNDAdd", check_single_output_args({inputs, indices, updates})) : Op(check_single_output_args({inputs, indices, updates}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sigmoid::type_name{"Sigmoid"};
const string op::SigmoidBackprop::type_name{"SigmoidBackprop"};
shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -28,13 +31,13 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con ...@@ -28,13 +31,13 @@ shared_ptr<Node> op::Sigmoid::copy_with_new_args(const NodeVector& new_args) con
} }
op::Sigmoid::Sigmoid(shared_ptr<Node> arg) op::Sigmoid::Sigmoid(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Sigmoid", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta) op::SigmoidBackprop::SigmoidBackprop(shared_ptr<Node> arg, shared_ptr<Node> delta)
: BinaryElementwiseArithmetic("SigmoidBackprop", arg, delta) : BinaryElementwiseArithmetic(arg, delta)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -28,6 +28,9 @@ namespace ngraph ...@@ -28,6 +28,9 @@ namespace ngraph
class Sigmoid : public util::UnaryElementwiseArithmetic class Sigmoid : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Sigmoid(std::shared_ptr<Node> arg); Sigmoid(std::shared_ptr<Node> arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -40,6 +43,9 @@ namespace ngraph ...@@ -40,6 +43,9 @@ namespace ngraph
class SigmoidBackprop : public util::BinaryElementwiseArithmetic class SigmoidBackprop : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a SigmoidBackprop operation. /// \brief Constructs a SigmoidBackprop operation.
/// ///
/// \param arg Node that produces the Sigmoid forward input tensor. /// \param arg Node that produces the Sigmoid forward input tensor.
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sign::type_name{"Sign"};
op::Sign::Sign(const shared_ptr<Node>& arg) op::Sign::Sign(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sign", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class Sign : public util::UnaryElementwiseArithmetic class Sign : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an elementwise sign operation. /// \brief Constructs an elementwise sign operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sin::type_name{"Sin"};
op::Sin::Sin(const shared_ptr<Node>& arg) op::Sin::Sin(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sin", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class Sin : public util::UnaryElementwiseArithmetic class Sin : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a sine operation. /// \brief Constructs a sine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sinh::type_name{"Sinh"};
op::Sinh::Sinh(const shared_ptr<Node>& arg) op::Sinh::Sinh(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sinh", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class Sinh : public util::UnaryElementwiseArithmetic class Sinh : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic sine operation. /// \brief Constructs a hyperbolic sine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -29,8 +29,10 @@ ...@@ -29,8 +29,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Softmax::type_name{"Softmax"};
op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes) op::Softmax::Softmax(const shared_ptr<Node>& arg, const AxisSet& axes)
: UnaryElementwiseArithmetic("Softmax", arg) : UnaryElementwiseArithmetic(arg)
, m_axes(axes) , m_axes(axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -27,6 +27,9 @@ namespace ngraph ...@@ -27,6 +27,9 @@ namespace ngraph
class Softmax : public util::UnaryElementwiseArithmetic class Softmax : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a softmax operation. /// \brief Constructs a softmax operation.
/// ///
/// \param arg Node that produces the first input tensor.<br> /// \param arg Node that produces the first input tensor.<br>
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Sqrt::type_name{"Sqrt"};
op::Sqrt::Sqrt(const shared_ptr<Node>& arg) op::Sqrt::Sqrt(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Sqrt", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class Sqrt : public util::UnaryElementwiseArithmetic class Sqrt : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a square operation. /// \brief Constructs a square operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::StopGradient::type_name{"StopGradient"};
op::StopGradient::StopGradient(const shared_ptr<Node>& arg) op::StopGradient::StopGradient(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("StopGradient", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class StopGradient : public util::UnaryElementwiseArithmetic class StopGradient : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs StopGradient /// \brief Constructs StopGradient
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Tan::type_name{"Tan"};
op::Tan::Tan(const shared_ptr<Node>& arg) op::Tan::Tan(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Tan", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,6 +38,9 @@ namespace ngraph ...@@ -38,6 +38,9 @@ namespace ngraph
class Tan : public util::UnaryElementwiseArithmetic class Tan : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tangent operation. /// \brief Constructs a tangent operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Tanh::type_name{"Tanh"};
op::Tanh::Tanh(const shared_ptr<Node>& arg) op::Tanh::Tanh(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Tanh", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class Tanh : public util::UnaryElementwiseArithmetic class Tanh : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic tangent operation. /// \brief Constructs a hyperbolic tangent operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
...@@ -53,6 +54,7 @@ ...@@ -53,6 +54,7 @@
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
...@@ -86,9 +88,11 @@ ...@@ -86,9 +88,11 @@
#include "ngraph/runtime/reference/pad.hpp" #include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/range.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/slice.hpp" #include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
...@@ -2247,3 +2251,196 @@ void pass::ConstantFolding::construct_constant_dyn_slice() ...@@ -2247,3 +2251,196 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
make_shared<pattern::Matcher>(dyn_slice_op, "ConstantFolding.ConstantDynSlice"); make_shared<pattern::Matcher>(dyn_slice_op, "ConstantFolding.ConstantDynSlice");
this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off); this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off);
} }
template <class T>
shared_ptr<op::Constant> fold_constant_range(shared_ptr<op::Constant> start,
shared_ptr<op::Constant> step,
shared_ptr<op::Range> range)
{
vector<T> out_vec(shape_size(range->get_shape()));
runtime::reference::range<T>(start->get_vector<T>().data(),
step->get_vector<T>().data(),
range->get_shape(),
out_vec.data());
return make_shared<op::Constant>(range->get_element_type(), range->get_shape(), out_vec);
}
void pass::ConstantFolding::construct_constant_range()
{
auto start_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto stop_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto step_label =
make_shared<pattern::op::Label>(element::i64, Shape{}, pattern::has_class<op::Constant>());
auto range_op = make_shared<op::Range>(start_label, stop_label, step_label);
auto constant_range_callback = [start_label, stop_label, step_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_range_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto start_node = static_pointer_cast<op::Constant>(pattern_map[start_label]);
auto stop_node = static_pointer_cast<op::Constant>(pattern_map[stop_label]);
auto step_node = static_pointer_cast<op::Constant>(pattern_map[step_label]);
auto range = static_pointer_cast<op::Range>(m.get_match_root());
std::shared_ptr<op::Constant> replacement;
switch (range->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_range_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_range_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_range<char>(start_node, step_node, range);
break;
case element::Type_t::bf16:
replacement = fold_constant_range<bfloat16>(start_node, step_node, range);
break;
case element::Type_t::f16:
replacement = fold_constant_range<float16>(start_node, step_node, range);
break;
case element::Type_t::f32:
replacement = fold_constant_range<float>(start_node, step_node, range);
break;
case element::Type_t::f64:
replacement = fold_constant_range<double>(start_node, step_node, range);
break;
case element::Type_t::i8:
replacement = fold_constant_range<int8_t>(start_node, step_node, range);
break;
case element::Type_t::i16:
replacement = fold_constant_range<int16_t>(start_node, step_node, range);
break;
case element::Type_t::i32:
replacement = fold_constant_range<int32_t>(start_node, step_node, range);
break;
case element::Type_t::i64:
replacement = fold_constant_range<int64_t>(start_node, step_node, range);
break;
case element::Type_t::u8:
replacement = fold_constant_range<uint8_t>(start_node, step_node, range);
break;
case element::Type_t::u16:
replacement = fold_constant_range<uint16_t>(start_node, step_node, range);
break;
case element::Type_t::u32:
replacement = fold_constant_range<uint32_t>(start_node, step_node, range);
break;
case element::Type_t::u64:
replacement = fold_constant_range<uint64_t>(start_node, step_node, range);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto range_matcher = make_shared<pattern::Matcher>(range_op, "ConstantFolding.ConstantRange");
this->add_matcher(range_matcher, constant_range_callback, all_pass_property_off);
}
template <class T>
shared_ptr<op::Constant> fold_constant_select(shared_ptr<op::Constant> selection,
shared_ptr<op::Constant> t,
shared_ptr<op::Constant> f,
shared_ptr<op::Select> select)
{
auto out_shape = select->get_shape();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(),
f->get_data_ptr<T>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(select->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_select()
{
auto selection_label = make_shared<pattern::op::Label>(
element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto t_label = make_shared<pattern::op::Label>(
element::i64, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto f_label = make_shared<pattern::op::Label>(
element::i64, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto select_op = make_shared<op::Select>(selection_label, t_label, f_label);
auto constant_select_callback = [selection_label, t_label, f_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_select_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto selection_node = static_pointer_cast<op::Constant>(pattern_map[selection_label]);
auto t_node = static_pointer_cast<op::Constant>(pattern_map[t_label]);
auto f_node = static_pointer_cast<op::Constant>(pattern_map[f_label]);
auto select = static_pointer_cast<op::Select>(m.get_match_root());
std::shared_ptr<op::Constant> replacement;
switch (select->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_select_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_select_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_select<char>(selection_node, t_node, f_node, select);
break;
case element::Type_t::bf16:
replacement = fold_constant_select<bfloat16>(selection_node, t_node, f_node, select);
break;
case element::Type_t::f16:
replacement = fold_constant_select<float16>(selection_node, t_node, f_node, select);
break;
case element::Type_t::f32:
replacement = fold_constant_select<float>(selection_node, t_node, f_node, select);
break;
case element::Type_t::f64:
replacement = fold_constant_select<double>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i8:
replacement = fold_constant_select<int8_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i16:
replacement = fold_constant_select<int16_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i32:
replacement = fold_constant_select<int32_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i64:
replacement = fold_constant_select<int64_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u8:
replacement = fold_constant_select<uint8_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u16:
replacement = fold_constant_select<uint16_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u32:
replacement = fold_constant_select<uint32_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u64:
replacement = fold_constant_select<uint64_t>(selection_node, t_node, f_node, select);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto select_matcher =
make_shared<pattern::Matcher>(select_op, "ConstantFolding.ConstantSelect");
this->add_matcher(select_matcher, constant_select_callback, all_pass_property_off);
}
...@@ -49,7 +49,9 @@ public: ...@@ -49,7 +49,9 @@ public:
SLICE, SLICE,
DYN_SLICE, DYN_SLICE,
DYN_RESHAPE, DYN_RESHAPE,
TRANSPOSE TRANSPOSE,
RANGE,
SELECT
}; };
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()) ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
...@@ -74,6 +76,8 @@ public: ...@@ -74,6 +76,8 @@ public:
construct_constant_dyn_slice(); construct_constant_dyn_slice();
construct_constant_dyn_reshape(); construct_constant_dyn_reshape();
construct_constant_transpose(); construct_constant_transpose();
construct_constant_range();
construct_constant_select();
} }
//this allows to specify the order in which matchers will be run //this allows to specify the order in which matchers will be run
...@@ -105,6 +109,8 @@ public: ...@@ -105,6 +109,8 @@ public:
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break; case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break; case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break; case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
case CFTransformations::RANGE: construct_constant_range(); break;
case CFTransformations::SELECT: construct_constant_select(); break;
} }
} }
} }
...@@ -128,6 +134,8 @@ private: ...@@ -128,6 +134,8 @@ private:
void construct_constant_dyn_slice(); void construct_constant_dyn_slice();
void construct_constant_dyn_reshape(); void construct_constant_dyn_reshape();
void construct_constant_transpose(); void construct_constant_transpose();
void construct_constant_range();
void construct_constant_select();
ngraph::BuildNodeExecutorMap m_cfmap; ngraph::BuildNodeExecutorMap m_cfmap;
}; };
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/range.hpp"
#include "ngraph/slice_plan.hpp" #include "ngraph/slice_plan.hpp"
using namespace std; using namespace std;
...@@ -342,8 +343,7 @@ void pass::DynElimination::construct_dyn_reshape() ...@@ -342,8 +343,7 @@ void pass::DynElimination::construct_dyn_reshape()
} }
template <typename T> template <typename T>
std::shared_ptr<op::Constant> std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
make_range_replacement_integral(const element::Type& et,
const Shape& shape, const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg, const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg) const std::shared_ptr<op::Constant>& step_arg)
...@@ -354,40 +354,7 @@ std::shared_ptr<op::Constant> ...@@ -354,40 +354,7 @@ std::shared_ptr<op::Constant>
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1); NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
T start = start_vec[0]; runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape, elements.data());
T step = step_vec[0];
T val = start;
for (size_t i = 0; i < elements.size(); i++)
{
elements[i] = val;
val = val + step;
}
return make_shared<op::Constant>(et, shape, elements);
}
template <typename T>
std::shared_ptr<op::Constant>
make_range_replacement_floating(const element::Type& et,
const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg)
{
std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>();
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
T start = start_vec[0];
T step = step_vec[0];
for (size_t i = 0; i < elements.size(); i++)
{
elements[i] = start + (static_cast<T>(i) * step);
}
return make_shared<op::Constant>(et, shape, elements); return make_shared<op::Constant>(et, shape, elements);
} }
...@@ -426,40 +393,40 @@ void pass::DynElimination::construct_range() ...@@ -426,40 +393,40 @@ void pass::DynElimination::construct_range()
switch (et.get_type_enum()) switch (et.get_type_enum())
{ {
case element::Type_t::bf16: case element::Type_t::bf16:
replacement = make_range_replacement_floating<bfloat16>(et, shape, start_arg, step_arg); replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::f16: case element::Type_t::f16:
replacement = make_range_replacement_floating<float16>(et, shape, start_arg, step_arg); replacement = make_range_replacement<float16>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
replacement = make_range_replacement_floating<float>(et, shape, start_arg, step_arg); replacement = make_range_replacement<float>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
replacement = make_range_replacement_floating<double>(et, shape, start_arg, step_arg); replacement = make_range_replacement<double>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
replacement = make_range_replacement_integral<int8_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int8_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
replacement = make_range_replacement_integral<int16_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int16_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
replacement = make_range_replacement_integral<int32_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int32_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
replacement = make_range_replacement_integral<int64_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<int64_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
replacement = make_range_replacement_integral<uint8_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint8_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
replacement = make_range_replacement_integral<uint16_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint16_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
replacement = make_range_replacement_integral<uint32_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint32_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
replacement = make_range_replacement_integral<uint64_t>(et, shape, start_arg, step_arg); replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <type_traits>
#include "ngraph/axis_vector.hpp"
#include "ngraph/check.hpp"
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
// Return type is `void`, only enabled if `T` is a built-in FP
// type, or nGraph's `bfloat16` or `float16` type.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value ||
std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value>::type
range(const T* start, const T* step, const Shape& out_shape, T* out)
{
for (size_t i = 0; i < shape_size(out_shape); i++)
{
out[i] = *start + (static_cast<T>(i) * (*step));
}
}
// Return type is `void`, only enabled if `T` is `is_integral`.
template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type
range(const T* start, const T* step, const Shape& out_shape, T* out)
{
T val = *start;
for (size_t i = 0; i < shape_size(out_shape); i++)
{
out[i] = val;
val += *step;
}
}
}
}
}
...@@ -891,6 +891,94 @@ TEST(constant_folding, constant_transpose) ...@@ -891,6 +891,94 @@ TEST(constant_folding, constant_transpose)
ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS)); ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
} }
void range_test_check(const vector<double>& values_out, const vector<double>& values_expected)
{
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
void range_test_check(const vector<float>& values_out, const vector<float>& values_expected)
{
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type
range_test_check(const vector<T>& values_out, const vector<T>& values_expected)
{
ASSERT_EQ(values_out, values_expected);
}
template <typename T>
void range_test(T start, T stop, T step, const vector<T>& values_expected)
{
vector<T> values_start{start};
vector<T> values_stop{stop};
vector<T> values_step{step};
auto constant_start = make_shared<op::Constant>(element::from<T>(), Shape{}, values_start);
auto constant_stop = make_shared<op::Constant>(element::from<T>(), Shape{}, values_stop);
auto constant_step = make_shared<op::Constant>(element::from<T>(), Shape{}, values_step);
auto range = make_shared<op::Range>(constant_start, constant_stop, constant_step);
auto f = make_shared<Function>(range, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Range>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->template get_vector<T>();
range_test_check(values_out, values_expected);
}
TEST(constant_folding, constant_range)
{
range_test<int8_t>(5, 12, 2, {5, 7, 9, 11});
range_test<int32_t>(5, 12, 2, {5, 7, 9, 11});
range_test<int64_t>(5, 12, 2, {5, 7, 9, 11});
range_test<uint64_t>(5, 12, 2, {5, 7, 9, 11});
range_test<double>(5, 12, 2, {5, 7, 9, 11});
range_test<float>(5, 12, 2, {5, 7, 9, 11});
range_test<int32_t>(5, 12, -2, {});
range_test<float>(12, 4, -2, {12, 10, 8, 6});
}
TEST(constant_folding, constant_select)
{
Shape shape{2, 4};
vector<char> values_selection{0, 1, 1, 0, 1, 0, 0, 1};
vector<int64_t> values_t{2, 4, 6, 8, 10, 12, 14, 16};
vector<int64_t> values_f{1, 3, 5, 7, 9, 11, 13, 15};
auto constant_selection = make_shared<op::Constant>(element::boolean, shape, values_selection);
auto constant_t = make_shared<op::Constant>(element::i64, shape, values_t);
auto constant_f = make_shared<op::Constant>(element::i64, shape, values_f);
auto select = make_shared<op::Select>(constant_selection, constant_t, constant_f);
auto f = make_shared<Function>(select, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int64_t>();
vector<int64_t> values_expected{1, 4, 6, 7, 10, 11, 13, 16};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, pass_property) TEST(constant_folding, pass_property)
{ {
auto pass = std::make_shared<ngraph::pass::ConstantFolding>(); auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment