Commit f5b2d581 authored by Rob Earhart's avatar Rob Earhart Committed by Scott Cyphers

API cleanup & performance passes (#2242)

* nGraph compat updates

* Simplify op structure

* Add zero-dim elimination pass

* Add logical data conversion pass

* Add graphviz support

* Add implicit broadcast pass

* Elide unnecessary reshape ops

* Merge reshapes into convolutions

* Add winograd convolution support

* Allow three-input maxpool backprop

* Add concat elision

* Combine replication operations

* Elide unneeded replicates

* Style update
parent c11644ec
...@@ -41,10 +41,21 @@ set(SRC ...@@ -41,10 +41,21 @@ set(SRC
plaidml_ops_pool.cpp plaidml_ops_pool.cpp
plaidml_ops_reduce.cpp plaidml_ops_reduce.cpp
plaidml_ops_replace_slice.cpp plaidml_ops_replace_slice.cpp
plaidml_ops_replicate.cpp
plaidml_ops_reverse.cpp plaidml_ops_reverse.cpp
plaidml_ops_slice.cpp plaidml_ops_slice.cpp
plaidml_ops_softmax.cpp plaidml_ops_softmax.cpp
plaidml_ops_tile.cpp
plaidml_ops_transcendental.cpp plaidml_ops_transcendental.cpp
plaidml_ops_winograd.cpp
plaidml_pass_concat_elision.cpp
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_reshape_elision.cpp
plaidml_pass_winograd.cpp
plaidml_tensor.cpp plaidml_tensor.cpp
plaidml_translate.cpp plaidml_translate.cpp
) )
......
...@@ -41,10 +41,11 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe ...@@ -41,10 +41,11 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe
&m_config, element_type, shape, "direct_data", memory_pointer); &m_config, element_type, shape, "direct_data", memory_pointer);
} }
bool ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func) std::shared_ptr<ngraph::Function>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
{ {
m_cache.compile(func, &m_compiler); m_cache.compile(func, &m_compiler);
return true; return func;
} }
bool ngraph::runtime::plaidml::PlaidML_Backend::call( bool ngraph::runtime::plaidml::PlaidML_Backend::call(
......
...@@ -46,7 +46,7 @@ public: ...@@ -46,7 +46,7 @@ public:
std::shared_ptr<ngraph::runtime::Tensor> create_tensor( std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final; const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final;
bool compile(std::shared_ptr<Function> func) final; std::shared_ptr<Function> compile(std::shared_ptr<Function> func) final;
bool call(std::shared_ptr<Function> func, bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs, const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
#include "ngraph/except.hpp"
#include "ngraph/runtime/plaidml/plaidml_builder.hpp" #include "ngraph/runtime/plaidml/plaidml_builder.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp" #include "ngraph/runtime/plaidml/plaidml_logger.hpp"
...@@ -467,6 +468,24 @@ ngraph::runtime::plaidml::builder::Input&& ...@@ -467,6 +468,24 @@ ngraph::runtime::plaidml::builder::Input&&
return std::move(*this); return std::move(*this);
} }
void ngraph::runtime::plaidml::builder::Input::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size())
{
throw ngraph_error{"Mismatched shape in input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> dims(m_dims.size());
for (auto dim : m_dims)
{
dims[axes[idx]] = dim;
idx++;
}
m_dims.clear();
m_dims.insert(m_dims.end(), dims.begin(), dims.end());
}
ngraph::runtime::plaidml::builder::Output::Output(std::string name) ngraph::runtime::plaidml::builder::Output::Output(std::string name)
: m_name{std::move(name)} : m_name{std::move(name)}
{ {
...@@ -611,6 +630,28 @@ ngraph::runtime::plaidml::builder::ContractionOutput&& ...@@ -611,6 +630,28 @@ ngraph::runtime::plaidml::builder::ContractionOutput&&
return std::move(*this); return std::move(*this);
} }
void ngraph::runtime::plaidml::builder::ContractionOutput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size() || axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction output transposition"};
}
std::vector<std::string> dims{m_dims.begin(), m_dims.end()};
m_dims.clear();
for (auto idx : axes)
{
m_dims.emplace_back(dims[idx]);
}
std::vector<std::string> indices{m_indices.begin(), m_indices.end()};
m_indices.clear();
for (auto idx : axes)
{
m_indices.emplace_back(indices[idx]);
}
}
ngraph::runtime::plaidml::builder::ContractionInput& ngraph::runtime::plaidml::builder::ContractionInput&
ngraph::runtime::plaidml::builder::ContractionInput::add_indices(std::string prefix, ngraph::runtime::plaidml::builder::ContractionInput::add_indices(std::string prefix,
std::size_t first, std::size_t first,
...@@ -675,6 +716,24 @@ ngraph::runtime::plaidml::builder::ContractionInput&& ...@@ -675,6 +716,24 @@ ngraph::runtime::plaidml::builder::ContractionInput&&
return std::move(*this); return std::move(*this);
} }
void ngraph::runtime::plaidml::builder::ContractionInput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> indices(m_indices.size());
for (auto dim : m_indices)
{
indices[axes[idx]] = dim;
idx++;
}
m_indices.clear();
m_indices.insert(m_indices.end(), indices.begin(), indices.end());
}
ngraph::runtime::plaidml::builder::UnaryContraction::UnaryContraction(std::string agg_op) ngraph::runtime::plaidml::builder::UnaryContraction::UnaryContraction(std::string agg_op)
: m_agg_op{std::move(agg_op)} : m_agg_op{std::move(agg_op)}
{ {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp" #include "ngraph/runtime/plaidml/plaidml_config.hpp"
// Utilities for constructing PlaidML functions. // Utilities for constructing PlaidML functions.
...@@ -136,9 +137,22 @@ public: ...@@ -136,9 +137,22 @@ public:
return std::move(*this); return std::move(*this);
} }
Input& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
Input&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private: private:
friend class Function; friend class Function;
void apply_transpose(const AxisVector& axes);
vertexai::plaidml::variable m_var; vertexai::plaidml::variable m_var;
std::string m_name; std::string m_name;
std::list<std::string> m_dims; std::list<std::string> m_dims;
...@@ -230,9 +244,22 @@ public: ...@@ -230,9 +244,22 @@ public:
return std::move(*this); return std::move(*this);
} }
ContractionOutput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionOutput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private: private:
friend class Function; friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name; std::string m_name;
std::list<std::string> m_indices; std::list<std::string> m_indices;
std::list<std::string> m_dims; std::list<std::string> m_dims;
...@@ -268,9 +295,22 @@ public: ...@@ -268,9 +295,22 @@ public:
return std::move(*this); return std::move(*this);
} }
ContractionInput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionInput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private: private:
friend class Function; friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name; std::string m_name;
std::list<std::string> m_indices; std::list<std::string> m_indices;
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp" #include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp" #include "ngraph/pass/any_all_replacement.hpp"
...@@ -25,8 +26,18 @@ ...@@ -25,8 +26,18 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp" #include "ngraph/runtime/plaidml/plaidml_logger.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
namespace namespace
{ {
...@@ -78,15 +89,44 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -78,15 +89,44 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>(); pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
// N.B. We'd like to register ngraph::pass::GetOutputElementElimination, but it breaks BatchNorm // N.B. We'd like to register ngraph::pass::GetOutputElementElimination, but it breaks BatchNorm
// backprop // backprop
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ExplicitLogicals>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReshapeElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (m_config->winograd)
{
pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>();
}
if (!m_config->graphviz.empty())
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>(m_config->graphviz);
}
// N.B. When we rewrite the graph, there are cases where we
// produce nodes that contain validation errors. A good example
// is in the ImplicitBroadcast pass -- after this pass, there may
// be elementwise operations whose inputs are not all the same
// shape.
//
// The caller may wish to perform operations (e.g. clone) on their
// supplied function that will cause validation to occur. So
// before we rewrite, we make our own copy of the function.
func = clone_function(*func);
// Apply passes.
pass_manager.run_passes(func); pass_manager.run_passes(func);
// Compile the resulting function.
Build b; Build b;
build(std::move(func), &b); build(std::move(func), &b);
return std::make_shared<CompiledFunction>(std::move(b)); return std::make_shared<CompiledFunction>(std::move(b));
...@@ -98,7 +138,7 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B ...@@ -98,7 +138,7 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
b->config = m_config; b->config = m_config;
b->func = func; b->func = func;
const auto* op_map = OpImplMap(); const auto* op_map = GlobalOpImplMap();
for (const auto& op_ptr : func->get_ordered_ops()) for (const auto& op_ptr : func->get_ordered_ops())
{ {
...@@ -114,6 +154,6 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B ...@@ -114,6 +154,6 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
std::string{"The PlaidML backend doesn't currently implement the '"} + std::string{"The PlaidML backend doesn't currently implement the '"} +
op->description() + "' operation"}; op->description() + "' operation"};
} }
it->second(b, *op); it->second->Apply(b, op);
} }
} }
...@@ -77,8 +77,10 @@ ngraph::runtime::plaidml::Config ...@@ -77,8 +77,10 @@ ngraph::runtime::plaidml::Config
bool help = false; bool help = false;
bool list = false; bool list = false;
bool debug = false; bool debug = false;
bool winograd = false;
std::size_t device_idx = 0; std::size_t device_idx = 0;
std::string eventlog_config; std::string eventlog_config;
std::string graphviz;
#ifdef NGRAPH_DEBUG_ENABLE #ifdef NGRAPH_DEBUG_ENABLE
debug = true; debug = true;
...@@ -155,7 +157,7 @@ ngraph::runtime::plaidml::Config ...@@ -155,7 +157,7 @@ ngraph::runtime::plaidml::Config
return (oname_end - oname_begin == len) && !strncmp(oname_begin, opt, len); return (oname_end - oname_begin == len) && !strncmp(oname_begin, opt, len);
}; };
auto oval_len = oval_end - oval_begin; std::size_t oval_len = oval_end - oval_begin;
bool has_oval = oval_begin != oname_end; bool has_oval = oval_begin != oname_end;
// N.B. oval_len != 0 => has_oval, but there's no other relationship. // N.B. oval_len != 0 => has_oval, but there's no other relationship.
...@@ -229,6 +231,25 @@ ngraph::runtime::plaidml::Config ...@@ -229,6 +231,25 @@ ngraph::runtime::plaidml::Config
continue; continue;
} }
// Check for visualization (GraphViz output)
if (is_opt("graphviz"))
{
if (!oval_len)
{
throw std::invalid_argument{"PlaidML graphviz requires a value"};
}
graphviz = std::string{oval_begin, oval_len};
continue;
}
// Check for Winograd. (Winograd is sometimes a performance
// boost, but not always, so we make it optional.)
if (is_opt("winograd"))
{
winograd = true;
continue;
}
// Reject unknown options // Reject unknown options
err = true; err = true;
} }
...@@ -236,7 +257,7 @@ ngraph::runtime::plaidml::Config ...@@ -236,7 +257,7 @@ ngraph::runtime::plaidml::Config
constexpr char help_text[] = constexpr char help_text[] =
"PlaidML Backend Specification: \"" "PlaidML Backend Specification: \""
"PlaidML[:[device_index][,debug][,help][,list_devices][," "PlaidML[:[device_index][,debug][,help][,list_devices][,"
"eventlog=<filename>]]\". For example: \"PlaidML\", \"" "eventlog=<filename>][,graphviz=<filename>][,winograd]]\". For example: \"PlaidML\", \""
"PlaidML:0,list_devices\""; "PlaidML:0,list_devices\"";
if (err) if (err)
{ {
...@@ -269,5 +290,9 @@ ngraph::runtime::plaidml::Config ...@@ -269,5 +290,9 @@ ngraph::runtime::plaidml::Config
result.debug = debug; result.debug = debug;
result.graphviz = graphviz;
result.winograd = winograd;
return result; return result;
} }
...@@ -39,4 +39,6 @@ struct ngraph::runtime::plaidml::Config ...@@ -39,4 +39,6 @@ struct ngraph::runtime::plaidml::Config
std::shared_ptr<vertexai::ctx> ctx; std::shared_ptr<vertexai::ctx> ctx;
std::shared_ptr<vertexai::plaidml::device> dev; std::shared_ptr<vertexai::plaidml::device> dev;
bool debug; bool debug;
bool winograd;
std::string graphviz;
}; };
...@@ -171,7 +171,7 @@ ngraph::runtime::plaidml::ConvPoolFormatter::ConvPoolFormatter( ...@@ -171,7 +171,7 @@ ngraph::runtime::plaidml::ConvPoolFormatter::ConvPoolFormatter(
} }
ngraph::runtime::plaidml::builder::Input ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var) ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var) const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -191,7 +191,7 @@ ngraph::runtime::plaidml::builder::Input ...@@ -191,7 +191,7 @@ ngraph::runtime::plaidml::builder::Input
} }
ngraph::runtime::plaidml::builder::Input ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var) ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var) const
{ {
if (m_deriv == DerivType::Data && m_op == OpType::Conv) if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{ {
...@@ -216,7 +216,7 @@ ngraph::runtime::plaidml::builder::Input ...@@ -216,7 +216,7 @@ ngraph::runtime::plaidml::builder::Input
} }
ngraph::runtime::plaidml::builder::Input ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var) ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var) const
{ {
if (m_deriv == DerivType::None) if (m_deriv == DerivType::None)
{ {
...@@ -240,7 +240,7 @@ ngraph::runtime::plaidml::builder::Input ...@@ -240,7 +240,7 @@ ngraph::runtime::plaidml::builder::Input
} }
ngraph::runtime::plaidml::builder::Output ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header() ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header() const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -254,7 +254,7 @@ ngraph::runtime::plaidml::builder::Output ...@@ -254,7 +254,7 @@ ngraph::runtime::plaidml::builder::Output
} }
ngraph::runtime::plaidml::builder::Output ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header() ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header() const
{ {
if (m_deriv != DerivType::Data) if (m_deriv != DerivType::Data)
{ {
...@@ -272,7 +272,7 @@ ngraph::runtime::plaidml::builder::Output ...@@ -272,7 +272,7 @@ ngraph::runtime::plaidml::builder::Output
} }
ngraph::runtime::plaidml::builder::Output ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header() ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header() const
{ {
if (m_deriv != DerivType::None) if (m_deriv != DerivType::None)
{ {
...@@ -282,7 +282,7 @@ ngraph::runtime::plaidml::builder::Output ...@@ -282,7 +282,7 @@ ngraph::runtime::plaidml::builder::Output
} }
ngraph::runtime::plaidml::builder::ContractionOutput ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body() ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body() const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -307,7 +307,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput ...@@ -307,7 +307,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
} }
ngraph::runtime::plaidml::builder::ContractionOutput ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body() ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body() const
{ {
if (m_deriv != DerivType::Data) if (m_deriv != DerivType::Data)
{ {
...@@ -353,7 +353,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput ...@@ -353,7 +353,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
} }
ngraph::runtime::plaidml::builder::ContractionOutput ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body() ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body() const
{ {
if (m_deriv != DerivType::None && m_op == OpType::Conv) if (m_deriv != DerivType::None && m_op == OpType::Conv)
{ {
...@@ -405,7 +405,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput ...@@ -405,7 +405,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
} }
ngraph::runtime::plaidml::builder::ContractionInput ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body() ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body() const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -425,7 +425,7 @@ ngraph::runtime::plaidml::builder::ContractionInput ...@@ -425,7 +425,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
} }
ngraph::runtime::plaidml::builder::ContractionInput ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body() ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body() const
{ {
if (m_deriv == DerivType::Data && m_op == OpType::Conv) if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{ {
...@@ -449,7 +449,7 @@ ngraph::runtime::plaidml::builder::ContractionInput ...@@ -449,7 +449,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
} }
ngraph::runtime::plaidml::builder::ContractionInput ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body() ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body() const
{ {
if (m_deriv == DerivType::None) if (m_deriv == DerivType::None)
{ {
...@@ -482,7 +482,7 @@ ngraph::runtime::plaidml::builder::ContractionInput ...@@ -482,7 +482,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
} }
ngraph::runtime::plaidml::builder::UnaryContraction ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones() ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones() const
{ {
if (m_op != OpType::AvgPool) if (m_op != OpType::AvgPool)
{ {
...@@ -501,7 +501,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction ...@@ -501,7 +501,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
} }
ngraph::runtime::plaidml::builder::UnaryContraction ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Count() ngraph::runtime::plaidml::ConvPoolFormatter::Count() const
{ {
if (m_op != OpType::AvgPool) if (m_op != OpType::AvgPool)
{ {
...@@ -535,7 +535,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction ...@@ -535,7 +535,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
} }
ngraph::runtime::plaidml::builder::UnaryContraction ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction() ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction() const
{ {
std::string agg_op; std::string agg_op;
switch (m_op) switch (m_op)
...@@ -558,7 +558,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction ...@@ -558,7 +558,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
} }
ngraph::runtime::plaidml::builder::TernaryContraction ngraph::runtime::plaidml::builder::TernaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction() ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction() const
{ {
builder::ContractionOutput output{"DI"}; builder::ContractionOutput output{"DI"};
output.add_indices({n(), c()}).add_dims({N(), C()}); output.add_indices({n(), c()}).add_dims({N(), C()});
...@@ -595,27 +595,27 @@ ngraph::runtime::plaidml::builder::TernaryContraction ...@@ -595,27 +595,27 @@ ngraph::runtime::plaidml::builder::TernaryContraction
.set_third(incoming_deriv); .set_third(incoming_deriv);
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::c() std::string ngraph::runtime::plaidml::ConvPoolFormatter::c() const
{ {
return "c"; return "c";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci() std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci() const
{ {
return "ci"; return "ci";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::co() std::string ngraph::runtime::plaidml::ConvPoolFormatter::co() const
{ {
return "co"; return "co";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::n() std::string ngraph::runtime::plaidml::ConvPoolFormatter::n() const
{ {
return "n"; return "n";
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() const
{ {
if (m_xfs.empty()) if (m_xfs.empty())
{ {
...@@ -629,7 +629,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() ...@@ -629,7 +629,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs()
return m_xfs; return m_xfs;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() const
{ {
if (m_xis.empty()) if (m_xis.empty())
{ {
...@@ -652,7 +652,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() ...@@ -652,7 +652,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis()
return m_xis; return m_xis;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() const
{ {
if (m_xos.empty()) if (m_xos.empty())
{ {
...@@ -666,27 +666,27 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() ...@@ -666,27 +666,27 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos()
return m_xos; return m_xos;
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::C() std::string ngraph::runtime::plaidml::ConvPoolFormatter::C() const
{ {
return "C"; return "C";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI() std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI() const
{ {
return "CI"; return "CI";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO() std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO() const
{ {
return "CO"; return "CO";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::N() std::string ngraph::runtime::plaidml::ConvPoolFormatter::N() const
{ {
return "N"; return "N";
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() const
{ {
if (m_XFs.empty()) if (m_XFs.empty())
{ {
...@@ -707,7 +707,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() ...@@ -707,7 +707,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs()
return m_XFs; return m_XFs;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() const
{ {
if (m_XIs.empty()) if (m_XIs.empty())
{ {
...@@ -728,7 +728,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() ...@@ -728,7 +728,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs()
return m_XIs; return m_XIs;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() const
{ {
if (m_XOs.empty()) if (m_XOs.empty())
{ {
...@@ -765,7 +765,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() ...@@ -765,7 +765,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs()
return m_XOs; return m_XOs;
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() const
{ {
if (m_deriv == DerivType::Filter) if (m_deriv == DerivType::Filter)
{ {
...@@ -774,7 +774,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() ...@@ -774,7 +774,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::F()
return "F"; return "F";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() const
{ {
if (m_deriv == DerivType::Data && m_op == OpType::Conv) if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{ {
...@@ -783,7 +783,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() ...@@ -783,7 +783,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::I()
return "I"; return "I";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::O() std::string ngraph::runtime::plaidml::ConvPoolFormatter::O() const
{ {
if (m_deriv != DerivType::None) if (m_deriv != DerivType::None)
{ {
......
...@@ -72,47 +72,47 @@ public: ...@@ -72,47 +72,47 @@ public:
ConvPoolFormatter::DerivType deriv); ConvPoolFormatter::DerivType deriv);
// Formatted tensors // Formatted tensors
builder::Input F_in_header(vertexai::plaidml::variable var); builder::Input F_in_header(vertexai::plaidml::variable var) const;
builder::Input I_in_header(vertexai::plaidml::variable var); builder::Input I_in_header(vertexai::plaidml::variable var) const;
builder::Input O_in_header(vertexai::plaidml::variable var); builder::Input O_in_header(vertexai::plaidml::variable var) const;
builder::Output F_out_header(); builder::Output F_out_header() const;
builder::Output I_out_header(); builder::Output I_out_header() const;
builder::Output O_out_header(); builder::Output O_out_header() const;
builder::ContractionOutput F_out_body(); builder::ContractionOutput F_out_body() const;
builder::ContractionOutput I_out_body(); builder::ContractionOutput I_out_body() const;
builder::ContractionOutput O_out_body(); builder::ContractionOutput O_out_body() const;
builder::ContractionInput F_in_body(); builder::ContractionInput F_in_body() const;
builder::ContractionInput I_in_body(); builder::ContractionInput I_in_body() const;
builder::ContractionInput O_in_body(); builder::ContractionInput O_in_body() const;
// Special Operations // Special Operations
builder::UnaryContraction Broadcast_Ones(); builder::UnaryContraction Broadcast_Ones() const;
builder::UnaryContraction Count(); builder::UnaryContraction Count() const;
builder::UnaryContraction PoolContraction(); builder::UnaryContraction PoolContraction() const;
builder::TernaryContraction PoolDerivContraction(); builder::TernaryContraction PoolDerivContraction() const;
// Index names / formulas // Index names / formulas
std::string c(); std::string c() const;
std::string ci(); std::string ci() const;
std::string co(); std::string co() const;
std::string n(); std::string n() const;
std::vector<std::string> xfs(); std::vector<std::string> xfs() const;
std::vector<std::string> xis(); std::vector<std::string> xis() const;
std::vector<std::string> xos(); std::vector<std::string> xos() const;
// Dimension names / formulas // Dimension names / formulas
std::string C(); std::string C() const;
std::string CI(); std::string CI() const;
std::string CO(); std::string CO() const;
std::string N(); std::string N() const;
std::vector<std::string> XFs(); std::vector<std::string> XFs() const;
std::vector<std::string> XIs(); std::vector<std::string> XIs() const;
std::vector<std::string> XOs(); std::vector<std::string> XOs() const;
// Tensor names // Tensor names
std::string F(); std::string F() const;
std::string I(); std::string I() const;
std::string O(); std::string O() const;
private: private:
std::size_t m_rank; std::size_t m_rank;
...@@ -126,10 +126,10 @@ private: ...@@ -126,10 +126,10 @@ private:
DerivType m_deriv = DerivType::None; DerivType m_deriv = DerivType::None;
ngraph::Shape m_filters_shape; ngraph::Shape m_filters_shape;
ngraph::Shape m_data_batch_shape; ngraph::Shape m_data_batch_shape;
std::vector<std::string> m_xfs; mutable std::vector<std::string> m_xfs;
std::vector<std::string> m_xis; mutable std::vector<std::string> m_xis;
std::vector<std::string> m_xos; mutable std::vector<std::string> m_xos;
std::vector<std::string> m_XFs; mutable std::vector<std::string> m_XFs;
std::vector<std::string> m_XIs; mutable std::vector<std::string> m_XIs;
std::vector<std::string> m_XOs; mutable std::vector<std::string> m_XOs;
}; };
...@@ -16,20 +16,8 @@ ...@@ -16,20 +16,8 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph ngraph::runtime::plaidml::OpImplMap* ngraph::runtime::plaidml::GlobalOpImplMap()
{ {
namespace runtime static OpImplMap op_impl_map;
{
namespace plaidml
{
std::unordered_map<std::type_index, std::function<void(Build*, const ngraph::Node&)>>*
OpImplMap()
{
static std::unordered_map<std::type_index,
std::function<void(Build*, const ngraph::Node&)>>
op_impl_map;
return &op_impl_map; return &op_impl_map;
}
}
}
} }
This diff is collapsed.
...@@ -34,10 +34,26 @@ namespace ngraph ...@@ -34,10 +34,26 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Abs performs a simple elementwise absolute value. NGRAPH_PLAIDML_OP_CLASS(ImplAbs, OpImpl<op::Abs>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplAdd, OpImpl<op::Add>);
void Impl<op::Abs>::operator()() NGRAPH_PLAIDML_OP_CLASS(ImplCeiling, OpImpl<op::Ceiling>);
{ NGRAPH_PLAIDML_OP_CLASS(ImplDivide, OpImpl<op::Divide>);
NGRAPH_PLAIDML_OP_CLASS(ImplFloor, OpImpl<op::Floor>);
NGRAPH_PLAIDML_OP_CLASS(ImplMultiply, OpImpl<op::Multiply>);
NGRAPH_PLAIDML_OP_CLASS(ImplNegative, OpImpl<op::Negative>);
NGRAPH_PLAIDML_OP_CLASS(ImplRelu, OpImpl<op::Relu>);
NGRAPH_PLAIDML_OP_CLASS(ImplReluBackprop, OpImpl<op::ReluBackprop>);
NGRAPH_PLAIDML_OP_CLASS(ImplSigmoid, OpImpl<op::Sigmoid>);
NGRAPH_PLAIDML_OP_CLASS(ImplSigmoidBackprop, OpImpl<op::SigmoidBackprop>);
NGRAPH_PLAIDML_OP_CLASS(ImplSign, OpImpl<op::Sign>);
NGRAPH_PLAIDML_OP_CLASS(ImplSubtract, OpImpl<op::Subtract>);
}
}
}
// Abs performs a simple elementwise absolute value.
void ngraph::runtime::plaidml::ImplAbs::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -45,12 +61,11 @@ namespace ngraph ...@@ -45,12 +61,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "abs(I)"}) .add(builder::Elementwise{"O", "abs(I)"})
.finalize()); .finalize());
} }
// Add performs a simple elementwise addition. // Add performs a simple elementwise addition.
template <> void ngraph::runtime::plaidml::ImplAdd::Apply()
void Impl<op::Add>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -59,12 +74,11 @@ namespace ngraph ...@@ -59,12 +74,11 @@ namespace ngraph
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A + B"}) .add(builder::Elementwise{"C", "A + B"})
.finalize()); .finalize());
} }
// Ceiling performs a simple elementwise ceiling. // Ceiling performs a simple elementwise ceiling.
template <> void ngraph::runtime::plaidml::ImplCeiling::Apply()
void Impl<op::Ceiling>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -72,12 +86,11 @@ namespace ngraph ...@@ -72,12 +86,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "ceil(I)"}) .add(builder::Elementwise{"O", "ceil(I)"})
.finalize()); .finalize());
} }
// Divide performs a simple elementwise division. // Divide performs a simple elementwise division.
template <> void ngraph::runtime::plaidml::ImplDivide::Apply()
void Impl<op::Divide>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -86,12 +99,11 @@ namespace ngraph ...@@ -86,12 +99,11 @@ namespace ngraph
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A / B"}) .add(builder::Elementwise{"C", "A / B"})
.finalize()); .finalize());
} }
// Floor performs a simple elementwise floor. // Floor performs a simple elementwise floor.
template <> void ngraph::runtime::plaidml::ImplFloor::Apply()
void Impl<op::Floor>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -99,12 +111,11 @@ namespace ngraph ...@@ -99,12 +111,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "floor(I)"}) .add(builder::Elementwise{"O", "floor(I)"})
.finalize()); .finalize());
} }
// Multiply performs a simple elementwise multiplication. // Multiply performs a simple elementwise multiplication.
template <> void ngraph::runtime::plaidml::ImplMultiply::Apply()
void Impl<op::Multiply>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -113,12 +124,11 @@ namespace ngraph ...@@ -113,12 +124,11 @@ namespace ngraph
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A * B"}) .add(builder::Elementwise{"C", "A * B"})
.finalize()); .finalize());
} }
// Negative performs a simple elementwise negation. // Negative performs a simple elementwise negation.
template <> void ngraph::runtime::plaidml::ImplNegative::Apply()
void Impl<op::Negative>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -126,12 +136,11 @@ namespace ngraph ...@@ -126,12 +136,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "-I"}) .add(builder::Elementwise{"O", "-I"})
.finalize()); .finalize());
} }
// Relu implements a simple elementwise rectified linear unit. // Relu implements a simple elementwise rectified linear unit.
template <> void ngraph::runtime::plaidml::ImplRelu::Apply()
void Impl<op::Relu>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -139,12 +148,11 @@ namespace ngraph ...@@ -139,12 +148,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "relu(I)"}) .add(builder::Elementwise{"O", "relu(I)"})
.finalize()); .finalize());
} }
// ReluBackprop computes the derivative of Relu. // ReluBackprop computes the derivative of Relu.
template <> void ngraph::runtime::plaidml::ImplReluBackprop::Apply()
void Impl<op::ReluBackprop>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -153,12 +161,11 @@ namespace ngraph ...@@ -153,12 +161,11 @@ namespace ngraph
.add(builder::Output{"DI"}) .add(builder::Output{"DI"})
.add(builder::Elementwise{"DI", "I > 0 ? DO : 0"}) .add(builder::Elementwise{"DI", "I > 0 ? DO : 0"})
.finalize()); .finalize());
} }
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X)) // Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template <> void ngraph::runtime::plaidml::ImplSigmoid::Apply()
void Impl<op::Sigmoid>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -166,13 +173,12 @@ namespace ngraph ...@@ -166,13 +173,12 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "1/(1+exp(-I))"}) .add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.finalize()); .finalize());
} }
// SigmoidBackprop computes the derivative of a standard ML // SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X)) // sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template <> void ngraph::runtime::plaidml::ImplSigmoidBackprop::Apply()
void Impl<op::SigmoidBackprop>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -182,27 +188,24 @@ namespace ngraph ...@@ -182,27 +188,24 @@ namespace ngraph
.add(builder::Elementwise{"O", "1/(1+exp(-I))"}) .add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.add(builder::Elementwise{"DI", "DO * O * (1-O)"}) .add(builder::Elementwise{"DI", "DO * O * (1-O)"})
.finalize()); .finalize());
} }
// Sign returns the sign of an element. // Sign returns the sign of an element.
template <> void ngraph::runtime::plaidml::ImplSign::Apply()
void Impl<op::Sign>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0), "I"}) .add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"S", "(I < 0) ? -1 : ((I > 0) ? 1 : 0)"}) .add(builder::Elementwise{"S", "(I < 0) ? -1 : ((I > 0) ? 1 : 0)"})
.add(builder::Elementwise{ .add(builder::Elementwise{"O", tile_converter("S", op().get_element_type())})
"O", tile_converter("S", op().get_element_type())})
.finalize()); .finalize());
} }
// Subtract performs a simple elementwise subtraction. // Subtract performs a simple elementwise subtraction.
template <> void ngraph::runtime::plaidml::ImplSubtract::Apply()
void Impl<op::Subtract>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -211,24 +214,4 @@ namespace ngraph ...@@ -211,24 +214,4 @@ namespace ngraph
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A - B"}) .add(builder::Elementwise{"C", "A - B"})
.finalize()); .finalize());
}
namespace
{
Impl<op::Abs>::Registration register_abs;
Impl<op::Add>::Registration register_add;
Impl<op::Ceiling>::Registration register_ceiling;
Impl<op::Divide>::Registration register_divide;
Impl<op::Floor>::Registration register_floor;
Impl<op::Multiply>::Registration register_multiply;
Impl<op::Negative>::Registration register_negative;
Impl<op::Relu>::Registration register_relu;
Impl<op::ReluBackprop>::Registration register_relu_backprop;
Impl<op::Sigmoid>::Registration register_sigmoid;
Impl<op::SigmoidBackprop>::Registration register_sigmoid_backprop;
Impl<op::Sign>::Registration register_sign;
Impl<op::Subtract>::Registration register_subtract;
}
}
}
} }
...@@ -30,25 +30,34 @@ namespace ngraph ...@@ -30,25 +30,34 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Equal performs a simple elementwise equality. NGRAPH_PLAIDML_OP_CLASS(ImplEqual, OpImpl<op::Equal>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplGreater, OpImpl<op::Greater>);
void Impl<op::Equal>::operator()() NGRAPH_PLAIDML_OP_CLASS(ImplGreaterEq, OpImpl<op::GreaterEq>);
{ NGRAPH_PLAIDML_OP_CLASS(ImplLess, OpImpl<op::Less>);
NGRAPH_PLAIDML_OP_CLASS(ImplLessEq, OpImpl<op::LessEq>);
NGRAPH_PLAIDML_OP_CLASS(ImplMaximum, OpImpl<op::Maximum>);
NGRAPH_PLAIDML_OP_CLASS(ImplMinimum, OpImpl<op::Minimum>);
NGRAPH_PLAIDML_OP_CLASS(ImplNotEqual, OpImpl<op::NotEqual>);
}
}
}
// Equal performs a simple elementwise equality.
void ngraph::runtime::plaidml::ImplEqual::Apply()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"}) .add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A == B"}) .add(builder::Elementwise{"C", "A == B"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// Greater performs a simple elementwise greater-than comparison. // Greater performs a simple elementwise greater-than comparison.
template <> void ngraph::runtime::plaidml::ImplGreater::Apply()
void Impl<op::Greater>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -56,14 +65,12 @@ namespace ngraph ...@@ -56,14 +65,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A > B"}) .add(builder::Elementwise{"C", "A > B"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison. // GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template <> void ngraph::runtime::plaidml::ImplGreaterEq::Apply()
void Impl<op::GreaterEq>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -71,14 +78,12 @@ namespace ngraph ...@@ -71,14 +78,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A >= B"}) .add(builder::Elementwise{"C", "A >= B"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// Less performs a simple elementwise less-than comparison. // Less performs a simple elementwise less-than comparison.
template <> void ngraph::runtime::plaidml::ImplLess::Apply()
void Impl<op::Less>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -86,14 +91,12 @@ namespace ngraph ...@@ -86,14 +91,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A < B"}) .add(builder::Elementwise{"C", "A < B"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// LessEq performs a simple elementwise less-than-or-equal-to comparison. // LessEq performs a simple elementwise less-than-or-equal-to comparison.
template <> void ngraph::runtime::plaidml::ImplLessEq::Apply()
void Impl<op::LessEq>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -101,14 +104,12 @@ namespace ngraph ...@@ -101,14 +104,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A <= B"}) .add(builder::Elementwise{"C", "A <= B"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// Maximum performs a simple elementwise maximum. // Maximum performs a simple elementwise maximum.
template <> void ngraph::runtime::plaidml::ImplMaximum::Apply()
void Impl<op::Maximum>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -117,12 +118,11 @@ namespace ngraph ...@@ -117,12 +118,11 @@ namespace ngraph
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "max(A, B)"}) .add(builder::Elementwise{"C", "max(A, B)"})
.finalize()); .finalize());
} }
// Minimum performs a simple elementwise minimum. // Minimum performs a simple elementwise minimum.
template <> void ngraph::runtime::plaidml::ImplMinimum::Apply()
void Impl<op::Minimum>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -131,34 +131,17 @@ namespace ngraph ...@@ -131,34 +131,17 @@ namespace ngraph
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "min(A, B)"}) .add(builder::Elementwise{"C", "min(A, B)"})
.finalize()); .finalize());
} }
// NotEqual performs a simple elementwise not-equality. // NotEqual performs a simple elementwise not-equality.
template <> void ngraph::runtime::plaidml::ImplNotEqual::Apply()
void Impl<op::NotEqual>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"}) .add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A != B"}) .add(builder::Elementwise{"C", "A != B"})
.finalize(), .finalize());
TensorContents::LOGICAL);
}
namespace
{
Impl<op::Equal>::Registration register_equal;
Impl<op::Greater>::Registration register_greater;
Impl<op::GreaterEq>::Registration register_greater_eq;
Impl<op::Less>::Registration register_less;
Impl<op::LessEq>::Registration register_less_eq;
Impl<op::Maximum>::Registration register_maximum;
Impl<op::Minimum>::Registration register_minimum;
Impl<op::NotEqual>::Registration register_not_equal;
}
}
}
} }
...@@ -23,10 +23,14 @@ namespace ngraph ...@@ -23,10 +23,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Concat views a tensor as a new type. NGRAPH_PLAIDML_OP_CLASS(ImplConcat, OpImpl<op::Concat>);
template <> }
void Impl<op::Concat>::operator()() }
{ }
// Concat views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConcat::Apply()
{
check_outputs(1); check_outputs(1);
auto f = start_tile_function(); auto f = start_tile_function();
...@@ -58,12 +62,10 @@ namespace ngraph ...@@ -58,12 +62,10 @@ namespace ngraph
continue; continue;
} }
std::string sidx{std::to_string(iidx)}; std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims( f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims("I" + sidx + "_D", 0, dim_count));
"I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="} f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx} .set(builder::ContractionOutput{"E" + sidx}
.add_dims([&]( .add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx) for (std::size_t idx = 0; idx < dim_count; ++idx)
{ {
std::ostringstream s; std::ostringstream s;
...@@ -78,22 +80,19 @@ namespace ngraph ...@@ -78,22 +80,19 @@ namespace ngraph
} }
} }
}) })
.add_indices([&]( .add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx) for (std::size_t idx = 0; idx < dim_count; ++idx)
{ {
std::ostringstream s; std::ostringstream s;
s << "d" << idx; s << "d" << idx;
if (saw_non_zero_tensor && if (saw_non_zero_tensor && idx == op().get_concatenation_axis())
idx == op().get_concatenation_axis())
{ {
s << " + " << offset.str(); s << " + " << offset.str();
} }
out = s.str(); out = s.str();
} }
})) }))
.set(builder::ContractionInput{"I" + sidx}.add_indices( .set(builder::ContractionInput{"I" + sidx}.add_indices("d", 0, dim_count)));
"d", 0, dim_count)));
if (saw_non_zero_tensor) if (saw_non_zero_tensor)
{ {
oexpr << " + "; oexpr << " + ";
...@@ -106,12 +105,4 @@ namespace ngraph ...@@ -106,12 +105,4 @@ namespace ngraph
f.add(builder::Elementwise{"O", oexpr.str()}); f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize()); set_output(f.finalize());
}
namespace
{
Impl<op::Concat>::Registration register_concat;
}
}
}
} }
...@@ -24,25 +24,20 @@ namespace ngraph ...@@ -24,25 +24,20 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Convert views a tensor as a new type. NGRAPH_PLAIDML_OP_CLASS(ImplConvert, OpImpl<op::Convert>);
template <> }
void Impl<op::Convert>::operator()() }
{ }
// Convert views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConvert::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output( set_output(start_tile_function()
start_tile_function()
.add(builder::Input{op_input(), "I"}) .add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{ .add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))}) "O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize()); .finalize());
}
namespace
{
Impl<op::Convert>::Registration register_convert;
}
}
}
} }
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Convolution;
class ConvolutionBackpropData;
class ConvolutionBackpropFilters;
}
}
}
}
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::Convolution>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
private:
std::shared_ptr<ngraph::op::Convolution> m_src;
AxisVector m_data_axes;
AxisVector m_filters_axes;
AxisVector m_output_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropData>& get_src() const { return m_src; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_data_axes() const { return m_data_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropData> m_src;
AxisVector m_filters_axes;
AxisVector m_output_axes;
AxisVector m_data_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropFilters>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> m_src;
AxisVector m_data_axes;
AxisVector m_output_axes;
AxisVector m_filters_axes;
};
...@@ -26,11 +26,15 @@ namespace ngraph ...@@ -26,11 +26,15 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Dot is a generalized dot product operation -- scalar-tensor, NGRAPH_PLAIDML_OP_CLASS(ImplDot, OpImpl<op::Dot>);
// matrix-vector, and matrix multiplication. }
template <> }
void Impl<op::Dot>::operator()() }
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
void ngraph::runtime::plaidml::ImplDot::Apply()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -46,8 +50,7 @@ namespace ngraph ...@@ -46,8 +50,7 @@ namespace ngraph
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac; NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic; NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
set_output( set_output(start_tile_function()
start_tile_function()
.add(builder::Input{op_input(0), "L"} .add(builder::Input{op_input(0), "L"}
.add_dims("DL", 1, l_dim_mac + 1) .add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DC", 1, reduce_limit + 1)) .add_dims("DC", 1, reduce_limit + 1))
...@@ -68,12 +71,4 @@ namespace ngraph ...@@ -68,12 +71,4 @@ namespace ngraph
.add_indices("dc", 1, reduce_limit + 1) .add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1))) .add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize()); .finalize());
}
namespace
{
Impl<op::Dot>::Registration register_dot;
}
}
}
} }
...@@ -25,10 +25,14 @@ namespace ngraph ...@@ -25,10 +25,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// FunctionCall invokes a sub-function. NGRAPH_PLAIDML_OP_CLASS(ImplFunctionCall, OpImpl<op::FunctionCall>);
template <> }
void Impl<op::FunctionCall>::operator()() }
{ }
// FunctionCall invokes a sub-function.
void ngraph::runtime::plaidml::ImplFunctionCall::Apply()
{
Build b; Build b;
build()->compiler->build(op().get_functions()[0], &b); build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer}; vertexai::plaidml::function f{b.composer};
...@@ -36,8 +40,7 @@ namespace ngraph ...@@ -36,8 +40,7 @@ namespace ngraph
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx) for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{ {
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get(); auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv = auto* iitv = b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var); inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
} }
vertexai::plaidml::application app{f.apply(inputs)}; vertexai::plaidml::application app{f.apply(inputs)};
...@@ -46,12 +49,4 @@ namespace ngraph ...@@ -46,12 +49,4 @@ namespace ngraph
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get(); auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv])); set_output(idx, app.get_output(b.output_names[iotv]));
} }
}
namespace
{
Impl<op::FunctionCall>::Registration register_function_call;
}
}
}
} }
...@@ -26,19 +26,14 @@ namespace ngraph ...@@ -26,19 +26,14 @@ namespace ngraph
namespace plaidml namespace plaidml
{ {
template <typename O> template <typename O>
class IndexReductionImpl : public BaseImpl<O> class IndexReductionBase : public OpImpl<O>
{ {
public: protected:
IndexReductionImpl(Build* build, const O& op)
: BaseImpl<O>{build, op}
{
}
void build_index_reduction(const char* agg_op); void build_index_reduction(const char* agg_op);
}; };
template <typename O> template <typename O>
void IndexReductionImpl<O>::build_index_reduction(const char* agg_op) void IndexReductionBase<O>::build_index_reduction(const char* agg_op)
{ {
this->check_inputs(1); this->check_inputs(1);
this->check_outputs(1); this->check_outputs(1);
...@@ -117,37 +112,20 @@ namespace ngraph ...@@ -117,37 +112,20 @@ namespace ngraph
.finalize()); .finalize());
} }
template <> NGRAPH_PLAIDML_OP_CLASS(ImplArgMax, IndexReductionBase<op::ArgMax>);
struct ParentImpl<op::ArgMax> NGRAPH_PLAIDML_OP_CLASS(ImplArgMin, IndexReductionBase<op::ArgMin>);
{ }
using Type = IndexReductionImpl<op::ArgMax>; }
}; }
template <>
struct ParentImpl<op::ArgMin>
{
using Type = IndexReductionImpl<op::ArgMin>;
};
// ArgMax computes the maximum index along a tensor axis. // ArgMax computes the maximum index along a tensor axis.
template <> void ngraph::runtime::plaidml::ImplArgMax::Apply()
void Impl<op::ArgMax>::operator()() {
{
build_index_reduction(">"); build_index_reduction(">");
} }
// ArgMin computes the minimum index along a tensor axis. // ArgMin computes the minimum index along a tensor axis.
template <> void ngraph::runtime::plaidml::ImplArgMin::Apply()
void Impl<op::ArgMin>::operator()() {
{
build_index_reduction("<"); build_index_reduction("<");
}
namespace
{
Impl<op::ArgMax>::Registration register_argmax;
Impl<op::ArgMin>::Registration register_argmin;
}
}
}
} }
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder. NGRAPH_PLAIDML_OP_CLASS(ImplParameter, OpImpl<op::Parameter>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplResult, OpImpl<op::Result>);
void Impl<op::Parameter>::operator()() }
{ }
}
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
void ngraph::runtime::plaidml::ImplParameter::Apply()
{
check_inputs(0); check_inputs(0);
check_outputs(1); check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
...@@ -39,25 +44,15 @@ namespace ngraph ...@@ -39,25 +44,15 @@ namespace ngraph
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA}); build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph); build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name)); build()->input_names.emplace(tv, std::move(name));
} }
// Result binds a PlaidML variable to a composed function output. // Result binds a PlaidML variable to a composed function output.
template <> void ngraph::runtime::plaidml::ImplResult::Apply()
void Impl<op::Result>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size()); std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get(); descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input()); build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name)); build()->output_names.emplace(tv, std::move(name));
}
namespace
{
Impl<op::Parameter>::Registration register_parameter;
Impl<op::Result>::Registration register_result;
}
}
}
} }
...@@ -23,23 +23,25 @@ namespace ngraph ...@@ -23,23 +23,25 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// LRN implements Local Response Normalization NGRAPH_PLAIDML_OP_CLASS(ImplLRN, OpImpl<op::LRN>);
template <> }
void Impl<op::LRN>::operator()() }
{ }
// LRN implements Local Response Normalization
void ngraph::runtime::plaidml::ImplLRN::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size(); auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2; auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2; auto distance = op().get_nsize() / 2;
std::ostringstream div_expr; std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / "
<< ".0 / " << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)"; << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output( set_output(
start_tile_function() start_tile_function()
.add(builder::Input{op_input(), "I"} .add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank))
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"}) .add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"} .add(builder::UnaryContraction{"+"}
...@@ -51,18 +53,9 @@ namespace ngraph ...@@ -51,18 +53,9 @@ namespace ngraph
.set(builder::ContractionInput{"ISQ"} .set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)}) .add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank)) .add_indices("d", 0, rank))
.add_constraints( .add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
[&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize()); out = "z < " + std::to_string(op().get_nsize());
})) }))
.add(builder::Elementwise{"O", div_expr.str()}) .add(builder::Elementwise{"O", div_expr.str()})
.finalize()); .finalize());
}
namespace
{
Impl<op::LRN>::Registration register_local_response_norm;
}
}
}
} }
...@@ -25,56 +25,47 @@ namespace ngraph ...@@ -25,56 +25,47 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// And performs a simple elementwise logical and. NGRAPH_PLAIDML_OP_CLASS(ImplAnd, OpImpl<op::And>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplNot, OpImpl<op::Not>);
void Impl<op::And>::operator()() NGRAPH_PLAIDML_OP_CLASS(ImplOr, OpImpl<op::Or>);
{ }
}
}
// And performs a simple elementwise logical and.
void ngraph::runtime::plaidml::ImplAnd::Apply()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"}) .add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? B : A"}) .add(builder::Elementwise{"C", "A ? B : A"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// Not performs a simple elementwise logical not. // Not performs a simple elementwise logical not.
template <> void ngraph::runtime::plaidml::ImplNot::Apply()
void Impl<op::Not>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "I"}) .add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"}) .add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// Or performs a simple elementwise logical or. // Or performs a simple elementwise logical or.
template <> void ngraph::runtime::plaidml::ImplOr::Apply()
void Impl<op::Or>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"}) .add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? A : B"}) .add(builder::Elementwise{"C", "A ? A : B"})
.finalize(), .finalize());
TensorContents::LOGICAL);
}
namespace
{
Impl<op::And>::Registration register_and;
Impl<op::Not>::Registration register_not;
Impl<op::Or>::Registration register_or;
}
}
}
} }
...@@ -26,10 +26,14 @@ namespace ngraph ...@@ -26,10 +26,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// OneHot performs one-hot encoding along the requested axis. NGRAPH_PLAIDML_OP_CLASS(ImplOneHot, OpImpl<op::OneHot>);
template <> }
void Impl<op::OneHot>::operator()() }
{ }
// OneHot performs one-hot encoding along the requested axis.
void ngraph::runtime::plaidml::ImplOneHot::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -74,12 +78,9 @@ namespace ngraph ...@@ -74,12 +78,9 @@ namespace ngraph
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size())) .add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"}) .add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add( .add(builder::UnaryContraction{"="}
builder::UnaryContraction{"="} .set(builder::ContractionOutput{"ZS"}
.set( .add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
builder::ContractionOutput{"ZS"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx) for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{ {
if (idx == op().get_one_hot_axis()) if (idx == op().get_one_hot_axis())
...@@ -94,19 +95,10 @@ namespace ngraph ...@@ -94,19 +95,10 @@ namespace ngraph
}) })
.add_indices("d", 0, out_shape.size())) .add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"})) .set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{ .add(builder::Elementwise{"Idx",
"Idx", "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"}) "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"}) .add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"}) .add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O", .add(builder::Elementwise{"O", tile_converter("OV", op().get_element_type())})
tile_converter("OV", op().get_element_type())})
.finalize()); .finalize());
}
namespace
{
Impl<op::OneHot>::Registration register_one_hot;
}
}
}
} }
...@@ -26,10 +26,17 @@ namespace ngraph ...@@ -26,10 +26,17 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// AvgPool implements a batch average pooling operation. NGRAPH_PLAIDML_OP_CLASS(ImplAvgPool, OpImpl<op::AvgPool>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplMaxPool, OpImpl<op::MaxPool>);
void Impl<op::AvgPool>::operator()() NGRAPH_PLAIDML_OP_CLASS(ImplAvgPoolBackprop, OpImpl<op::AvgPoolBackprop>);
{ NGRAPH_PLAIDML_OP_CLASS(ImplMaxPoolBackprop, OpImpl<op::MaxPoolBackprop>);
}
}
}
// AvgPool implements a batch average pooling operation.
void ngraph::runtime::plaidml::ImplAvgPool::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -98,12 +105,11 @@ namespace ngraph ...@@ -98,12 +105,11 @@ namespace ngraph
f.add(cpf.PoolContraction()).add(builder::Elementwise{"O", "S / Count"}); f.add(cpf.PoolContraction()).add(builder::Elementwise{"O", "S / Count"});
set_output(f.finalize()); set_output(f.finalize());
} }
// MaxPool implements a batch max pooling operation. // MaxPool implements a batch max pooling operation.
template <> void ngraph::runtime::plaidml::ImplMaxPool::Apply()
void Impl<op::MaxPool>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -162,11 +168,10 @@ namespace ngraph ...@@ -162,11 +168,10 @@ namespace ngraph
.add(cpf.O_out_header()) .add(cpf.O_out_header())
.add(cpf.PoolContraction()) .add(cpf.PoolContraction())
.finalize()); .finalize());
} }
template <> void ngraph::runtime::plaidml::ImplAvgPoolBackprop::Apply()
void Impl<op::AvgPoolBackprop>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -180,8 +185,7 @@ namespace ngraph ...@@ -180,8 +185,7 @@ namespace ngraph
if (include_padding) if (include_padding)
{ {
throw std::runtime_error( throw std::runtime_error("Include padding in average not yet implemented in PlaidML");
"Include padding in average not yet implemented in PlaidML");
} }
ngraph::CoordinateDiff pad_above; ngraph::CoordinateDiff pad_above;
...@@ -236,20 +240,18 @@ namespace ngraph ...@@ -236,20 +240,18 @@ namespace ngraph
{ {
std::ostringstream s; std::ostringstream s;
s << "XI" << i - 2; s << "XI" << i - 2;
ret.add( ret.add(builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
} }
set_output(ret.add(cpf.Broadcast_Ones()) set_output(ret.add(cpf.Broadcast_Ones())
.add(cpf.Count()) .add(cpf.Count())
.add(builder::Elementwise{"S", "DO / Count"}) .add(builder::Elementwise{"S", "DO / Count"})
.add(cpf.PoolContraction()) .add(cpf.PoolContraction())
.finalize()); .finalize());
} }
template <> void ngraph::runtime::plaidml::ImplMaxPoolBackprop::Apply()
void Impl<op::MaxPoolBackprop>::operator()() {
{ check_inputs_ge(2);
check_inputs(2);
check_outputs(1); check_outputs(1);
auto src_dims = op().get_inputs()[0].get_shape().size() - 2; auto src_dims = op().get_inputs()[0].get_shape().size() - 2;
...@@ -307,15 +309,4 @@ namespace ngraph ...@@ -307,15 +309,4 @@ namespace ngraph
.add(cpf.PoolContraction()) .add(cpf.PoolContraction())
.add(cpf.PoolDerivContraction()) .add(cpf.PoolDerivContraction())
.finalize()); .finalize());
}
namespace
{
Impl<op::AvgPool>::Registration register_avg_pool;
Impl<op::MaxPool>::Registration register_max_pool;
Impl<op::AvgPoolBackprop>::Registration register_avg_pool_backprop;
Impl<op::MaxPoolBackprop>::Registration register_max_pool_backprop;
}
}
}
} }
...@@ -32,19 +32,14 @@ namespace ngraph ...@@ -32,19 +32,14 @@ namespace ngraph
namespace plaidml namespace plaidml
{ {
template <typename O> template <typename O>
class ReductionImpl : public BaseImpl<O> class ReductionBase : public OpImpl<O>
{ {
public: public:
ReductionImpl(Build* build, const O& op)
: BaseImpl<O>{build, op}
{
}
void build_reduction(const char* agg_op); void build_reduction(const char* agg_op);
}; };
template <typename O> template <typename O>
void ReductionImpl<O>::build_reduction(const char* agg_op) void ReductionBase<O>::build_reduction(const char* agg_op)
{ {
this->check_inputs(1); this->check_inputs(1);
this->check_outputs(1); this->check_outputs(1);
...@@ -90,61 +85,36 @@ namespace ngraph ...@@ -90,61 +85,36 @@ namespace ngraph
.finalize()); .finalize());
} }
template <> NGRAPH_PLAIDML_OP_CLASS(ImplMax, ReductionBase<op::Max>);
struct ParentImpl<op::Max> NGRAPH_PLAIDML_OP_CLASS(ImplMin, ReductionBase<op::Min>);
{ NGRAPH_PLAIDML_OP_CLASS(ImplProduct, ReductionBase<op::Product>);
using Type = ReductionImpl<op::Max>; NGRAPH_PLAIDML_OP_CLASS(ImplReduce, ReductionBase<op::Reduce>);
}; NGRAPH_PLAIDML_OP_CLASS(ImplSum, ReductionBase<op::Sum>);
}
template <> }
struct ParentImpl<op::Min> }
{
using Type = ReductionImpl<op::Min>;
};
template <>
struct ParentImpl<op::Product>
{
using Type = ReductionImpl<op::Product>;
};
template <>
struct ParentImpl<op::Reduce>
{
using Type = ReductionImpl<op::Reduce>;
};
template <>
struct ParentImpl<op::Sum>
{
using Type = ReductionImpl<op::Sum>;
};
// Max reduces a tensor, taking the maximum along the specified axes. // Max reduces a tensor, taking the maximum along the specified axes.
template <> void ngraph::runtime::plaidml::ImplMax::Apply()
void Impl<op::Max>::operator()() {
{
build_reduction(">"); build_reduction(">");
} }
// Min reduces a tensor, taking the minimum along the specified axes. // Min reduces a tensor, taking the minimum along the specified axes.
template <> void ngraph::runtime::plaidml::ImplMin::Apply()
void Impl<op::Min>::operator()() {
{
build_reduction("<"); build_reduction("<");
} }
// Min reduces a tensor, taking the product along the specified axes. // Min reduces a tensor, taking the product along the specified axes.
template <> void ngraph::runtime::plaidml::ImplProduct::Apply()
void Impl<op::Product>::operator()() {
{
build_reduction("*"); build_reduction("*");
} }
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation. // Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template <> void ngraph::runtime::plaidml::ImplReduce::Apply()
void Impl<op::Reduce>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -185,13 +155,10 @@ namespace ngraph ...@@ -185,13 +155,10 @@ namespace ngraph
start_tile_function() start_tile_function()
.add(builder::Input{op_input(1), "I"}) .add(builder::Input{op_input(1), "I"})
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add( .add(builder::UnaryContraction{"="}
builder::UnaryContraction{"="} .set(builder::ContractionOutput{"O"}
.set(
builder::ContractionOutput{"O"}
.add_indices("d", 0, agg_dim_limit) .add_indices("d", 0, agg_dim_limit)
.add_dims([&]( .add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < agg_dim_limit; ++idx) for (auto idx = 0; idx < agg_dim_limit; ++idx)
{ {
out = "1"; out = "1";
...@@ -210,12 +177,9 @@ namespace ngraph ...@@ -210,12 +177,9 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::UnaryContraction{"="} .add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices([&]( .add_indices(
std::back_insert_iterator<std::list<std::string>> [&](std::back_insert_iterator<std::list<std::string>> out) {
out) { for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
{ {
if (!op().get_reduction_axes().count(idx)) if (!op().get_reduction_axes().count(idx))
{ {
...@@ -223,12 +187,9 @@ namespace ngraph ...@@ -223,12 +187,9 @@ namespace ngraph
} }
} }
}) })
.add_dims([&]( .add_dims(
std::back_insert_iterator<std::list<std::string>> [&](std::back_insert_iterator<std::list<std::string>> out) {
out) { for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
{ {
if (!op().get_reduction_axes().count(idx)) if (!op().get_reduction_axes().count(idx))
{ {
...@@ -236,8 +197,8 @@ namespace ngraph ...@@ -236,8 +197,8 @@ namespace ngraph
} }
} }
})) }))
.set(builder::ContractionInput{"I"}.add_indices([&]( .set(builder::ContractionInput{"I"}.add_indices(
std::back_insert_iterator<std::list<std::string>> out) { [&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx) for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
{ {
std::size_t cidx = 0; std::size_t cidx = 0;
...@@ -255,23 +216,10 @@ namespace ngraph ...@@ -255,23 +216,10 @@ namespace ngraph
} }
set_output(result); set_output(result);
} }
// Sum reduces a tensor, summing the specified axes. // Sum reduces a tensor, summing the specified axes.
template <> void ngraph::runtime::plaidml::ImplSum::Apply()
void Impl<op::Sum>::operator()() {
{
build_reduction("+"); build_reduction("+");
}
namespace
{
Impl<op::Max>::Registration register_max;
Impl<op::Min>::Registration register_min;
Impl<op::Product>::Registration register_product;
Impl<op::Reduce>::Registration register_reduce;
Impl<op::Sum>::Registration register_sum;
}
}
}
} }
...@@ -25,10 +25,14 @@ namespace ngraph ...@@ -25,10 +25,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// ReplaceSlice replaces part of a tensor with another tensor. NGRAPH_PLAIDML_OP_CLASS(ImplReplaceSlice, OpImpl<op::ReplaceSlice>);
template <> }
void Impl<op::ReplaceSlice>::operator()() }
{ }
// ReplaceSlice replaces part of a tensor with another tensor.
void ngraph::runtime::plaidml::ImplReplaceSlice::Apply()
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
...@@ -49,13 +53,11 @@ namespace ngraph ...@@ -49,13 +53,11 @@ namespace ngraph
.add(builder::Input{op_input(0), "L"}.add_dims("D", 0, shape.size())) .add(builder::Input{op_input(0), "L"}.add_dims("D", 0, shape.size()))
.add(builder::Input{op_input(1), "S"}.add_dims("SD", 0, shape.size())) .add(builder::Input{op_input(1), "S"}.add_dims("SD", 0, shape.size()))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add( .add(builder::UnaryContraction{"="}
builder::UnaryContraction{"="} .set(builder::ContractionOutput{"O"}
.set(
builder::ContractionOutput{"O"}
.add_dims("D", 0, shape.size()) .add_dims("D", 0, shape.size())
.add_indices([&]( .add_indices(
std::back_insert_iterator<std::list<std::string>> out) { [&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx) for (std::size_t idx = 0; idx < shape.size(); ++idx)
{ {
auto stride = op().get_strides()[idx]; auto stride = op().get_strides()[idx];
...@@ -81,10 +83,8 @@ namespace ngraph ...@@ -81,10 +83,8 @@ namespace ngraph
out = didx.str(); out = didx.str();
} }
})) }))
.set(builder::ContractionInput{"S"}.add_indices( .set(builder::ContractionInput{"S"}.add_indices("d", 0, shape.size()))
"d", 0, shape.size())) .add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx) for (std::size_t idx = 0; idx < shape.size(); ++idx)
{ {
out = "d" + std::to_string(idx) + " < " + out = "d" + std::to_string(idx) + " < " +
...@@ -94,12 +94,4 @@ namespace ngraph ...@@ -94,12 +94,4 @@ namespace ngraph
}) })
.set_default("L")) .set_default("L"))
.finalize()); .finalize());
}
namespace
{
Impl<op::ReplaceSlice>::Registration register_replace_slice;
}
}
}
} }
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_ops_replicate.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplReplicate, OpImpl<plaidml::op::Replicate>);
}
}
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
std::size_t replication_axis,
std::size_t replication_count)
: Op{"Replicate", NodeVector{arg}}
, m_replication_axes(arg->get_shape().size(), 1)
{
m_replication_axes.at(replication_axis) = replication_count;
constructor_validate_and_infer_types();
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
std::vector<std::size_t> replication_axes)
: Op{"Replicate", NodeVector{arg}}
, m_replication_axes(std::move(replication_axes))
{
if (arg->get_shape().size() != m_replication_axes.size())
{
throw ngraph_error{"Replicate requires compatible axes dimensions"};
}
constructor_validate_and_infer_types();
}
void ngraph::runtime::plaidml::op::Replicate::validate_and_infer_types()
{
const auto& arg = get_arguments().at(0);
Shape shape = arg->get_shape();
for (auto rit = m_replication_axes.begin(), sit = shape.begin();
rit != m_replication_axes.end();
++rit, ++sit)
{
*sit *= *rit;
}
set_output_type(0, arg->get_element_type(), shape);
}
std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Replicate::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error{"Replicate requires exactly one input"};
}
if (new_args.at(0)->get_shape().size() != m_replication_axes.size())
{
throw ngraph_error{"Replicate requires identical dimensions in inputs"};
}
return std::make_shared<Replicate>(new_args.at(0), m_replication_axes);
}
void ngraph::runtime::plaidml::ImplReplicate::Apply()
{
check_inputs(1);
check_outputs(1);
const auto& axes = op().get_replication_axes();
const auto& ishape = op().get_input_shape(0);
set_output(
start_tile_function()
.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, axes.size()))
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < axes.size(); ++idx)
{
std::string dsize = "D" + std::to_string(idx);
if (axes.at(idx) != 1)
{
dsize = dsize + " * " + std::to_string(axes.at(idx));
}
out = dsize;
}
})
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < axes.size(); ++idx)
{
std::string didx = "d" + std::to_string(idx);
if (axes.at(idx) != 1)
{
if (ishape.at(idx) == 1)
{
didx = didx + " + s" + std::to_string(idx);
}
else
{
didx = didx + " + (s" + std::to_string(idx) + " * " +
std::to_string(ishape.at(idx)) + ")";
}
}
out = didx;
}
}))
.set(builder::ContractionInput{"I"}.add_indices("d", 0, axes.size())))
.finalize());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <vector>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Replicate;
}
}
}
}
// Replicate works like Concat, but only over identical inputs. This
// restriction allows it to be substantially more efficient.
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{
public:
Replicate(std::shared_ptr<Node> arg,
std::size_t replication_axis,
std::size_t replication_count);
Replicate(std::shared_ptr<Node> arg, std::vector<std::size_t> replication_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
/// \return The replication axes: axis index -> the replication count along that axis.
const std::vector<std::size_t>& get_replication_axes() const { return m_replication_axes; }
private:
std::vector<std::size_t> m_replication_axes;
};
...@@ -25,10 +25,14 @@ namespace ngraph ...@@ -25,10 +25,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Reverse reverses the selected axes within a tensor. NGRAPH_PLAIDML_OP_CLASS(ImplReverse, OpImpl<op::Reverse>);
template <> }
void Impl<op::Reverse>::operator()() }
{ }
// Reverse reverses the selected axes within a tensor.
void ngraph::runtime::plaidml::ImplReverse::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -41,8 +45,8 @@ namespace ngraph ...@@ -41,8 +45,8 @@ namespace ngraph
.set(builder::ContractionOutput{"O"} .set(builder::ContractionOutput{"O"}
.add_indices("d", 0, shape.size()) .add_indices("d", 0, shape.size())
.add_dims("D", 0, shape.size())) .add_dims("D", 0, shape.size()))
.set(builder::ContractionInput{"I"}.add_indices([&]( .set(builder::ContractionInput{"I"}.add_indices(
std::back_insert_iterator<std::list<std::string>> out) { [&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx) for (std::size_t idx = 0; idx < shape.size(); ++idx)
{ {
auto sidx = std::to_string(idx); auto sidx = std::to_string(idx);
...@@ -57,12 +61,4 @@ namespace ngraph ...@@ -57,12 +61,4 @@ namespace ngraph
} }
}))) })))
.finalize()); .finalize());
}
namespace
{
Impl<op::Reverse>::Registration register_reverse;
}
}
}
} }
...@@ -24,10 +24,14 @@ namespace ngraph ...@@ -24,10 +24,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Slice takes a sub-slice of a tensor. NGRAPH_PLAIDML_OP_CLASS(ImplSlice, OpImpl<op::Slice>);
template <> }
void Impl<op::Slice>::operator()() }
{ }
// Slice takes a sub-slice of a tensor.
void ngraph::runtime::plaidml::ImplSlice::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
NGRAPH_DEBUG << "Slice: low: " << op().get_lower_bounds(); NGRAPH_DEBUG << "Slice: low: " << op().get_lower_bounds();
...@@ -39,21 +43,17 @@ namespace ngraph ...@@ -39,21 +43,17 @@ namespace ngraph
start_tile_function() start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("ID", 0, dim_limit)) .add(builder::Input{op_input(), "I"}.add_dims("ID", 0, dim_limit))
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add( .add(builder::UnaryContraction{"="}
builder::UnaryContraction{"="} .set(builder::ContractionOutput{"O"}
.set(
builder::ContractionOutput{"O"}
.add_indices("od", 0, dim_limit) .add_indices("od", 0, dim_limit)
.add_dims([&]( .add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx) for (std::size_t idx = 0; idx < dim_limit; ++idx)
{ {
std::ostringstream s; std::ostringstream s;
std::size_t stride = op().get_strides()[idx]; std::size_t stride = op().get_strides()[idx];
std::ptrdiff_t trim_count = std::ptrdiff_t trim_count =
op().get_lower_bounds()[idx] + op().get_lower_bounds()[idx] +
(shape[idx] - op().get_upper_bounds()[idx]) + (shape[idx] - op().get_upper_bounds()[idx]) + 1 - stride;
1 - stride;
if ((stride != 1) && trim_count) if ((stride != 1) && trim_count)
{ {
s << "("; s << "(";
...@@ -106,12 +106,4 @@ namespace ngraph ...@@ -106,12 +106,4 @@ namespace ngraph
} }
}))) })))
.finalize()); .finalize());
}
namespace
{
Impl<op::Slice>::Registration register_slice;
}
}
}
} }
...@@ -25,10 +25,14 @@ namespace ngraph ...@@ -25,10 +25,14 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Softmax implements a standard ML softmax operation. NGRAPH_PLAIDML_OP_CLASS(ImplSoftmax, OpImpl<op::Softmax>);
template <> }
void Impl<op::Softmax>::operator()() }
{ }
// Softmax implements a standard ML softmax operation.
void ngraph::runtime::plaidml::ImplSoftmax::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
...@@ -36,8 +40,7 @@ namespace ngraph ...@@ -36,8 +40,7 @@ namespace ngraph
auto dim_limit = shape.size(); auto dim_limit = shape.size();
auto f = start_tile_function(); auto f = start_tile_function();
f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit)) f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit)).add(builder::Output{"O"});
.add(builder::Output{"O"});
bool reorder_needed = false; bool reorder_needed = false;
bool saw_element = false; bool saw_element = false;
...@@ -78,8 +81,7 @@ namespace ngraph ...@@ -78,8 +81,7 @@ namespace ngraph
{ {
f.add(builder::UnaryContraction{"="} f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"RI"} .set(builder::ContractionOutput{"RI"}
.add_dims([&]( .add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs) for (auto idx : group_idxs)
{ {
out = "D" + std::to_string(idx); out = "D" + std::to_string(idx);
...@@ -89,8 +91,7 @@ namespace ngraph ...@@ -89,8 +91,7 @@ namespace ngraph
out = "D" + std::to_string(idx); out = "D" + std::to_string(idx);
} }
}) })
.add_indices([&]( .add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs) for (auto idx : group_idxs)
{ {
out = "d" + std::to_string(idx); out = "d" + std::to_string(idx);
...@@ -126,8 +127,7 @@ namespace ngraph ...@@ -126,8 +127,7 @@ namespace ngraph
{ {
// Take the softmax. // Take the softmax.
std::ostringstream softmax; std::ostringstream softmax;
softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements << ")";
<< ")";
f.add(builder::Elementwise{output, softmax.str()}); f.add(builder::Elementwise{output, softmax.str()});
} }
...@@ -169,12 +169,4 @@ namespace ngraph ...@@ -169,12 +169,4 @@ namespace ngraph
} }
set_output(f.finalize()); set_output(f.finalize());
}
namespace
{
Impl<op::Softmax>::Registration register_softmax;
}
}
}
} }
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <utility>
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_tile.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplTile, OpImpl<op::Tile>);
}
}
}
ngraph::runtime::plaidml::op::Tile::Tile(
const std::string& node_type,
vertexai::plaidml::function function,
const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs)
: Node{node_type, args, outputs.size()}
, m_function{std::move(function)}
, m_output_shapes{std::move(outputs)}
{
constructor_validate_and_infer_types();
}
void ngraph::runtime::plaidml::op::Tile::validate_and_infer_types()
{
// TODO: It would be useful to have PlaidML deduce the output
// shapes, instead of having them passed in via the
// constructor. The primary barrier to doing so is that
// PlaidML placeholders always have a fixed number of
// dimensions but arbitrary dimension sizes, and the only way
// to pin them down to a concrete dimension size is to bind a
// tensor to them, which requires actually allocating the
// tensor. In principal, we could fix this pretty easily;
// we'll need to know more about where the PlaidML API is
// going before doing so, though.
if (get_input_size() != m_function.num_inputs())
{
throw ngraph_error{"Incorrect input count for Tile operation node"};
}
if (m_output_shapes.size() != m_function.num_outputs())
{
throw ngraph_error{"Incorrect output count for Tile operation node"};
}
std::size_t idx = 0;
for (auto& output_shape : m_output_shapes)
{
set_output_type(idx++, std::get<0>(output_shape), std::get<1>(output_shape));
}
}
std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Tile::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != get_input_size())
{
throw ngraph_error{"Tile node input counts cannot be changed for a given Tile function"};
}
return std::make_shared<Tile>(description(), m_function, new_args, m_output_shapes);
}
void ngraph::runtime::plaidml::ImplTile::Apply()
{
vertexai::plaidml::function::positional_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
inputs.emplace_back(op_input(idx));
}
auto app = op().func().apply(inputs);
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
set_output(idx, app.get_output(idx));
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <tuple>
#include <vector>
#include <plaidml/plaidml++.h>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
/// An op directly representing PlaidML Tile code.
class Tile;
}
}
}
}
class ngraph::runtime::plaidml::op::Tile final : public Node
{
public:
Tile(const std::string& node_type,
vertexai::plaidml::function function,
const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
vertexai::plaidml::function func() const { return m_function; }
private:
vertexai::plaidml::function m_function;
std::vector<std::tuple<element::Type, PartialShape>> m_output_shapes;
};
...@@ -35,10 +35,26 @@ namespace ngraph ...@@ -35,10 +35,26 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// acos performs a simple elementwise arccos function. NGRAPH_PLAIDML_OP_CLASS(ImplAcos, OpImpl<op::Acos>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplAsin, OpImpl<op::Asin>);
void Impl<op::Acos>::operator()() NGRAPH_PLAIDML_OP_CLASS(ImplAtan, OpImpl<op::Atan>);
{ NGRAPH_PLAIDML_OP_CLASS(ImplCos, OpImpl<op::Cos>);
NGRAPH_PLAIDML_OP_CLASS(ImplCosh, OpImpl<op::Cosh>);
NGRAPH_PLAIDML_OP_CLASS(ImplExp, OpImpl<op::Exp>);
NGRAPH_PLAIDML_OP_CLASS(ImplLog, OpImpl<op::Log>);
NGRAPH_PLAIDML_OP_CLASS(ImplPower, OpImpl<op::Power>);
NGRAPH_PLAIDML_OP_CLASS(ImplSin, OpImpl<op::Sin>);
NGRAPH_PLAIDML_OP_CLASS(ImplSinh, OpImpl<op::Sinh>);
NGRAPH_PLAIDML_OP_CLASS(ImplSqrt, OpImpl<op::Sqrt>);
NGRAPH_PLAIDML_OP_CLASS(ImplTan, OpImpl<op::Tan>);
NGRAPH_PLAIDML_OP_CLASS(ImplTanh, OpImpl<op::Tanh>);
}
}
}
// acos performs a simple elementwise arccos function.
void ngraph::runtime::plaidml::ImplAcos::Apply()
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -46,12 +62,11 @@ namespace ngraph ...@@ -46,12 +62,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "acos(I)"}) .add(builder::Elementwise{"O", "acos(I)"})
.finalize()); .finalize());
} }
// asin performs a simple elementwise arcsin function. // asin performs a simple elementwise arcsin function.
template <> void ngraph::runtime::plaidml::ImplAsin::Apply()
void Impl<op::Asin>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -59,12 +74,11 @@ namespace ngraph ...@@ -59,12 +74,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "asin(I)"}) .add(builder::Elementwise{"O", "asin(I)"})
.finalize()); .finalize());
} }
// atan performs a simple elementwise arctan function. // atan performs a simple elementwise arctan function.
template <> void ngraph::runtime::plaidml::ImplAtan::Apply()
void Impl<op::Atan>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -72,12 +86,11 @@ namespace ngraph ...@@ -72,12 +86,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "atan(I)"}) .add(builder::Elementwise{"O", "atan(I)"})
.finalize()); .finalize());
} }
// cos performs a simple elementwise cos function. // cos performs a simple elementwise cos function.
template <> void ngraph::runtime::plaidml::ImplCos::Apply()
void Impl<op::Cos>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -85,12 +98,11 @@ namespace ngraph ...@@ -85,12 +98,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cos(I)"}) .add(builder::Elementwise{"O", "cos(I)"})
.finalize()); .finalize());
} }
// cosh performs a simple elementwise hyperbolic cos function. // cosh performs a simple elementwise hyperbolic cos function.
template <> void ngraph::runtime::plaidml::ImplCosh::Apply()
void Impl<op::Cosh>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -98,12 +110,11 @@ namespace ngraph ...@@ -98,12 +110,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cosh(I)"}) .add(builder::Elementwise{"O", "cosh(I)"})
.finalize()); .finalize());
} }
// exp performs a simple elementwise natural exponential function. // exp performs a simple elementwise natural exponential function.
template <> void ngraph::runtime::plaidml::ImplExp::Apply()
void Impl<op::Exp>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -111,12 +122,11 @@ namespace ngraph ...@@ -111,12 +122,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "exp(I)"}) .add(builder::Elementwise{"O", "exp(I)"})
.finalize()); .finalize());
} }
// log performs a simple elementwise natural logarithm function. // log performs a simple elementwise natural logarithm function.
template <> void ngraph::runtime::plaidml::ImplLog::Apply()
void Impl<op::Log>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -124,12 +134,11 @@ namespace ngraph ...@@ -124,12 +134,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "log(I)"}) .add(builder::Elementwise{"O", "log(I)"})
.finalize()); .finalize());
} }
// power performs a simple elementwise power function. // power performs a simple elementwise power function.
template <> void ngraph::runtime::plaidml::ImplPower::Apply()
void Impl<op::Power>::operator()() {
{
check_inputs(2); check_inputs(2);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -138,12 +147,11 @@ namespace ngraph ...@@ -138,12 +147,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "pow(I, E)"}) .add(builder::Elementwise{"O", "pow(I, E)"})
.finalize()); .finalize());
} }
// sin performs a simple elementwise sin function. // sin performs a simple elementwise sin function.
template <> void ngraph::runtime::plaidml::ImplSin::Apply()
void Impl<op::Sin>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -151,12 +159,11 @@ namespace ngraph ...@@ -151,12 +159,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sin(I)"}) .add(builder::Elementwise{"O", "sin(I)"})
.finalize()); .finalize());
} }
// sinh performs a simple elementwise hyperbolic sin function. // sinh performs a simple elementwise hyperbolic sin function.
template <> void ngraph::runtime::plaidml::ImplSinh::Apply()
void Impl<op::Sinh>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -164,12 +171,11 @@ namespace ngraph ...@@ -164,12 +171,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sinh(I)"}) .add(builder::Elementwise{"O", "sinh(I)"})
.finalize()); .finalize());
} }
// sqrt performs a simple elementwise square root function. // sqrt performs a simple elementwise square root function.
template <> void ngraph::runtime::plaidml::ImplSqrt::Apply()
void Impl<op::Sqrt>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -177,12 +183,11 @@ namespace ngraph ...@@ -177,12 +183,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sqrt(I)"}) .add(builder::Elementwise{"O", "sqrt(I)"})
.finalize()); .finalize());
} }
// tan performs a simple elementwise tangent function. // tan performs a simple elementwise tangent function.
template <> void ngraph::runtime::plaidml::ImplTan::Apply()
void Impl<op::Tan>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -190,12 +195,11 @@ namespace ngraph ...@@ -190,12 +195,11 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tan(I)"}) .add(builder::Elementwise{"O", "tan(I)"})
.finalize()); .finalize());
} }
// tanh performs a simple elementwise hyperbolic tangent function. // tanh performs a simple elementwise hyperbolic tangent function.
template <> void ngraph::runtime::plaidml::ImplTanh::Apply()
void Impl<op::Tanh>::operator()() {
{
check_inputs(1); check_inputs(1);
check_outputs(1); check_outputs(1);
set_output(start_tile_function() set_output(start_tile_function()
...@@ -203,24 +207,4 @@ namespace ngraph ...@@ -203,24 +207,4 @@ namespace ngraph
.add(builder::Output{"O"}) .add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tanh(I)"}) .add(builder::Elementwise{"O", "tanh(I)"})
.finalize()); .finalize());
}
namespace
{
Impl<op::Acos>::Registration register_acos;
Impl<op::Asin>::Registration register_asin;
Impl<op::Atan>::Registration register_atan;
Impl<op::Cos>::Registration register_cos;
Impl<op::Cosh>::Registration register_cosh;
Impl<op::Exp>::Registration register_exp;
Impl<op::Log>::Registration register_log;
Impl<op::Power>::Registration register_power;
Impl<op::Sin>::Registration register_sin;
Impl<op::Sinh>::Registration register_sinh;
Impl<op::Sqrt>::Registration register_sqrt;
Impl<op::Tan>::Registration register_tan;
Impl<op::Tanh>::Registration register_tanh;
}
}
}
} }
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_ops_winograd.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace vp = vertexai::plaidml;
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplWinograd, OpImpl<plaidml::op::Winograd>);
}
}
}
ngraph::runtime::plaidml::op::Winograd::Winograd(std::shared_ptr<plaidml::op::Convolution> conv,
const NodeVector& args)
: Op{"Winograd", args}
, m_conv{std::move(conv)}
{
constructor_validate_and_infer_types();
}
void ngraph::runtime::plaidml::op::Winograd::validate_and_infer_types()
{
set_output_type(0, m_conv->get_element_type(), m_conv->get_output_partial_shape(0));
}
std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Winograd::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 5)
{
throw ngraph_error{"Winograd requires five inputs (data, filters, A, B, and G)"};
}
return std::make_shared<Winograd>(m_conv, new_args);
}
void ngraph::runtime::plaidml::ImplWinograd::Apply()
{
check_inputs(5);
check_outputs(1);
const auto& data_shape = op().get_input_shape(0);
const auto& filters_shape = op().get_input_shape(1);
const auto& padding_above = op().get_conv()->get_src()->get_padding_above();
const auto& padding_below = op().get_conv()->get_src()->get_padding_below();
vp::variable xo(static_cast<int64_t>(data_shape.at(1) + padding_below.at(0) +
padding_above.at(0) - filters_shape.at(0) + 1));
vp::variable yo(static_cast<int64_t>(data_shape.at(2) + padding_below.at(1) +
padding_above.at(1) - filters_shape.at(1) + 1));
vp::variable xp(static_cast<int64_t>(padding_below.at(0)));
vp::variable yp(static_cast<int64_t>(padding_below.at(1)));
set_output(vp::function{R"(
function (I[N, X, Y, CI], K[S, S, CI, CO], A[BI, BO], B[BI, BI], G[BI, S], XO, YO, XP, YP) -> (O) {
Assert = assert_winograd_valid(BI - CI + 1 == BO);
XB = (XO + BO - 1) / BO;
YB = (YO + BO - 1) / BO;
U1[i, j, ci, co : BI, S, CI, CO] = +(G[i, k] * K[k, j, ci, co]);
U[i, j, ci, co : BI, BI, CI, CO] = +(U1[i, k, ci, co] * G[j, k]);
V1[n, i, j, x, y, ci : N, BI, BI, XB, YB, CI] = +(B[k, i] * I[n, BO*x + k - XP, BO*y + j - YP, ci]);
V[n, i, j, x, y, ci : N, BI, BI, XB, YB, CI] = +(V1[n, i, k, x, y, ci] * B[k, j]);
M[n, i, j, x, y, co : N, BI, BI, XB, YB, CO] = +(V[n, i, j, x, y, ci] * U[i, j, ci, co]);
O1[n, i, j, x, y, co : N, BO, BI, XB, YB, CO] = +(A[k, i] * M[n, k, j, x, y, co]);
O[n, BO*x + i, BO*y + j, co : N, XO, YO, CO] = +(O1[n, i, k, x, y, co] * A[k, j]) no_defract;
})"}(op_input(0), op_input(1), op_input(2), op_input(3), op_input(4), xo, yo, xp, yp));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_convolution.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Winograd;
}
}
}
}
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{
public:
Winograd(std::shared_ptr<Convolution> conv, const NodeVector& args);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<Convolution> get_conv() const { return m_conv; }
private:
std::shared_ptr<Convolution> m_conv;
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_replicate.hpp"
ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
{
auto concat_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
return dynamic_cast<ngraph::op::Concat*>(node.get()) != nullptr;
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto concat = std::dynamic_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments();
// Elide one-argument concats.
if (args.size() == 1)
{
replace_node(concat, args.at(0));
return true;
}
// Check for a run of inputs from the same source -- if we see
// one, we can simplify it, and otherwise, we already have the
// best Concat we can make.
{
bool found_input_run = false;
std::size_t prev_instance_id = concat->get_instance_id(); // This will never be an arg
for (const auto& arg : args)
{
auto current_instance_id = arg->get_instance_id();
if (current_instance_id == prev_instance_id)
{
found_input_run = true;
break;
}
prev_instance_id = current_instance_id;
}
if (!found_input_run)
{
return false;
}
}
// Merge runs with the same input into Replicate calls.
NodeVector new_args;
auto run_begin = args.begin();
// N.B. There's at least one argument to concat at this point
// (actually, two, but we only care that there's at least
// one), so run_end is still valid after this incremenent.
auto run_end = run_begin + 1;
for (;;)
{
// Invariants:
// * [run_begin..run_end) is a range of identical arguments
// * run_begin < run_end (there's at least one member of the range).
if (run_end == args.end() || *run_begin != *run_end)
{
// End of the range.
if (run_end - run_begin == 1)
{
new_args.emplace_back(*run_begin);
}
else
{
new_args.emplace_back(std::make_shared<plaidml::op::Replicate>(
*run_begin, concat->get_concatenation_axis(), run_end - run_begin));
}
if (run_end == args.end())
{
break;
}
run_begin = run_end;
}
++run_end;
}
// Re-check for single-input concat.
if (new_args.size() == 1)
{
replace_node(concat, new_args.at(0));
return true;
}
// Build a replacement concat.
auto new_concat =
std::make_shared<ngraph::op::Concat>(new_args, concat->get_concatenation_axis());
replace_node(std::move(concat), std::move(new_concat));
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ConcatElision;
}
}
}
}
// A pass to elide unnecessary concats:
//
// ) Concats with a single input and single output can be entitely
// elided.
//
// ) Concats with multiples of the same input can be replaced by
// Replicate.
class ngraph::runtime::plaidml::pass::ConcatElision final : public ngraph::pass::GraphRewrite
{
public:
ConcatElision();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <typeindex>
#include "ngraph/graph_util.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_tile.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data()
{
auto producer_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
static const std::unordered_set<std::type_index> logical_producers{
std::type_index{typeid(ngraph::op::And)},
std::type_index{typeid(ngraph::op::Equal)},
std::type_index{typeid(ngraph::op::Greater)},
std::type_index{typeid(ngraph::op::GreaterEq)},
std::type_index{typeid(ngraph::op::Less)},
std::type_index{typeid(ngraph::op::LessEq)},
std::type_index{typeid(ngraph::op::Not)},
std::type_index{typeid(ngraph::op::NotEqual)},
std::type_index{typeid(ngraph::op::Or)}};
const ngraph::Node* node_ptr = node.get();
// True iff this node produces a logical output.
return logical_producers.count(std::type_index(typeid(*node_ptr))) != 0;
});
auto data_consumer_op = std::make_shared<pattern::op::Any>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
static const std::unordered_set<std::type_index> logical_consumers{
std::type_index{typeid(ngraph::op::And)},
std::type_index{typeid(ngraph::op::Equal)},
std::type_index{typeid(ngraph::op::Not)},
std::type_index{typeid(ngraph::op::NotEqual)},
std::type_index{typeid(ngraph::op::Or)}};
const ngraph::Node* node_ptr = node.get();
// True iff this node should not be presented with a logical output.
return logical_consumers.count(std::type_index(typeid(*node_ptr))) == 0;
},
NodeVector{producer_op});
pattern::graph_rewrite_callback callback = [producer_op](pattern::Matcher& m) {
auto consumer = m.get_match_root();
auto producer = m.get_pattern_map()[producer_op];
NGRAPH_DEBUG << "Adding conversion for " << producer->description() << " -> "
<< consumer->description();
ngraph::insert_new_node_between(
producer,
consumer,
std::make_shared<op::Tile>(
"ConvertLogicalToData",
vertexai::plaidml::function{"function (I) -> (O) { O = as_int(I ? 1 : 0, 8);}"},
NodeVector{producer},
std::vector<std::tuple<element::Type, PartialShape>>{
{element::i8, PartialShape{producer->get_output_shape(0)}}}));
return true;
};
auto m = std::make_shared<pattern::Matcher>(data_consumer_op, callback);
add_matcher(m);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ExplicitLogicals;
}
}
}
}
// ExplicitLogicals handles conversion between logical and binary
// values.
//
// In PlaidML, the logical values do not have well-defined binary
// representations -- for example, due to how some frameworks handle
// vectorization, 'true' might be represented as a binary '1' or as a
// binary '-1', even within the same kernel.
//
// nGraph semantics are that 'false' == '0' and 'true' == '1', that
// booleans are exactly equivalent to binary uint8 values, and that
// binary uint8 values can be passed directly into logical operations.
//
// The ExplicitLogicals pass inserts conversions as needed to preserve
// the semantics expected by the other nGraph operations.
class ngraph::runtime::plaidml::pass::ExplicitLogicals final : public ngraph::pass::GraphRewrite
{
public:
ExplicitLogicals()
: GraphRewrite()
{
construct_logical_to_data();
}
private:
void construct_logical_to_data();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
{
auto src_op = std::make_shared<pattern::op::Label>(
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
auto broadcast_op = std::make_shared<op::Broadcast>(src_op, Shape{}, AxisSet{});
auto target_op = std::make_shared<pattern::op::AnyOf>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node);
},
NodeVector{broadcast_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
// Since the broadcast is going to an elementwise operation, we
// can always replace it with an equivalent reshape that uses ones
// for the broadcast axes.
auto src = m.get_matched_nodes().at(2);
Shape src_shape = src->get_shape();
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(m.get_matched_nodes().at(1));
AxisVector reshape_order;
Shape reshape_shape;
std::size_t input_dim = 0;
std::size_t didx_limit = broadcast->get_broadcast_shape().size();
for (std::size_t didx = 0; didx < didx_limit; ++didx)
{
if (broadcast->get_broadcast_axes().count(didx))
{
reshape_shape.emplace_back(1);
}
else
{
reshape_order.emplace_back(input_dim);
reshape_shape.emplace_back(src_shape.at(input_dim++));
}
}
auto reshape = std::make_shared<op::Reshape>(src, reshape_order, reshape_shape);
replace_node(broadcast, reshape);
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ImplicitBroadcast;
}
}
}
}
// The PlaidML nGraph runtime's implementation of the Broadcast
// operation requires a contraction, and then the broadcasted output
// needs to be read by the downstream operation.
//
// Most explicit Broadcast operations are passed as inputs to
// elementwise operations. When a tensor is used as an input to an
// elementwise operation, PlaidML automatically provides NumPy
// broadcast semantics.
//
// So eliding Broadcast operations can significantly reduce the IO
// needed by an elementwise operation, and eliminates an unnecessary
// contraction.
class ngraph::runtime::plaidml::pass::ImplicitBroadcast final : public ngraph::pass::GraphRewrite
{
public:
ImplicitBroadcast();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_convolution.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{
auto convolution_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
return pattern::has_class<ngraph::op::Convolution>()(node) ||
pattern::has_class<ngraph::op::ConvolutionBackpropData>()(node) ||
pattern::has_class<ngraph::op::ConvolutionBackpropFilters>()(node);
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto to_transpose = [](const std::shared_ptr<Node>& node) -> ngraph::op::Reshape* {
if (!node)
{
return nullptr;
}
auto* reshape = dynamic_cast<ngraph::op::Reshape*>(node.get());
if (reshape && reshape->get_is_transpose())
{
return reshape;
}
return nullptr;
};
auto to_axes = [](const std::shared_ptr<Node>& node, ngraph::op::Reshape* reshape) {
if (reshape)
{
return reshape->get_input_order();
}
AxisVector result(node->get_shape().size());
std::iota(result.begin(), result.end(), 0);
return result;
};
std::shared_ptr<Node> node = m.get_match_root();
std::shared_ptr<Node> output;
auto users = node->get_users(true);
if (users.size() == 1)
{
output = users[0];
}
auto target = node;
auto* output_transpose = to_transpose(output);
if (output_transpose)
{
target = output;
}
// N.B. For the output axes, we can either use the convolution
// or the final output op -- but there might not be an output
// op. Using target always works.
AxisVector out_axes = to_axes(target, output_transpose);
auto lhs = node->get_arguments().at(0);
auto* lhs_transpose = to_transpose(lhs);
if (lhs_transpose)
{
lhs = lhs_transpose->get_arguments().at(0);
}
AxisVector lhs_axes = to_axes(lhs, lhs_transpose);
auto rhs = node->get_arguments().at(1);
auto* rhs_transpose = to_transpose(rhs);
if (rhs_transpose)
{
rhs = rhs_transpose->get_arguments().at(0);
}
AxisVector rhs_axes = to_axes(rhs, rhs_transpose);
{
auto conv = std::dynamic_pointer_cast<ngraph::op::Convolution>(node);
if (conv)
{
replace_node(target,
std::make_shared<plaidml::op::Convolution>(conv,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
{
auto conv_bp_data =
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropData>(node);
if (conv_bp_data)
{
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropData>(conv_bp_data,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
{
auto conv_bp_filters =
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropFilters>(node);
if (conv_bp_filters)
{
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(conv_bp_filters,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
return false;
};
add_matcher(std::make_shared<pattern::Matcher>(convolution_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class LowerConvolutions;
}
}
}
}
// Lowers op::Convolution, op::ConvolutionDataBackprop, and
// op::ConvolutionFilterBackprop into plaidml::op::Convolution,
// allowing for unification with transposition operations.
class ngraph::runtime::plaidml::pass::LowerConvolutions final : public ngraph::pass::GraphRewrite
{
public:
LowerConvolutions();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_replicate.hpp"
ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
{
auto upper_replicate_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
return pattern::has_class<plaidml::op::Replicate>()(node);
});
auto lower_replicate_op = std::make_shared<pattern::op::Any>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return pattern::has_class<plaidml::op::Replicate>()(node);
},
NodeVector{upper_replicate_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto nodes = m.get_matched_nodes();
auto lower = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
std::vector<size_t> axes = lower->get_replication_axes();
const std::vector<size_t>& upper_axes = upper->get_replication_axes();
auto uit = upper_axes.begin();
for (auto ait = axes.begin(); ait != axes.end(); ++ait, ++uit)
{
*ait *= *uit;
}
replace_node(lower,
std::make_shared<plaidml::op::Replicate>(upper->get_arguments().at(0),
std::move(axes)));
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(lower_replicate_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ReplicateCombination;
}
}
}
}
// Combines adjacent Replicate operations.
class ngraph::runtime::plaidml::pass::ReplicateCombination final : public ngraph::pass::GraphRewrite
{
public:
ReplicateCombination();
};
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -130,12 +130,3 @@ std::string ngraph::runtime::plaidml::tile_converter(const std::string& tensor_n ...@@ -130,12 +130,3 @@ std::string ngraph::runtime::plaidml::tile_converter(const std::string& tensor_n
} }
return tile_converter(tensor_name, to_plaidml(element_type)); return tile_converter(tensor_name, to_plaidml(element_type));
} }
vp::variable ngraph::runtime::plaidml::plaidml_logical_to_data(vp::variable var, bool debug)
{
return builder::Function{"logicalToData", debug}
.add(builder::Input{var, "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "as_int(I ? 1 : 0, 8)"})
.finalize();
}
...@@ -46,9 +46,6 @@ namespace ngraph ...@@ -46,9 +46,6 @@ namespace ngraph
std::string tile_converter(const std::string& tensor_name, std::string tile_converter(const std::string& tensor_name,
const ngraph::element::Type& element_type); const ngraph::element::Type& element_type);
vertexai::plaidml::variable plaidml_logical_to_data(vertexai::plaidml::variable var,
bool debug);
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment