Unverified Commit 5b59c095 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert PlaidML ops to new form (#3344)

* Convert PlaidML ops to new form

* style
parent a6c2f23b
......@@ -36,12 +36,14 @@ namespace ngraph
}
}
const std::string ngraph::runtime::plaidml::op::Convolution::type_name{"PlaidMLConvolution"};
ngraph::runtime::plaidml::op::Convolution::Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes)
: Op{"PlaidMLConvolution", args}
: Op{args}
, m_src{std::move(src)}
, m_data_axes{std::move(data_axes)}
, m_filters_axes{std::move(filters_axes)}
......@@ -69,16 +71,19 @@ std::shared_ptr<ngraph::Node>
throw ngraph_error{"PlaidMLConvolution requires two inputs (data and filters)"};
}
return std::make_shared<Convolution>(
m_src, new_args, m_data_axes, m_filters_axes, m_output_axes);
m_src, as_output_vector(new_args), m_data_axes, m_filters_axes, m_output_axes);
}
const std::string ngraph::runtime::plaidml::op::ConvolutionBackpropData::type_name{
"PlaidMLConvolutionBackpropData"};
ngraph::runtime::plaidml::op::ConvolutionBackpropData::ConvolutionBackpropData(
std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
const OutputVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes)
: Op{"PlaidMLConvolutionBackpropData", args}
: Op{args}
, m_src{std::move(src)}
, m_filters_axes{std::move(filters_axes)}
, m_output_axes{std::move(output_axes)}
......@@ -107,16 +112,19 @@ std::shared_ptr<ngraph::Node>
throw ngraph_error{"PlaidMLConvolutionBackpropData requires two inputs (data and output)"};
}
return std::make_shared<ConvolutionBackpropData>(
m_src, new_args, m_filters_axes, m_output_axes, m_data_axes);
m_src, as_output_vector(new_args), m_filters_axes, m_output_axes, m_data_axes);
}
const std::string ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::type_name{
"PlaidMLConvolutionBackpropFilters"};
ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes)
: Op{"PlaidMLConvolutionBackpropFilters", args}
: Op{args}
, m_src{std::move(src)}
, m_data_axes{std::move(data_axes)}
, m_output_axes{std::move(output_axes)}
......@@ -146,7 +154,7 @@ std::shared_ptr<ngraph::Node>
"PlaidMLConvolutionBackpropFilters requires two inputs (filters and output)"};
}
return std::make_shared<ConvolutionBackpropFilters>(
m_src, new_args, m_data_axes, m_output_axes, m_filters_axes);
m_src, as_output_vector(new_args), m_data_axes, m_output_axes, m_filters_axes);
}
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
......
......@@ -39,8 +39,11 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes);
......@@ -63,8 +66,11 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
const OutputVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes);
......@@ -87,8 +93,11 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes);
......
......@@ -30,9 +30,11 @@ namespace ngraph
}
}
ngraph::runtime::plaidml::op::ImplicitBroadcast::ImplicitBroadcast(std::shared_ptr<Node> input,
const std::string ngraph::runtime::plaidml::op::ImplicitBroadcast::type_name{"ImplicitBroadcast"};
ngraph::runtime::plaidml::op::ImplicitBroadcast::ImplicitBroadcast(const Output<Node>& input,
const Shape& shape)
: Op{"ImplicitBroadcast", {input}}
: Op{{input}}
, m_shape{shape}
{
constructor_validate_and_infer_types();
......
......@@ -40,7 +40,10 @@ namespace ngraph
class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op
{
public:
ImplicitBroadcast(std::shared_ptr<Node> input, const Shape& shape);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ImplicitBroadcast(const Output<Node>& input, const Shape& shape);
void validate_and_infer_types() final;
......
......@@ -28,22 +28,24 @@ namespace ngraph
}
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
const std::string ngraph::runtime::plaidml::op::Replicate::type_name{"Replicate"};
ngraph::runtime::plaidml::op::Replicate::Replicate(const Output<Node>& arg,
std::size_t replication_axis,
std::size_t replication_count)
: Op{"Replicate", NodeVector{arg}}
, m_replication_axes(arg->get_shape().size(), 1)
: Op{{arg}}
, m_replication_axes(arg.get_shape().size(), 1)
{
m_replication_axes.at(replication_axis) = replication_count;
constructor_validate_and_infer_types();
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
ngraph::runtime::plaidml::op::Replicate::Replicate(const Output<Node>& arg,
std::vector<std::size_t> replication_axes)
: Op{"Replicate", NodeVector{arg}}
: Op{{arg}}
, m_replication_axes(std::move(replication_axes))
{
if (arg->get_shape().size() != m_replication_axes.size())
if (arg.get_shape().size() != m_replication_axes.size())
{
throw ngraph_error{"Replicate requires compatible axes dimensions"};
}
......
......@@ -39,11 +39,12 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{
public:
Replicate(std::shared_ptr<Node> arg,
std::size_t replication_axis,
std::size_t replication_count);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Replicate(const Output<Node>& arg, std::size_t replication_axis, std::size_t replication_count);
Replicate(std::shared_ptr<Node> arg, std::vector<std::size_t> replication_axes);
Replicate(const Output<Node>& arg, std::vector<std::size_t> replication_axes);
void validate_and_infer_types() final;
......
......@@ -30,9 +30,11 @@ namespace ngraph
}
}
const std::string ngraph::runtime::plaidml::op::Winograd::type_name{"Winograd"};
ngraph::runtime::plaidml::op::Winograd::Winograd(std::shared_ptr<plaidml::op::Convolution> conv,
const NodeVector& args)
: Op{"Winograd", args}
const OutputVector& args)
: Op{args}
, m_conv{std::move(conv)}
{
constructor_validate_and_infer_types();
......@@ -50,7 +52,7 @@ std::shared_ptr<ngraph::Node>
{
throw ngraph_error{"Winograd requires five inputs (data, filters, A, B, and G)"};
}
return std::make_shared<Winograd>(m_conv, new_args);
return std::make_shared<Winograd>(m_conv, as_output_vector(new_args));
}
void ngraph::runtime::plaidml::ImplWinograd::Apply()
......
......@@ -38,7 +38,10 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{
public:
Winograd(std::shared_ptr<Convolution> conv, const NodeVector& args);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Winograd(std::shared_ptr<Convolution> conv, const OutputVector& args);
void validate_and_infer_types() final;
......
......@@ -97,7 +97,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{
replace_node(target,
std::make_shared<plaidml::op::Convolution>(conv,
NodeVector{lhs, rhs},
OutputVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
......@@ -113,7 +113,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropData>(conv_bp_data,
NodeVector{lhs, rhs},
OutputVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
......@@ -126,10 +126,10 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropFilters>(node);
if (conv_bp_filters)
{
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(conv_bp_filters,
NodeVector{lhs, rhs},
replace_node(target,
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(
conv_bp_filters,
OutputVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
......
......@@ -113,7 +113,11 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
auto callback = [](pattern::Matcher& m) {
auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments();
OutputVector args;
for (auto input : conv->inputs())
{
args.push_back(input.get_source_output());
}
std::shared_ptr<ngraph::op::Constant> a;
std::shared_ptr<ngraph::op::Constant> b;
std::shared_ptr<ngraph::op::Constant> g;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment