Unverified Commit 0ca40376 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert some more ops to use Output<Node> inputs (#3307)

* Convert some ops to use Output<Node> inputs

* Remove duplicate validation
parent 0c523507
......@@ -21,13 +21,15 @@
using namespace std;
using namespace ngraph;
op::QuantizedAvgPool::QuantizedAvgPool(const shared_ptr<Node>& arg,
const string op::QuantizedAvgPool::type_name{"QuantizedAvgPool"};
op::QuantizedAvgPool::QuantizedAvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above,
bool include_padding_in_avg_computation)
: Op("QuantizedAvgPool", check_single_output_args({arg}))
: Op({arg})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
......@@ -35,15 +37,16 @@ op::QuantizedAvgPool::QuantizedAvgPool(const shared_ptr<Node>& arg,
, m_include_padding_in_avg_computation(include_padding_in_avg_computation)
{
constructor_validate_and_infer_types();
}
if (arg->get_element_type() != element::u8 && arg->get_element_type() != element::i8)
void op::QuantizedAvgPool::validate_and_infer_types()
{
auto arg(input(0).get_source_output());
if (arg.get_element_type() != element::u8 && arg.get_element_type() != element::i8)
{
throw ngraph_error("QuantizedAvgPool supported only for i8/u8!");
}
}
void op::QuantizedAvgPool::validate_and_infer_types()
{
auto& arg_shape = get_input_shape(0);
if (0 == m_window_movement_strides.size() && arg_shape.size() > 2)
......
......@@ -28,6 +28,11 @@ namespace ngraph
class QuantizedAvgPool : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation.
QuantizedAvgPool() = default;
/// \brief Constructs a batched average pooling operation.
///
/// \param arg The node producing the input data batch tensor.<br>
......@@ -43,7 +48,7 @@ namespace ngraph
/// \param include_padding_in_avg_computation If true then averages include padding
/// elements, each treated as the number zero. If false, padding elements are entirely
/// ignored when computing averages.
QuantizedAvgPool(const std::shared_ptr<Node>& arg,
QuantizedAvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......
......@@ -27,17 +27,19 @@
using namespace std;
using namespace ngraph;
op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const shared_ptr<Node>& bias,
const string op::QuantizedConvolutionBias::type_name{"QuantizedConvolutionBias"};
op::QuantizedConvolutionBias::QuantizedConvolutionBias(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& scale,
const Output<Node>& scale,
const bool with_relu)
: Op("QuantizedConvolutionBias", check_single_output_args({data_batch, filters, bias, scale}))
: Op({data_batch, filters, bias, scale})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -47,8 +49,8 @@ op::QuantizedConvolutionBias::QuantizedConvolutionBias(const shared_ptr<Node>& d
{
constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters.get_shape();
// TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
......@@ -92,20 +94,21 @@ shared_ptr<Node> op::QuantizedConvolutionBias::copy_with_new_args(const NodeVect
m_with_relu));
}
op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const shared_ptr<Node>& bias,
const shared_ptr<Node>& sum_input,
const string op::QuantizedConvolutionBiasAdd::type_name{"QuantizedConvolutionBiasAdd"};
op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& sum_scale,
const Output<Node>& scale,
const Output<Node>& sum_scale,
const bool with_relu)
: Op("QuantizedConvolutionBiasAdd",
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
: Op({data_batch, filters, bias, sum_input, scale, sum_scale})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -115,8 +118,8 @@ op::QuantizedConvolutionBiasAdd::QuantizedConvolutionBiasAdd(const shared_ptr<No
{
constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters.get_shape();
// TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
......@@ -163,21 +166,22 @@ shared_ptr<Node>
m_with_relu));
}
const string op::QuantizedConvolutionBiasSignedAdd::type_name{"QuantizedConvolutionBiasSignedAdd"};
op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd(
const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const shared_ptr<Node>& bias,
const shared_ptr<Node>& sum_input,
const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& sum_scale,
const Output<Node>& scale,
const Output<Node>& sum_scale,
const bool with_relu)
: Op("QuantizedConvolutionBiasSignedAdd",
check_single_output_args({data_batch, filters, bias, sum_input, scale, sum_scale}))
: Op({data_batch, filters, bias, sum_input, scale, sum_scale})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -187,8 +191,8 @@ op::QuantizedConvolutionBiasSignedAdd::QuantizedConvolutionBiasSignedAdd(
{
constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters.get_shape();
// TODO: call ngraph util
// util::validate_convbias_shapes(data_batch_shape, filters_shape, bias->get_shape());
......
......@@ -27,15 +27,18 @@ namespace ngraph
class QuantizedConvolutionBias : public Op
{
public:
QuantizedConvolutionBias(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionBias(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale,
const Output<Node>& scale,
const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......@@ -43,9 +46,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
Output<Node> get_bias() { return input(2).get_source_output(); }
Output<Node> get_filters() { return input(1).get_source_output(); }
Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -62,17 +65,20 @@ namespace ngraph
class QuantizedConvolutionBiasAdd : public Op
{
public:
QuantizedConvolutionBiasAdd(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
const std::shared_ptr<Node>& sum_input,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionBiasAdd(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& sum_scale,
const Output<Node>& scale,
const Output<Node>& sum_scale,
const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......@@ -80,9 +86,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
Output<Node> get_bias() { return input(2).get_source_output(); }
Output<Node> get_filters() { return input(1).get_source_output(); }
Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -99,17 +105,20 @@ namespace ngraph
class QuantizedConvolutionBiasSignedAdd : public Op
{
public:
QuantizedConvolutionBiasSignedAdd(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
const std::shared_ptr<Node>& bias,
const std::shared_ptr<Node>& sum_input,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionBiasSignedAdd(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
const Output<Node>& sum_input,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& sum_scale,
const Output<Node>& scale,
const Output<Node>& sum_scale,
const bool with_relu = false);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
......@@ -117,9 +126,9 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
Output<Node> get_bias() { return input(2).get_source_output(); }
Output<Node> get_filters() { return input(1).get_source_output(); }
Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -26,15 +26,17 @@
using namespace std;
using namespace ngraph;
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const string op::QuantizedConvolutionRelu::type_name{"QuantizedConvolutionRelu"};
op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& scale)
: Op("QuantizedConvolutionRelu", check_single_output_args({data_batch, filters, scale}))
const Output<Node>& scale)
: Op({data_batch, filters, scale})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -43,8 +45,8 @@ op::QuantizedConvolutionRelu::QuantizedConvolutionRelu(const shared_ptr<Node>& d
{
constructor_validate_and_infer_types();
auto& data_batch_shape = data_batch->get_shape();
auto& filters_shape = filters->get_shape();
auto& data_batch_shape = data_batch.get_shape();
auto& filters_shape = filters.get_shape();
set_output_type(0,
element::u8,
......
......@@ -27,22 +27,25 @@ namespace ngraph
class QuantizedConvolutionRelu : public Op
{
public:
QuantizedConvolutionRelu(const std::shared_ptr<Node>& data_batch,
const std::shared_ptr<Node>& filters,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionRelu(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& scale);
const Output<Node>& scale);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); }
Output<Node> get_filters() { return input(1).get_source_output(); }
Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return true; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -24,19 +24,21 @@
using namespace std;
using namespace ngraph;
op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data,
const shared_ptr<Node>& weights,
const shared_ptr<Node>& scale,
const string op::QuantizedDot::type_name{"QuantizedDot"};
op::QuantizedDot::QuantizedDot(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& scale,
bool requantize,
bool with_relu)
: Op("QuantizedDot", check_single_output_args({data, weights, scale}))
: Op({data, weights, scale})
, m_requantize(requantize)
, m_with_relu(with_relu)
{
constructor_validate_and_infer_types();
auto& data_shape = data->get_shape();
auto& weights_shape = weights->get_shape();
auto& data_shape = data.get_shape();
auto& weights_shape = weights.get_shape();
// QuantizedDot does [m ,n] * [n, k] = [m, k]
NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 &&
......@@ -47,7 +49,7 @@ op::QuantizedDot::QuantizedDot(const shared_ptr<Node>& data,
weights_shape);
auto output_et = requantize ? (with_relu ? element::u8 : element::i8) : element::i32;
if (data->get_element_type() == element::u8 && weights->get_element_type() == element::u8)
if (data.get_element_type() == element::u8 && weights.get_element_type() == element::u8)
{
output_et = element::u8;
}
......
......@@ -27,9 +27,12 @@ namespace ngraph
class QuantizedDot : public Op
{
public:
QuantizedDot(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& weights,
const std::shared_ptr<Node>& scale,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedDot(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& scale,
bool requantize = true,
bool with_relu = false);
......
......@@ -24,21 +24,23 @@
using namespace std;
using namespace ngraph;
op::QuantizedDotBias::QuantizedDotBias(const shared_ptr<Node>& data,
const shared_ptr<Node>& weights,
const shared_ptr<Node>& bias,
const shared_ptr<Node>& scale,
const string op::QuantizedDotBias::type_name{"QuantizedDotBias"};
op::QuantizedDotBias::QuantizedDotBias(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& bias,
const Output<Node>& scale,
bool requantize,
bool with_relu)
: Op("QuantizedDotBias", check_single_output_args({data, weights, bias, scale}))
: Op({data, weights, bias, scale})
, m_requantize(requantize)
, m_with_relu(with_relu)
{
constructor_validate_and_infer_types();
auto& data_shape = data->get_shape();
auto& weights_shape = weights->get_shape();
auto& bias_shape = bias->get_shape();
auto& data_shape = data.get_shape();
auto& weights_shape = weights.get_shape();
auto& bias_shape = bias.get_shape();
NODE_VALIDATION_CHECK(this,
data_shape.size() == 2 && weights_shape.size() == 2 &&
data_shape[1] == weights_shape[1],
......
......@@ -27,10 +27,13 @@ namespace ngraph
class QuantizedDotBias : public Op
{
public:
QuantizedDotBias(const std::shared_ptr<Node>& data,
const std::shared_ptr<Node>& weights,
const std::shared_ptr<Node>& bias,
const std::shared_ptr<Node>& scale,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedDotBias(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& bias,
const Output<Node>& scale,
bool requantize = true,
bool with_relu = false);
......
......@@ -22,12 +22,14 @@
using namespace std;
using namespace ngraph;
op::QuantizedMaxPool::QuantizedMaxPool(const shared_ptr<Node>& arg,
const string op::QuantizedMaxPool::type_name{"QuantizedMaxPool"};
op::QuantizedMaxPool::QuantizedMaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
const Shape& padding_above)
: Op("QuantizedMaxPool", check_single_output_args({arg}))
: Op({arg})
, m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below)
......@@ -35,7 +37,7 @@ op::QuantizedMaxPool::QuantizedMaxPool(const shared_ptr<Node>& arg,
{
constructor_validate_and_infer_types();
if (arg->get_element_type() != element::u8 && arg->get_element_type() != element::i8)
if (arg.get_element_type() != element::u8 && arg.get_element_type() != element::i8)
{
throw ngraph_error("QuantizedMaxPool supported only for i8/u8!");
}
......
......@@ -26,6 +26,9 @@ namespace ngraph
class QuantizedMaxPool : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
......@@ -33,7 +36,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape.
QuantizedMaxPool(const std::shared_ptr<Node>& arg,
QuantizedMaxPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
const Shape& padding_below,
......
......@@ -20,14 +20,16 @@
using namespace std;
using namespace ngraph;
op::Quantize::Quantize(const shared_ptr<Node>& input,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point,
const string op::Quantize::type_name{"Quantize"};
op::Quantize::Quantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type,
const AxisSet& axes,
RoundMode round_mode)
: Op("Quantize", check_single_output_args({input, scale, zero_point}))
: Op({input, scale, zero_point})
, m_type(type)
, m_axes(axes)
, m_round_mode(round_mode)
......
......@@ -30,6 +30,9 @@ namespace ngraph
class Quantize : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
enum class RoundMode
{
// round to nearest integer
......@@ -82,9 +85,9 @@ namespace ngraph
/// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified
/// \param round_mode describes how to perform ROUND function (see above)
Quantize(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& zero_point,
Quantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes,
RoundMode round_mode);
......
......@@ -24,32 +24,33 @@
using namespace std;
using namespace ngraph;
op::QuantizedConvolution::QuantizedConvolution(const shared_ptr<Node>& input,
const shared_ptr<Node>& filters,
const string op::QuantizedConvolution::type_name{"QuantizedConvolution"};
op::QuantizedConvolution::QuantizedConvolution(const Output<Node>& input,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const shared_ptr<Node>& input_scale,
const shared_ptr<Node>& input_zero_point,
const shared_ptr<Node>& filter_scale,
const shared_ptr<Node>& filter_zero_point,
const shared_ptr<Node>& output_scale,
const shared_ptr<Node>& output_zero_point,
const Output<Node>& input_scale,
const Output<Node>& input_zero_point,
const Output<Node>& filter_scale,
const Output<Node>& filter_zero_point,
const Output<Node>& output_scale,
const Output<Node>& output_zero_point,
const element::Type& output_type,
const AxisSet& input_axes,
const AxisSet& filter_axes,
const AxisSet& output_axes)
: Op("QuantizedConvolution",
check_single_output_args({input,
filters,
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
output_zero_point}))
: Op({input,
filters,
input_scale,
input_zero_point,
filter_scale,
filter_zero_point,
output_scale,
output_zero_point})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......
......@@ -26,6 +26,9 @@ namespace ngraph
class QuantizedConvolution : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a quantized convolution operation.
///
/// \param input The node producing the input data batch tensor.
......@@ -45,19 +48,19 @@ namespace ngraph
/// \param input_axes Input axes set for channel wise quantization
/// \param filter_axes Filter axes set for channel wise quantization
/// \param output_axes Output axes set for channel wise quantization
QuantizedConvolution(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& filters,
QuantizedConvolution(const Output<Node>& input,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const std::shared_ptr<Node>& input_scale,
const std::shared_ptr<Node>& input_zero_point,
const std::shared_ptr<Node>& filter_scale,
const std::shared_ptr<Node>& filter_zero_point,
const std::shared_ptr<Node>& output_scale,
const std::shared_ptr<Node>& output_zero_point,
const Output<Node>& input_scale,
const Output<Node>& input_zero_point,
const Output<Node>& filter_scale,
const Output<Node>& filter_zero_point,
const Output<Node>& output_scale,
const Output<Node>& output_zero_point,
const ngraph::element::Type& output_type,
const ngraph::AxisSet& input_axes = ngraph::AxisSet{},
const ngraph::AxisSet& filter_axes = ngraph::AxisSet{},
......
......@@ -26,10 +26,10 @@
using namespace std;
using namespace ngraph;
op::Select::Select(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const shared_ptr<Node>& arg2)
: Op("Select", check_single_output_args({arg0, arg1, arg2}))
const string op::Select::type_name{"Select"};
op::Select::Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2)
: Op({arg0, arg1, arg2})
{
constructor_validate_and_infer_types();
}
......
......@@ -40,14 +40,17 @@ namespace ngraph
class Select : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a selection operation.
Select() = default;
/// \brief Constructs a selection operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param arg2 Node that produces the third input tensor.
Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2);
Select(const Output<Node>& arg0, const Output<Node>& arg1, const Output<Node>& arg2);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,10 +20,12 @@
using namespace std;
using namespace ngraph;
op::Subtract::Subtract(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Subtract::type_name{"Subtract"};
op::Subtract::Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Subtract", arg0, arg1, autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......@@ -50,8 +52,7 @@ void op::Subtract::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, -delta);
}
shared_ptr<ngraph::Node> ngraph::operator-(const shared_ptr<ngraph::Node> arg0,
const shared_ptr<ngraph::Node> arg1)
shared_ptr<ngraph::Node> ngraph::operator-(const Output<Node> arg0, const Output<Node> arg1)
{
return make_shared<ngraph::op::Subtract>(arg0, arg1);
}
......@@ -26,13 +26,17 @@ namespace ngraph
class Subtract : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Subtract() = default;
/// \brief Constructs an subtraction operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Subtract(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......@@ -42,6 +46,6 @@ namespace ngraph
const NodeVector& deltas) override;
};
}
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1);
std::shared_ptr<ngraph::Node> operator-(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment