Unverified Commit 0ca40376 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert some more ops to use Output<Node> inputs (#3307)

* Convert some ops to use Output<Node> inputs

* Remove duplicate validation
parent 0c523507
...@@ -21,13 +21,15 @@ ...@@ -21,13 +21,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedAvgPool::QuantizedAvgPool(const shared_ptr<Node>& arg, const string op::QuantizedAvgPool::type_name{"QuantizedAvgPool"};
op::QuantizedAvgPool::QuantizedAvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
bool include_padding_in_avg_computation) bool include_padding_in_avg_computation)
: Op("QuantizedAvgPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -35,15 +37,16 @@ op::QuantizedAvgPool::QuantizedAvgPool(const shared_ptr<Node>& arg, ...@@ -35,15 +37,16 @@ op::QuantizedAvgPool::QuantizedAvgPool(const shared_ptr<Node>& arg,
, m_include_padding_in_avg_computation(include_padding_in_avg_computation) , m_include_padding_in_avg_computation(include_padding_in_avg_computation)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
}
if (arg->get_element_type() != element::u8 && arg->get_element_type() != element::i8) void op::QuantizedAvgPool::validate_and_infer_types()
{
auto arg(input(0).get_source_output());
if (arg.get_element_type() != element::u8 && arg.get_element_type() != element::i8)
{ {
throw ngraph_error("QuantizedAvgPool supported only for i8/u8!"); throw ngraph_error("QuantizedAvgPool supported only for i8/u8!");
} }
}
void op::QuantizedAvgPool::validate_and_infer_types()
{
auto& arg_shape = get_input_shape(0); auto& arg_shape = get_input_shape(0);
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2) if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
......
...@@ -28,6 +28,11 @@ namespace ngraph ...@@ -28,6 +28,11 @@ namespace ngraph
class QuantizedAvgPool : public Op class QuantizedAvgPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation.
QuantizedAvgPool() = default;
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor.<br> /// \param arg The node producing the input data batch tensor.<br>
...@@ -43,7 +48,7 @@ namespace ngraph ...@@ -43,7 +48,7 @@ namespace ngraph
/// \param include_padding_in_avg_computation If true then averages include padding /// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely /// elements, each treated as the number zero. If false, padding elements are entirely
/// ignored when computing averages. /// ignored when computing averages.
QuantizedAvgPool(const std::shared_ptr<Node>& arg, QuantizedAvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
......
...@@ -27,17 +27,19 @@ ...@@ -27,17 +27,19 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& data_batch, const string op::QuantizedConvolutionBias::type_name{"QuantizedConvolutionBias"};
const shared_ptr<Node>& filters,
const shared_ptr<Node>& bias, op::QuantizedConvolutionBias::QuantizedConvolutionBias(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& scale, const Output<Node>& scale,
const bool with_relu) const bool with_relu)
: Op("QuantizedConvolutionBias", check_single_output_args({data_batch, filters, bias, scale})) : Op({data_batch, filters, bias, scale})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -47,8 +49,8 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d ...@@ -47,8 +49,8 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape(); auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters->get_shape(); auto& filters_shape = filters.get_shape();
// TODO: call ngraph util // TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape()); // util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
...@@ -92,20 +94,21 @@ shared_ptr<Node> op::QuantizedConvolutionBias::copy_with_new_args(const NodeVect ...@@ -92,20 +94,21 @@ shared_ptr<Node> op::QuantizedConvolutionBias::copy_with_new_args(const NodeVect
m_with_relu)); m_with_relu));
} }
op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<Node>& data_batch, const string op::QuantizedConvolutionBiasAdd::type_name{"QuantizedConvolutionBiasAdd"};
const shared_ptr<Node>& filters,
const shared_ptr<Node>& bias, op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const Output<Node>& data_batch,
const shared_ptr<Node>& sum_input, const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& scale, const Output<Node>& scale,
const shared_ptr<Node>& sum_scale, const Output<Node>& sum_scale,
const bool with_relu) const bool with_relu)
: Op("QuantizedConvolutionBiasAdd", : Op({data_batch, filters, bias, sum_input, scale, sum_scale})
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -115,8 +118,8 @@ op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<No ...@@ -115,8 +118,8 @@ op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<No
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape(); auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters->get_shape(); auto& filters_shape = filters.get_shape();
// TODO: call ngraph util // TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape()); // util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
...@@ -163,21 +166,22 @@ shared_ptr<Node> ...@@ -163,21 +166,22 @@ shared_ptr<Node>
m_with_relu)); m_with_relu));
} }
const string op::QuantizedConvolutionBiasSignedAdd::type_name{"QuantizedConvolutionBiasSignedAdd"};
op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd( op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd(
const shared_ptr<Node>& data_batch, const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& bias, const Output<Node>& bias,
const shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& scale, const Output<Node>& scale,
const shared_ptr<Node>& sum_scale, const Output<Node>& sum_scale,
const bool with_relu) const bool with_relu)
: Op("QuantizedConvolutionBiasSignedAdd", : Op({data_batch, filters, bias, sum_input, scale, sum_scale})
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -187,8 +191,8 @@ op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd( ...@@ -187,8 +191,8 @@ op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd(
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape(); auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters->get_shape(); auto& filters_shape = filters.get_shape();
// TODO: call ngraph util // TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape()); // util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
......
...@@ -27,15 +27,18 @@ namespace ngraph ...@@ -27,15 +27,18 @@ namespace ngraph
class QuantizedConvolutionBias : public Op class QuantizedConvolutionBias : public Op
{ {
public: public:
QuantizedConvolutionBias(const std::shared_ptr<Node>& data_batch, NGRAPH_API
const std::shared_ptr<Node>& filters, static const std::string type_name;
const std::shared_ptr<Node>& bias, const std::string& description() const override { return type_name; }
QuantizedConvolutionBias(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale, const Output<Node>& scale,
const bool with_relu = false); const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
...@@ -43,9 +46,9 @@ namespace ngraph ...@@ -43,9 +46,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); } Output<Node> get_bias() { return input(2).get_source_output(); }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -62,17 +65,20 @@ namespace ngraph ...@@ -62,17 +65,20 @@ namespace ngraph
class QuantizedConvolutionBiasAdd : public Op class QuantizedConvolutionBiasAdd : public Op
{ {
public: public:
QuantizedConvolutionBiasAdd(const std::shared_ptr<Node>& data_batch, NGRAPH_API
const std::shared_ptr<Node>& filters, static const std::string type_name;
const std::shared_ptr<Node>& bias, const std::string& description() const override { return type_name; }
const std::shared_ptr<Node>& sum_input, QuantizedConvolutionBiasAdd(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale, const Output<Node>& scale,
const std::shared_ptr<Node>& sum_scale, const Output<Node>& sum_scale,
const bool with_relu = false); const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
...@@ -80,9 +86,9 @@ namespace ngraph ...@@ -80,9 +86,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); } Output<Node> get_bias() { return input(2).get_source_output(); }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -99,17 +105,20 @@ namespace ngraph ...@@ -99,17 +105,20 @@ namespace ngraph
class QuantizedConvolutionBiasSignedAdd : public Op class QuantizedConvolutionBiasSignedAdd : public Op
{ {
public: public:
QuantizedConvolutionBiasSignedAdd(const std::shared_ptr<Node>& data_batch, NGRAPH_API
const std::shared_ptr<Node>& filters, static const std::string type_name;
const std::shared_ptr<Node>& bias, const std::string& description() const override { return type_name; }
const std::shared_ptr<Node>& sum_input, QuantizedConvolutionBiasSignedAdd(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale, const Output<Node>& scale,
const std::shared_ptr<Node>& sum_scale, const Output<Node>& sum_scale,
const bool with_relu = false); const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
...@@ -117,9 +126,9 @@ namespace ngraph ...@@ -117,9 +126,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); } Output<Node> get_bias() { return input(2).get_source_output(); }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -26,15 +26,17 @@ ...@@ -26,15 +26,17 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& data_batch, const string op::QuantizedConvolutionRelu::type_name{"QuantizedConvolutionRelu"};
const shared_ptr<Node>& filters,
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& scale) const Output<Node>& scale)
: Op("QuantizedConvolutionRelu", check_single_output_args({data_batch, filters, scale})) : Op({data_batch, filters, scale})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -43,8 +45,8 @@ op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& d ...@@ -43,8 +45,8 @@ op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& d
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape(); auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters->get_shape(); auto& filters_shape = filters.get_shape();
set_output_type(0, set_output_type(0,
element::u8, element::u8,
......
...@@ -27,22 +27,25 @@ namespace ngraph ...@@ -27,22 +27,25 @@ namespace ngraph
class QuantizedConvolutionRelu : public Op class QuantizedConvolutionRelu : public Op
{ {
public: public:
QuantizedConvolutionRelu(const std::shared_ptr<Node>& data_batch, NGRAPH_API
const std::shared_ptr<Node>& filters, static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionRelu(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale); const Output<Node>& scale);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return true; } bool with_relu() const { return true; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -24,19 +24,21 @@ ...@@ -24,19 +24,21 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data, const string op::QuantizedDot::type_name{"QuantizedDot"};
const shared_ptr<Node>& weights,
const shared_ptr<Node>& scale, op::QuantizedDot::QuantizedDot(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& scale,
bool requantize, bool requantize,
bool with_relu) bool with_relu)
: Op("QuantizedDot", check_single_output_args({data, weights, scale})) : Op({data, weights, scale})
, m_requantize(requantize) , m_requantize(requantize)
, m_with_relu(with_relu) , m_with_relu(with_relu)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_shape = data->get_shape(); auto& data_shape = data.get_shape();
auto& weights_shape = weights->get_shape(); auto& weights_shape = weights.get_shape();
// QuantizedDot does [m ,n] * [n, k] = [m, k] // QuantizedDot does [m ,n] * [n, k] = [m, k]
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 && data_shape.size() == 2 && weights_shape.size() == 2 &&
...@@ -47,7 +49,7 @@ op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data, ...@@ -47,7 +49,7 @@ op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data,
weights_shape); weights_shape);
auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::i32; auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::i32;
if (data->get_element_type() == element::u8 && weights->get_element_type() == element::u8) if (data.get_element_type() == element::u8 && weights.get_element_type() == element::u8)
{ {
output_et = element::u8; output_et = element::u8;
} }
......
...@@ -27,9 +27,12 @@ namespace ngraph ...@@ -27,9 +27,12 @@ namespace ngraph
class QuantizedDot : public Op class QuantizedDot : public Op
{ {
public: public:
QuantizedDot(const std::shared_ptr<Node>& data, NGRAPH_API
const std::shared_ptr<Node>& weights, static const std::string type_name;
const std::shared_ptr<Node>& scale, const std::string& description() const override { return type_name; }
QuantizedDot(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& scale,
bool requantize = true, bool requantize = true,
bool with_relu = false); bool with_relu = false);
......
...@@ -24,21 +24,23 @@ ...@@ -24,21 +24,23 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedDotBias::QuantizedDotBias(const shared_ptr<Node>& data, const string op::QuantizedDotBias::type_name{"QuantizedDotBias"};
const shared_ptr<Node>& weights,
const shared_ptr<Node>& bias, op::QuantizedDotBias::QuantizedDotBias(const Output<Node>& data,
const shared_ptr<Node>& scale, const Output<Node>& weights,
const Output<Node>& bias,
const Output<Node>& scale,
bool requantize, bool requantize,
bool with_relu) bool with_relu)
: Op("QuantizedDotBias", check_single_output_args({data, weights, bias, scale})) : Op({data, weights, bias, scale})
, m_requantize(requantize) , m_requantize(requantize)
, m_with_relu(with_relu) , m_with_relu(with_relu)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
auto& data_shape = data->get_shape(); auto& data_shape = data.get_shape();
auto& weights_shape = weights->get_shape(); auto& weights_shape = weights.get_shape();
auto& bias_shape = bias->get_shape(); auto& bias_shape = bias.get_shape();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 && data_shape.size() == 2 && weights_shape.size() == 2 &&
data_shape[1] == weights_shape[1], data_shape[1] == weights_shape[1],
......
...@@ -27,10 +27,13 @@ namespace ngraph ...@@ -27,10 +27,13 @@ namespace ngraph
class QuantizedDotBias : public Op class QuantizedDotBias : public Op
{ {
public: public:
QuantizedDotBias(const std::shared_ptr<Node>& data, NGRAPH_API
const std::shared_ptr<Node>& weights, static const std::string type_name;
const std::shared_ptr<Node>& bias, const std::string& description() const override { return type_name; }
const std::shared_ptr<Node>& scale, QuantizedDotBias(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& bias,
const Output<Node>& scale,
bool requantize = true, bool requantize = true,
bool with_relu = false); bool with_relu = false);
......
...@@ -22,12 +22,14 @@ ...@@ -22,12 +22,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedMaxPool::QuantizedMaxPool(const shared_ptr<Node>& arg, const string op::QuantizedMaxPool::type_name{"QuantizedMaxPool"};
op::QuantizedMaxPool::QuantizedMaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("QuantizedMaxPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -35,7 +37,7 @@ op::QuantizedMaxPool::QuantizedMaxPool(const shared_ptr<Node>& arg, ...@@ -35,7 +37,7 @@ op::QuantizedMaxPool::QuantizedMaxPool(const shared_ptr<Node>& arg,
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
if (arg->get_element_type() != element::u8 && arg->get_element_type() != element::i8) if (arg.get_element_type() != element::u8 && arg.get_element_type() != element::i8)
{ {
throw ngraph_error("QuantizedMaxPool supported only for i8/u8!"); throw ngraph_error("QuantizedMaxPool supported only for i8/u8!");
} }
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class QuantizedMaxPool : public Op class QuantizedMaxPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
...@@ -33,7 +36,7 @@ namespace ngraph ...@@ -33,7 +36,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
QuantizedMaxPool(const std::shared_ptr<Node>& arg, QuantizedMaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
......
...@@ -20,14 +20,16 @@ ...@@ -20,14 +20,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Quantize::Quantize(const shared_ptr<Node>& input, const string op::Quantize::type_name{"Quantize"};
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point, op::Quantize::Quantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type, const element::Type& type,
const AxisSet& axes, const AxisSet& axes,
RoundMode round_mode) RoundMode round_mode)
: Op("Quantize", check_single_output_args({input, scale, zero_point})) : Op({input, scale, zero_point})
, m_type(type) , m_type(type)
, m_axes(axes) , m_axes(axes)
, m_round_mode(round_mode) , m_round_mode(round_mode)
......
...@@ -30,6 +30,9 @@ namespace ngraph ...@@ -30,6 +30,9 @@ namespace ngraph
class Quantize : public ngraph::op::Op class Quantize : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
enum class RoundMode enum class RoundMode
{ {
// round to nearest integer // round to nearest integer
...@@ -82,9 +85,9 @@ namespace ngraph ...@@ -82,9 +85,9 @@ namespace ngraph
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified /// \param axes axis positions on which `scale` and `zero_point` are specified
/// \param round_mode describes how to perform ROUND function (see above) /// \param round_mode describes how to perform ROUND function (see above)
Quantize(const std::shared_ptr<Node>& input, Quantize(const Output<Node>& input,
const std::shared_ptr<Node>& scale, const Output<Node>& scale,
const std::shared_ptr<Node>& zero_point, const Output<Node>& zero_point,
const ngraph::element::Type& type, const ngraph::element::Type& type,
const ngraph::AxisSet& axes, const ngraph::AxisSet& axes,
RoundMode round_mode); RoundMode round_mode);
......
...@@ -24,32 +24,33 @@ ...@@ -24,32 +24,33 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input, const string op::QuantizedConvolution::type_name{"QuantizedConvolution"};
const shared_ptr<Node>& filters,
op::QuantizedConvolution::QuantizedConvolution(const Output<Node>& input,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale, const Output<Node>& input_scale,
const shared_ptr<Node>& input_zero_point, const Output<Node>& input_zero_point,
const shared_ptr<Node>& filter_scale, const Output<Node>& filter_scale,
const shared_ptr<Node>& filter_zero_point, const Output<Node>& filter_zero_point,
const shared_ptr<Node>& output_scale, const Output<Node>& output_scale,
const shared_ptr<Node>& output_zero_point, const Output<Node>& output_zero_point,
const element::Type& output_type, const element::Type& output_type,
const AxisSet& input_axes, const AxisSet& input_axes,
const AxisSet& filter_axes, const AxisSet& filter_axes,
const AxisSet& output_axes) const AxisSet& output_axes)
: Op("QuantizedConvolution", : Op({input,
check_single_output_args({input, filters,
filters, input_scale,
input_scale, input_zero_point,
input_zero_point, filter_scale,
filter_scale, filter_zero_point,
filter_zero_point, output_scale,
output_scale, output_zero_point})
output_zero_point}))
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
......
...@@ -26,6 +26,9 @@ namespace ngraph ...@@ -26,6 +26,9 @@ namespace ngraph
class QuantizedConvolution : public Op class QuantizedConvolution : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a quantized convolution operation. /// \brief Constructs a quantized convolution operation.
/// ///
/// \param input The node producing the input data batch tensor. /// \param input The node producing the input data batch tensor.
...@@ -45,19 +48,19 @@ namespace ngraph ...@@ -45,19 +48,19 @@ namespace ngraph
/// \param input_axes Input axes set for channel wise quantization /// \param input_axes Input axes set for channel wise quantization
/// \param filter_axes Filter axes set for channel wise quantization /// \param filter_axes Filter axes set for channel wise quantization
/// \param output_axes Output axes set for channel wise quantization /// \param output_axes Output axes set for channel wise quantization
QuantizedConvolution(const std::shared_ptr<Node>& input, QuantizedConvolution(const Output<Node>& input,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale, const Output<Node>& input_scale,
const std::shared_ptr<Node>& input_zero_point, const Output<Node>& input_zero_point,
const std::shared_ptr<Node>& filter_scale, const Output<Node>& filter_scale,
const std::shared_ptr<Node>& filter_zero_point, const Output<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale, const Output<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point, const Output<Node>& output_zero_point,
const ngraph::element::Type& output_type, const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes = ngraph::AxisSet{}, const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{}, const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
......
...@@ -26,10 +26,10 @@ ...@@ -26,10 +26,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Select::Select(const shared_ptr<Node>& arg0, const string op::Select::type_name{"Select"};
const shared_ptr<Node>& arg1,
const shared_ptr<Node>& arg2) op::Select::Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2)
: Op("Select", check_single_output_args({arg0, arg1, arg2})) : Op({arg0, arg1, arg2})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -40,14 +40,17 @@ namespace ngraph ...@@ -40,14 +40,17 @@ namespace ngraph
class Select : public Op class Select : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a selection operation.
Select() = default;
/// \brief Constructs a selection operation. /// \brief Constructs a selection operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param arg2 Node that produces the third input tensor. /// \param arg2 Node that produces the third input tensor.
Select(const std::shared_ptr<Node>& arg0, Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2);
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,10 +20,12 @@ ...@@ -20,10 +20,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Subtract::Subtract(const shared_ptr<Node>& arg0, const string op::Subtract::type_name{"Subtract"};
const shared_ptr<Node>& arg1,
op::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Subtract", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -50,8 +52,7 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -50,8 +52,7 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, -delta); adjoints.add_delta(y, -delta);
} }
shared_ptr<ngraph::Node> ngraph::operator-(const shared_ptr<ngraph::Node> arg0, shared_ptr<ngraph::Node> ngraph::operator-(const Output<Node> arg0, const Output<Node> arg1)
const shared_ptr<ngraph::Node> arg1)
{ {
return make_shared<ngraph::op::Subtract>(arg0, arg1); return make_shared<ngraph::op::Subtract>(arg0, arg1);
} }
...@@ -26,13 +26,17 @@ namespace ngraph ...@@ -26,13 +26,17 @@ namespace ngraph
class Subtract : public util::BinaryElementwiseArithmetic class Subtract : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Subtract() = default;
/// \brief Constructs an subtraction operation. /// \brief Constructs an subtraction operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Subtract(const std::shared_ptr<Node>& arg0, Subtract(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -42,6 +46,6 @@ namespace ngraph ...@@ -42,6 +46,6 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} }
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<ngraph::Node> operator-(const Output<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1); const Output<ngraph::Node> arg1);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment