Unverified Commit 5b59c095 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert PlaidML ops to new form (#3344)

* Convert PlaidML ops to new form

* style
parent a6c2f23b
...@@ -36,12 +36,14 @@ namespace ngraph ...@@ -36,12 +36,14 @@ namespace ngraph
} }
} }
const std::string ngraph::runtime::plaidml::op::Convolution::type_name{"PlaidMLConvolution"};
ngraph::runtime::plaidml::op::Convolution::Convolution(std::shared_ptr<ngraph::op::Convolution> src, ngraph::runtime::plaidml::op::Convolution::Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args, const OutputVector& args,
AxisVector data_axes, AxisVector data_axes,
AxisVector filters_axes, AxisVector filters_axes,
AxisVector output_axes) AxisVector output_axes)
: Op{"PlaidMLConvolution", args} : Op{args}
, m_src{std::move(src)} , m_src{std::move(src)}
, m_data_axes{std::move(data_axes)} , m_data_axes{std::move(data_axes)}
, m_filters_axes{std::move(filters_axes)} , m_filters_axes{std::move(filters_axes)}
...@@ -69,16 +71,19 @@ std::shared_ptr<ngraph::Node> ...@@ -69,16 +71,19 @@ std::shared_ptr<ngraph::Node>
throw ngraph_error{"PlaidMLConvolution requires two inputs (data and filters)"}; throw ngraph_error{"PlaidMLConvolution requires two inputs (data and filters)"};
} }
return std::make_shared<Convolution>( return std::make_shared<Convolution>(
m_src, new_args, m_data_axes, m_filters_axes, m_output_axes); m_src, as_output_vector(new_args), m_data_axes, m_filters_axes, m_output_axes);
} }
const std::string ngraph::runtime::plaidml::op::ConvolutionBackpropData::type_name{
"PlaidMLConvolutionBackpropData"};
ngraph::runtime::plaidml::op::ConvolutionBackpropData::ConvolutionBackpropData( ngraph::runtime::plaidml::op::ConvolutionBackpropData::ConvolutionBackpropData(
std::shared_ptr<ngraph::op::ConvolutionBackpropData> src, std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args, const OutputVector& args,
AxisVector filters_axes, AxisVector filters_axes,
AxisVector output_axes, AxisVector output_axes,
AxisVector data_axes) AxisVector data_axes)
: Op{"PlaidMLConvolutionBackpropData", args} : Op{args}
, m_src{std::move(src)} , m_src{std::move(src)}
, m_filters_axes{std::move(filters_axes)} , m_filters_axes{std::move(filters_axes)}
, m_output_axes{std::move(output_axes)} , m_output_axes{std::move(output_axes)}
...@@ -107,16 +112,19 @@ std::shared_ptr<ngraph::Node> ...@@ -107,16 +112,19 @@ std::shared_ptr<ngraph::Node>
throw ngraph_error{"PlaidMLConvolutionBackpropData requires two inputs (data and output)"}; throw ngraph_error{"PlaidMLConvolutionBackpropData requires two inputs (data and output)"};
} }
return std::make_shared<ConvolutionBackpropData>( return std::make_shared<ConvolutionBackpropData>(
m_src, new_args, m_filters_axes, m_output_axes, m_data_axes); m_src, as_output_vector(new_args), m_filters_axes, m_output_axes, m_data_axes);
} }
const std::string ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::type_name{
"PlaidMLConvolutionBackpropFilters"};
ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src, std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args, const OutputVector& args,
AxisVector data_axes, AxisVector data_axes,
AxisVector output_axes, AxisVector output_axes,
AxisVector filters_axes) AxisVector filters_axes)
: Op{"PlaidMLConvolutionBackpropFilters", args} : Op{args}
, m_src{std::move(src)} , m_src{std::move(src)}
, m_data_axes{std::move(data_axes)} , m_data_axes{std::move(data_axes)}
, m_output_axes{std::move(output_axes)} , m_output_axes{std::move(output_axes)}
...@@ -146,7 +154,7 @@ std::shared_ptr<ngraph::Node> ...@@ -146,7 +154,7 @@ std::shared_ptr<ngraph::Node>
"PlaidMLConvolutionBackpropFilters requires two inputs (filters and output)"}; "PlaidMLConvolutionBackpropFilters requires two inputs (filters and output)"};
} }
return std::make_shared<ConvolutionBackpropFilters>( return std::make_shared<ConvolutionBackpropFilters>(
m_src, new_args, m_data_axes, m_output_axes, m_filters_axes); m_src, as_output_vector(new_args), m_data_axes, m_output_axes, m_filters_axes);
} }
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation. // Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
......
...@@ -39,8 +39,11 @@ namespace ngraph ...@@ -39,8 +39,11 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Convolution(std::shared_ptr<ngraph::op::Convolution> src, Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args, const OutputVector& args,
AxisVector data_axes, AxisVector data_axes,
AxisVector filters_axes, AxisVector filters_axes,
AxisVector output_axes); AxisVector output_axes);
...@@ -63,8 +66,11 @@ private: ...@@ -63,8 +66,11 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src, ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args, const OutputVector& args,
AxisVector filters_axes, AxisVector filters_axes,
AxisVector output_axes, AxisVector output_axes,
AxisVector data_axes); AxisVector data_axes);
...@@ -87,8 +93,11 @@ private: ...@@ -87,8 +93,11 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src, ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args, const OutputVector& args,
AxisVector data_axes, AxisVector data_axes,
AxisVector output_axes, AxisVector output_axes,
AxisVector filters_axes); AxisVector filters_axes);
......
...@@ -30,9 +30,11 @@ namespace ngraph ...@@ -30,9 +30,11 @@ namespace ngraph
} }
} }
ngraph::runtime::plaidml::op::ImplicitBroadcast::ImplicitBroadcast(std::shared_ptr<Node> input, const std::string ngraph::runtime::plaidml::op::ImplicitBroadcast::type_name{"ImplicitBroadcast"};
ngraph::runtime::plaidml::op::ImplicitBroadcast::ImplicitBroadcast(const Output<Node>& input,
const Shape& shape) const Shape& shape)
: Op{"ImplicitBroadcast", {input}} : Op{{input}}
, m_shape{shape} , m_shape{shape}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -40,7 +40,10 @@ namespace ngraph ...@@ -40,7 +40,10 @@ namespace ngraph
class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op
{ {
public: public:
ImplicitBroadcast(std::shared_ptr<Node> input, const Shape& shape); NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ImplicitBroadcast(const Output<Node>& input, const Shape& shape);
void validate_and_infer_types() final; void validate_and_infer_types() final;
......
...@@ -28,22 +28,24 @@ namespace ngraph ...@@ -28,22 +28,24 @@ namespace ngraph
} }
} }
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg, const std::string ngraph::runtime::plaidml::op::Replicate::type_name{"Replicate"};
ngraph::runtime::plaidml::op::Replicate::Replicate(const Output<Node>& arg,
std::size_t replication_axis, std::size_t replication_axis,
std::size_t replication_count) std::size_t replication_count)
: Op{"Replicate", NodeVector{arg}} : Op{{arg}}
, m_replication_axes(arg->get_shape().size(), 1) , m_replication_axes(arg.get_shape().size(), 1)
{ {
m_replication_axes.at(replication_axis) = replication_count; m_replication_axes.at(replication_axis) = replication_count;
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg, ngraph::runtime::plaidml::op::Replicate::Replicate(const Output<Node>& arg,
std::vector<std::size_t> replication_axes) std::vector<std::size_t> replication_axes)
: Op{"Replicate", NodeVector{arg}} : Op{{arg}}
, m_replication_axes(std::move(replication_axes)) , m_replication_axes(std::move(replication_axes))
{ {
if (arg->get_shape().size() != m_replication_axes.size()) if (arg.get_shape().size() != m_replication_axes.size())
{ {
throw ngraph_error{"Replicate requires compatible axes dimensions"}; throw ngraph_error{"Replicate requires compatible axes dimensions"};
} }
......
...@@ -39,11 +39,12 @@ namespace ngraph ...@@ -39,11 +39,12 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{ {
public: public:
Replicate(std::shared_ptr<Node> arg, NGRAPH_API
std::size_t replication_axis, static const std::string type_name;
std::size_t replication_count); const std::string& description() const override { return type_name; }
Replicate(const Output<Node>& arg, std::size_t replication_axis, std::size_t replication_count);
Replicate(std::shared_ptr<Node> arg, std::vector<std::size_t> replication_axes); Replicate(const Output<Node>& arg, std::vector<std::size_t> replication_axes);
void validate_and_infer_types() final; void validate_and_infer_types() final;
......
...@@ -30,9 +30,11 @@ namespace ngraph ...@@ -30,9 +30,11 @@ namespace ngraph
} }
} }
const std::string ngraph::runtime::plaidml::op::Winograd::type_name{"Winograd"};
ngraph::runtime::plaidml::op::Winograd::Winograd(std::shared_ptr<plaidml::op::Convolution> conv, ngraph::runtime::plaidml::op::Winograd::Winograd(std::shared_ptr<plaidml::op::Convolution> conv,
const NodeVector& args) const OutputVector& args)
: Op{"Winograd", args} : Op{args}
, m_conv{std::move(conv)} , m_conv{std::move(conv)}
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -50,7 +52,7 @@ std::shared_ptr<ngraph::Node> ...@@ -50,7 +52,7 @@ std::shared_ptr<ngraph::Node>
{ {
throw ngraph_error{"Winograd requires five inputs (data, filters, A, B, and G)"}; throw ngraph_error{"Winograd requires five inputs (data, filters, A, B, and G)"};
} }
return std::make_shared<Winograd>(m_conv, new_args); return std::make_shared<Winograd>(m_conv, as_output_vector(new_args));
} }
void ngraph::runtime::plaidml::ImplWinograd::Apply() void ngraph::runtime::plaidml::ImplWinograd::Apply()
......
...@@ -38,7 +38,10 @@ namespace ngraph ...@@ -38,7 +38,10 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{ {
public: public:
Winograd(std::shared_ptr<Convolution> conv, const NodeVector& args); NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Winograd(std::shared_ptr<Convolution> conv, const OutputVector& args);
void validate_and_infer_types() final; void validate_and_infer_types() final;
......
...@@ -97,7 +97,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -97,7 +97,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{ {
replace_node(target, replace_node(target,
std::make_shared<plaidml::op::Convolution>(conv, std::make_shared<plaidml::op::Convolution>(conv,
NodeVector{lhs, rhs}, OutputVector{lhs, rhs},
std::move(lhs_axes), std::move(lhs_axes),
std::move(rhs_axes), std::move(rhs_axes),
std::move(out_axes))); std::move(out_axes)));
...@@ -113,7 +113,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -113,7 +113,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
replace_node( replace_node(
target, target,
std::make_shared<plaidml::op::ConvolutionBackpropData>(conv_bp_data, std::make_shared<plaidml::op::ConvolutionBackpropData>(conv_bp_data,
NodeVector{lhs, rhs}, OutputVector{lhs, rhs},
std::move(lhs_axes), std::move(lhs_axes),
std::move(rhs_axes), std::move(rhs_axes),
std::move(out_axes))); std::move(out_axes)));
...@@ -126,13 +126,13 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -126,13 +126,13 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropFilters>(node); std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropFilters>(node);
if (conv_bp_filters) if (conv_bp_filters)
{ {
replace_node( replace_node(target,
target, std::make_shared<plaidml::op::ConvolutionBackpropFilters>(
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(conv_bp_filters, conv_bp_filters,
NodeVector{lhs, rhs}, OutputVector{lhs, rhs},
std::move(lhs_axes), std::move(lhs_axes),
std::move(rhs_axes), std::move(rhs_axes),
std::move(out_axes))); std::move(out_axes)));
return true; return true;
} }
} }
......
...@@ -113,7 +113,11 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd() ...@@ -113,7 +113,11 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
auto callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root()); auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments(); OutputVector args;
for (auto input : conv->inputs())
{
args.push_back(input.get_source_output());
}
std::shared_ptr<ngraph::op::Constant> a; std::shared_ptr<ngraph::op::Constant> a;
std::shared_ptr<ngraph::op::Constant> b; std::shared_ptr<ngraph::op::Constant> b;
std::shared_ptr<ngraph::op::Constant> g; std::shared_ptr<ngraph::op::Constant> g;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment