Commit f5b2d581 authored by Rob Earhart's avatar Rob Earhart Committed by Scott Cyphers

API cleanup & performance passes (#2242)

* nGraph compat updates

* Simplify op structure

* Add zero-dim elimination pass

* Add logical data conversion pass

* Add graphviz support

* Add implicit broadcast pass

* Elide unnecessary reshape ops

* Merge reshapes into convolutions

* Add winograd convolution support

* Allow three-input maxpool backprop

* Add concat elision

* Combine replication operations

* Elide unneeded replicates

* Style update
parent c11644ec
...@@ -41,10 +41,21 @@ set(SRC ...@@ -41,10 +41,21 @@ set(SRC
plaidml_ops_pool.cpp plaidml_ops_pool.cpp
plaidml_ops_reduce.cpp plaidml_ops_reduce.cpp
plaidml_ops_replace_slice.cpp plaidml_ops_replace_slice.cpp
plaidml_ops_replicate.cpp
plaidml_ops_reverse.cpp plaidml_ops_reverse.cpp
plaidml_ops_slice.cpp plaidml_ops_slice.cpp
plaidml_ops_softmax.cpp plaidml_ops_softmax.cpp
plaidml_ops_tile.cpp
plaidml_ops_transcendental.cpp plaidml_ops_transcendental.cpp
plaidml_ops_winograd.cpp
plaidml_pass_concat_elision.cpp
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_reshape_elision.cpp
plaidml_pass_winograd.cpp
plaidml_tensor.cpp plaidml_tensor.cpp
plaidml_translate.cpp plaidml_translate.cpp
) )
......
...@@ -41,10 +41,11 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe ...@@ -41,10 +41,11 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe
&m_config, element_type, shape, "direct_data", memory_pointer); &m_config, element_type, shape, "direct_data", memory_pointer);
} }
bool ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func) std::shared_ptr<ngraph::Function>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
{ {
m_cache.compile(func, &m_compiler); m_cache.compile(func, &m_compiler);
return true; return func;
} }
bool ngraph::runtime::plaidml::PlaidML_Backend::call( bool ngraph::runtime::plaidml::PlaidML_Backend::call(
......
...@@ -46,7 +46,7 @@ public: ...@@ -46,7 +46,7 @@ public:
std::shared_ptr<ngraph::runtime::Tensor> create_tensor( std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final; const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final;
bool compile(std::shared_ptr<Function> func) final; std::shared_ptr<Function> compile(std::shared_ptr<Function> func) final;
bool call(std::shared_ptr<Function> func, bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs, const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
#include "ngraph/except.hpp"
#include "ngraph/runtime/plaidml/plaidml_builder.hpp" #include "ngraph/runtime/plaidml/plaidml_builder.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp" #include "ngraph/runtime/plaidml/plaidml_logger.hpp"
...@@ -467,6 +468,24 @@ ngraph::runtime::plaidml::builder::Input&& ...@@ -467,6 +468,24 @@ ngraph::runtime::plaidml::builder::Input&&
return std::move(*this); return std::move(*this);
} }
void ngraph::runtime::plaidml::builder::Input::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size())
{
throw ngraph_error{"Mismatched shape in input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> dims(m_dims.size());
for (auto dim : m_dims)
{
dims[axes[idx]] = dim;
idx++;
}
m_dims.clear();
m_dims.insert(m_dims.end(), dims.begin(), dims.end());
}
ngraph::runtime::plaidml::builder::Output::Output(std::string name) ngraph::runtime::plaidml::builder::Output::Output(std::string name)
: m_name{std::move(name)} : m_name{std::move(name)}
{ {
...@@ -611,6 +630,28 @@ ngraph::runtime::plaidml::builder::ContractionOutput&& ...@@ -611,6 +630,28 @@ ngraph::runtime::plaidml::builder::ContractionOutput&&
return std::move(*this); return std::move(*this);
} }
void ngraph::runtime::plaidml::builder::ContractionOutput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size() || axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction output transposition"};
}
std::vector<std::string> dims{m_dims.begin(), m_dims.end()};
m_dims.clear();
for (auto idx : axes)
{
m_dims.emplace_back(dims[idx]);
}
std::vector<std::string> indices{m_indices.begin(), m_indices.end()};
m_indices.clear();
for (auto idx : axes)
{
m_indices.emplace_back(indices[idx]);
}
}
ngraph::runtime::plaidml::builder::ContractionInput& ngraph::runtime::plaidml::builder::ContractionInput&
ngraph::runtime::plaidml::builder::ContractionInput::add_indices(std::string prefix, ngraph::runtime::plaidml::builder::ContractionInput::add_indices(std::string prefix,
std::size_t first, std::size_t first,
...@@ -675,6 +716,24 @@ ngraph::runtime::plaidml::builder::ContractionInput&& ...@@ -675,6 +716,24 @@ ngraph::runtime::plaidml::builder::ContractionInput&&
return std::move(*this); return std::move(*this);
} }
void ngraph::runtime::plaidml::builder::ContractionInput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> indices(m_indices.size());
for (auto dim : m_indices)
{
indices[axes[idx]] = dim;
idx++;
}
m_indices.clear();
m_indices.insert(m_indices.end(), indices.begin(), indices.end());
}
ngraph::runtime::plaidml::builder::UnaryContraction::UnaryContraction(std::string agg_op) ngraph::runtime::plaidml::builder::UnaryContraction::UnaryContraction(std::string agg_op)
: m_agg_op{std::move(agg_op)} : m_agg_op{std::move(agg_op)}
{ {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp" #include "ngraph/runtime/plaidml/plaidml_config.hpp"
// Utilities for constructing PlaidML functions. // Utilities for constructing PlaidML functions.
...@@ -136,9 +137,22 @@ public: ...@@ -136,9 +137,22 @@ public:
return std::move(*this); return std::move(*this);
} }
Input& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
Input&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private: private:
friend class Function; friend class Function;
void apply_transpose(const AxisVector& axes);
vertexai::plaidml::variable m_var; vertexai::plaidml::variable m_var;
std::string m_name; std::string m_name;
std::list<std::string> m_dims; std::list<std::string> m_dims;
...@@ -230,9 +244,22 @@ public: ...@@ -230,9 +244,22 @@ public:
return std::move(*this); return std::move(*this);
} }
ContractionOutput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionOutput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private: private:
friend class Function; friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name; std::string m_name;
std::list<std::string> m_indices; std::list<std::string> m_indices;
std::list<std::string> m_dims; std::list<std::string> m_dims;
...@@ -268,9 +295,22 @@ public: ...@@ -268,9 +295,22 @@ public:
return std::move(*this); return std::move(*this);
} }
ContractionInput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionInput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private: private:
friend class Function; friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name; std::string m_name;
std::list<std::string> m_indices; std::list<std::string> m_indices;
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp" #include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pass/algebraic_simplification.hpp" #include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp" #include "ngraph/pass/any_all_replacement.hpp"
...@@ -25,8 +26,18 @@ ...@@ -25,8 +26,18 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp" #include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp" #include "ngraph/runtime/plaidml/plaidml_logger.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
namespace namespace
{ {
...@@ -78,15 +89,44 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction> ...@@ -78,15 +89,44 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>(); pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>(); pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>(); pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>(); pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>(); pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>(); pass_manager.register_pass<ngraph::pass::CoreFusion>();
// N.B. We'd like to register ngraph::pass::GetOutputElementElimination, but it breaks BatchNorm // N.B. We'd like to register ngraph::pass::GetOutputElementElimination, but it breaks BatchNorm
// backprop // backprop
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ExplicitLogicals>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReshapeElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (m_config->winograd)
{
pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>();
}
if (!m_config->graphviz.empty())
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>(m_config->graphviz);
}
// N.B. When we rewrite the graph, there are cases where we
// produce nodes that contain validation errors. A good example
// is in the ImplicitBroadcast pass -- after this pass, there may
// be elementwise operations whose inputs are not all the same
// shape.
//
// The caller may wish to perform operations (e.g. clone) on their
// supplied function that will cause validation to occur. So
// before we rewrite, we make our own copy of the function.
func = clone_function(*func);
// Apply passes.
pass_manager.run_passes(func); pass_manager.run_passes(func);
// Compile the resulting function.
Build b; Build b;
build(std::move(func), &b); build(std::move(func), &b);
return std::make_shared<CompiledFunction>(std::move(b)); return std::make_shared<CompiledFunction>(std::move(b));
...@@ -98,7 +138,7 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B ...@@ -98,7 +138,7 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
b->config = m_config; b->config = m_config;
b->func = func; b->func = func;
const auto* op_map = OpImplMap(); const auto* op_map = GlobalOpImplMap();
for (const auto& op_ptr : func->get_ordered_ops()) for (const auto& op_ptr : func->get_ordered_ops())
{ {
...@@ -114,6 +154,6 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B ...@@ -114,6 +154,6 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
std::string{"The PlaidML backend doesn't currently implement the '"} + std::string{"The PlaidML backend doesn't currently implement the '"} +
op->description() + "' operation"}; op->description() + "' operation"};
} }
it->second(b, *op); it->second->Apply(b, op);
} }
} }
...@@ -77,8 +77,10 @@ ngraph::runtime::plaidml::Config ...@@ -77,8 +77,10 @@ ngraph::runtime::plaidml::Config
bool help = false; bool help = false;
bool list = false; bool list = false;
bool debug = false; bool debug = false;
bool winograd = false;
std::size_t device_idx = 0; std::size_t device_idx = 0;
std::string eventlog_config; std::string eventlog_config;
std::string graphviz;
#ifdef NGRAPH_DEBUG_ENABLE #ifdef NGRAPH_DEBUG_ENABLE
debug = true; debug = true;
...@@ -155,7 +157,7 @@ ngraph::runtime::plaidml::Config ...@@ -155,7 +157,7 @@ ngraph::runtime::plaidml::Config
return (oname_end - oname_begin == len) && !strncmp(oname_begin, opt, len); return (oname_end - oname_begin == len) && !strncmp(oname_begin, opt, len);
}; };
auto oval_len = oval_end - oval_begin; std::size_t oval_len = oval_end - oval_begin;
bool has_oval = oval_begin != oname_end; bool has_oval = oval_begin != oname_end;
// N.B. oval_len != 0 => has_oval, but there's no other relationship. // N.B. oval_len != 0 => has_oval, but there's no other relationship.
...@@ -229,6 +231,25 @@ ngraph::runtime::plaidml::Config ...@@ -229,6 +231,25 @@ ngraph::runtime::plaidml::Config
continue; continue;
} }
// Check for visualization (GraphViz output)
if (is_opt("graphviz"))
{
if (!oval_len)
{
throw std::invalid_argument{"PlaidML graphviz requires a value"};
}
graphviz = std::string{oval_begin, oval_len};
continue;
}
// Check for Winograd. (Winograd is sometimes a performance
// boost, but not always, so we make it optional.)
if (is_opt("winograd"))
{
winograd = true;
continue;
}
// Reject unknown options // Reject unknown options
err = true; err = true;
} }
...@@ -236,7 +257,7 @@ ngraph::runtime::plaidml::Config ...@@ -236,7 +257,7 @@ ngraph::runtime::plaidml::Config
constexpr char help_text[] = constexpr char help_text[] =
"PlaidML Backend Specification: \"" "PlaidML Backend Specification: \""
"PlaidML[:[device_index][,debug][,help][,list_devices][," "PlaidML[:[device_index][,debug][,help][,list_devices][,"
"eventlog=<filename>]]\". For example: \"PlaidML\", \"" "eventlog=<filename>][,graphviz=<filename>][,winograd]]\". For example: \"PlaidML\", \""
"PlaidML:0,list_devices\""; "PlaidML:0,list_devices\"";
if (err) if (err)
{ {
...@@ -269,5 +290,9 @@ ngraph::runtime::plaidml::Config ...@@ -269,5 +290,9 @@ ngraph::runtime::plaidml::Config
result.debug = debug; result.debug = debug;
result.graphviz = graphviz;
result.winograd = winograd;
return result; return result;
} }
...@@ -39,4 +39,6 @@ struct ngraph::runtime::plaidml::Config ...@@ -39,4 +39,6 @@ struct ngraph::runtime::plaidml::Config
std::shared_ptr<vertexai::ctx> ctx; std::shared_ptr<vertexai::ctx> ctx;
std::shared_ptr<vertexai::plaidml::device> dev; std::shared_ptr<vertexai::plaidml::device> dev;
bool debug; bool debug;
bool winograd;
std::string graphviz;
}; };
...@@ -171,7 +171,7 @@ ngraph::runtime::plaidml::ConvPoolFormatter::ConvPoolFormatter( ...@@ -171,7 +171,7 @@ ngraph::runtime::plaidml::ConvPoolFormatter::ConvPoolFormatter(
} }
ngraph::runtime::plaidml::builder::Input ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var) ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var) const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -191,7 +191,7 @@ ngraph::runtime::plaidml::builder::Input ...@@ -191,7 +191,7 @@ ngraph::runtime::plaidml::builder::Input
} }
ngraph::runtime::plaidml::builder::Input ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var) ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var) const
{ {
if (m_deriv == DerivType::Data && m_op == OpType::Conv) if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{ {
...@@ -216,7 +216,7 @@ ngraph::runtime::plaidml::builder::Input ...@@ -216,7 +216,7 @@ ngraph::runtime::plaidml::builder::Input
} }
ngraph::runtime::plaidml::builder::Input ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var) ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var) const
{ {
if (m_deriv == DerivType::None) if (m_deriv == DerivType::None)
{ {
...@@ -240,7 +240,7 @@ ngraph::runtime::plaidml::builder::Input ...@@ -240,7 +240,7 @@ ngraph::runtime::plaidml::builder::Input
} }
ngraph::runtime::plaidml::builder::Output ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header() ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header() const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -254,7 +254,7 @@ ngraph::runtime::plaidml::builder::Output ...@@ -254,7 +254,7 @@ ngraph::runtime::plaidml::builder::Output
} }
ngraph::runtime::plaidml::builder::Output ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header() ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header() const
{ {
if (m_deriv != DerivType::Data) if (m_deriv != DerivType::Data)
{ {
...@@ -272,7 +272,7 @@ ngraph::runtime::plaidml::builder::Output ...@@ -272,7 +272,7 @@ ngraph::runtime::plaidml::builder::Output
} }
ngraph::runtime::plaidml::builder::Output ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header() ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header() const
{ {
if (m_deriv != DerivType::None) if (m_deriv != DerivType::None)
{ {
...@@ -282,7 +282,7 @@ ngraph::runtime::plaidml::builder::Output ...@@ -282,7 +282,7 @@ ngraph::runtime::plaidml::builder::Output
} }
ngraph::runtime::plaidml::builder::ContractionOutput ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body() ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body() const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -307,7 +307,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput ...@@ -307,7 +307,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
} }
ngraph::runtime::plaidml::builder::ContractionOutput ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body() ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body() const
{ {
if (m_deriv != DerivType::Data) if (m_deriv != DerivType::Data)
{ {
...@@ -353,7 +353,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput ...@@ -353,7 +353,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
} }
ngraph::runtime::plaidml::builder::ContractionOutput ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body() ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body() const
{ {
if (m_deriv != DerivType::None && m_op == OpType::Conv) if (m_deriv != DerivType::None && m_op == OpType::Conv)
{ {
...@@ -405,7 +405,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput ...@@ -405,7 +405,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
} }
ngraph::runtime::plaidml::builder::ContractionInput ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body() ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body() const
{ {
if (m_op != OpType::Conv) if (m_op != OpType::Conv)
{ {
...@@ -425,7 +425,7 @@ ngraph::runtime::plaidml::builder::ContractionInput ...@@ -425,7 +425,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
} }
ngraph::runtime::plaidml::builder::ContractionInput ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body() ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body() const
{ {
if (m_deriv == DerivType::Data && m_op == OpType::Conv) if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{ {
...@@ -449,7 +449,7 @@ ngraph::runtime::plaidml::builder::ContractionInput ...@@ -449,7 +449,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
} }
ngraph::runtime::plaidml::builder::ContractionInput ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body() ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body() const
{ {
if (m_deriv == DerivType::None) if (m_deriv == DerivType::None)
{ {
...@@ -482,7 +482,7 @@ ngraph::runtime::plaidml::builder::ContractionInput ...@@ -482,7 +482,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
} }
ngraph::runtime::plaidml::builder::UnaryContraction ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones() ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones() const
{ {
if (m_op != OpType::AvgPool) if (m_op != OpType::AvgPool)
{ {
...@@ -501,7 +501,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction ...@@ -501,7 +501,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
} }
ngraph::runtime::plaidml::builder::UnaryContraction ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Count() ngraph::runtime::plaidml::ConvPoolFormatter::Count() const
{ {
if (m_op != OpType::AvgPool) if (m_op != OpType::AvgPool)
{ {
...@@ -535,7 +535,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction ...@@ -535,7 +535,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
} }
ngraph::runtime::plaidml::builder::UnaryContraction ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction() ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction() const
{ {
std::string agg_op; std::string agg_op;
switch (m_op) switch (m_op)
...@@ -558,7 +558,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction ...@@ -558,7 +558,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
} }
ngraph::runtime::plaidml::builder::TernaryContraction ngraph::runtime::plaidml::builder::TernaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction() ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction() const
{ {
builder::ContractionOutput output{"DI"}; builder::ContractionOutput output{"DI"};
output.add_indices({n(), c()}).add_dims({N(), C()}); output.add_indices({n(), c()}).add_dims({N(), C()});
...@@ -595,27 +595,27 @@ ngraph::runtime::plaidml::builder::TernaryContraction ...@@ -595,27 +595,27 @@ ngraph::runtime::plaidml::builder::TernaryContraction
.set_third(incoming_deriv); .set_third(incoming_deriv);
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::c() std::string ngraph::runtime::plaidml::ConvPoolFormatter::c() const
{ {
return "c"; return "c";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci() std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci() const
{ {
return "ci"; return "ci";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::co() std::string ngraph::runtime::plaidml::ConvPoolFormatter::co() const
{ {
return "co"; return "co";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::n() std::string ngraph::runtime::plaidml::ConvPoolFormatter::n() const
{ {
return "n"; return "n";
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() const
{ {
if (m_xfs.empty()) if (m_xfs.empty())
{ {
...@@ -629,7 +629,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() ...@@ -629,7 +629,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs()
return m_xfs; return m_xfs;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() const
{ {
if (m_xis.empty()) if (m_xis.empty())
{ {
...@@ -652,7 +652,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() ...@@ -652,7 +652,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis()
return m_xis; return m_xis;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() const
{ {
if (m_xos.empty()) if (m_xos.empty())
{ {
...@@ -666,27 +666,27 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() ...@@ -666,27 +666,27 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos()
return m_xos; return m_xos;
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::C() std::string ngraph::runtime::plaidml::ConvPoolFormatter::C() const
{ {
return "C"; return "C";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI() std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI() const
{ {
return "CI"; return "CI";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO() std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO() const
{ {
return "CO"; return "CO";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::N() std::string ngraph::runtime::plaidml::ConvPoolFormatter::N() const
{ {
return "N"; return "N";
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() const
{ {
if (m_XFs.empty()) if (m_XFs.empty())
{ {
...@@ -707,7 +707,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() ...@@ -707,7 +707,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs()
return m_XFs; return m_XFs;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() const
{ {
if (m_XIs.empty()) if (m_XIs.empty())
{ {
...@@ -728,7 +728,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() ...@@ -728,7 +728,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs()
return m_XIs; return m_XIs;
} }
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() const
{ {
if (m_XOs.empty()) if (m_XOs.empty())
{ {
...@@ -765,7 +765,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() ...@@ -765,7 +765,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs()
return m_XOs; return m_XOs;
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() const
{ {
if (m_deriv == DerivType::Filter) if (m_deriv == DerivType::Filter)
{ {
...@@ -774,7 +774,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() ...@@ -774,7 +774,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::F()
return "F"; return "F";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() const
{ {
if (m_deriv == DerivType::Data && m_op == OpType::Conv) if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{ {
...@@ -783,7 +783,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() ...@@ -783,7 +783,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::I()
return "I"; return "I";
} }
std::string ngraph::runtime::plaidml::ConvPoolFormatter::O() std::string ngraph::runtime::plaidml::ConvPoolFormatter::O() const
{ {
if (m_deriv != DerivType::None) if (m_deriv != DerivType::None)
{ {
......
...@@ -72,47 +72,47 @@ public: ...@@ -72,47 +72,47 @@ public:
ConvPoolFormatter::DerivType deriv); ConvPoolFormatter::DerivType deriv);
// Formatted tensors // Formatted tensors
builder::Input F_in_header(vertexai::plaidml::variable var); builder::Input F_in_header(vertexai::plaidml::variable var) const;
builder::Input I_in_header(vertexai::plaidml::variable var); builder::Input I_in_header(vertexai::plaidml::variable var) const;
builder::Input O_in_header(vertexai::plaidml::variable var); builder::Input O_in_header(vertexai::plaidml::variable var) const;
builder::Output F_out_header(); builder::Output F_out_header() const;
builder::Output I_out_header(); builder::Output I_out_header() const;
builder::Output O_out_header(); builder::Output O_out_header() const;
builder::ContractionOutput F_out_body(); builder::ContractionOutput F_out_body() const;
builder::ContractionOutput I_out_body(); builder::ContractionOutput I_out_body() const;
builder::ContractionOutput O_out_body(); builder::ContractionOutput O_out_body() const;
builder::ContractionInput F_in_body(); builder::ContractionInput F_in_body() const;
builder::ContractionInput I_in_body(); builder::ContractionInput I_in_body() const;
builder::ContractionInput O_in_body(); builder::ContractionInput O_in_body() const;
// Special Operations // Special Operations
builder::UnaryContraction Broadcast_Ones(); builder::UnaryContraction Broadcast_Ones() const;
builder::UnaryContraction Count(); builder::UnaryContraction Count() const;
builder::UnaryContraction PoolContraction(); builder::UnaryContraction PoolContraction() const;
builder::TernaryContraction PoolDerivContraction(); builder::TernaryContraction PoolDerivContraction() const;
// Index names / formulas // Index names / formulas
std::string c(); std::string c() const;
std::string ci(); std::string ci() const;
std::string co(); std::string co() const;
std::string n(); std::string n() const;
std::vector<std::string> xfs(); std::vector<std::string> xfs() const;
std::vector<std::string> xis(); std::vector<std::string> xis() const;
std::vector<std::string> xos(); std::vector<std::string> xos() const;
// Dimension names / formulas // Dimension names / formulas
std::string C(); std::string C() const;
std::string CI(); std::string CI() const;
std::string CO(); std::string CO() const;
std::string N(); std::string N() const;
std::vector<std::string> XFs(); std::vector<std::string> XFs() const;
std::vector<std::string> XIs(); std::vector<std::string> XIs() const;
std::vector<std::string> XOs(); std::vector<std::string> XOs() const;
// Tensor names // Tensor names
std::string F(); std::string F() const;
std::string I(); std::string I() const;
std::string O(); std::string O() const;
private: private:
std::size_t m_rank; std::size_t m_rank;
...@@ -126,10 +126,10 @@ private: ...@@ -126,10 +126,10 @@ private:
DerivType m_deriv = DerivType::None; DerivType m_deriv = DerivType::None;
ngraph::Shape m_filters_shape; ngraph::Shape m_filters_shape;
ngraph::Shape m_data_batch_shape; ngraph::Shape m_data_batch_shape;
std::vector<std::string> m_xfs; mutable std::vector<std::string> m_xfs;
std::vector<std::string> m_xis; mutable std::vector<std::string> m_xis;
std::vector<std::string> m_xos; mutable std::vector<std::string> m_xos;
std::vector<std::string> m_XFs; mutable std::vector<std::string> m_XFs;
std::vector<std::string> m_XIs; mutable std::vector<std::string> m_XIs;
std::vector<std::string> m_XOs; mutable std::vector<std::string> m_XOs;
}; };
...@@ -16,20 +16,8 @@ ...@@ -16,20 +16,8 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp" #include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph ngraph::runtime::plaidml::OpImplMap* ngraph::runtime::plaidml::GlobalOpImplMap()
{ {
namespace runtime static OpImplMap op_impl_map;
{ return &op_impl_map;
namespace plaidml
{
std::unordered_map<std::type_index, std::function<void(Build*, const ngraph::Node&)>>*
OpImplMap()
{
static std::unordered_map<std::type_index,
std::function<void(Build*, const ngraph::Node&)>>
op_impl_map;
return &op_impl_map;
}
}
}
} }
This diff is collapsed.
...@@ -23,95 +23,86 @@ namespace ngraph ...@@ -23,95 +23,86 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Concat views a tensor as a new type. NGRAPH_PLAIDML_OP_CLASS(ImplConcat, OpImpl<op::Concat>);
template <> }
void Impl<op::Concat>::operator()() }
{ }
check_outputs(1);
auto f = start_tile_function();
f.add(builder::Output{"O"});
std::size_t dim_count = op().get_shape().size();
std::ostringstream offset;
std::ostringstream oexpr;
std::ostringstream concat_dsize;
bool saw_non_zero_tensor = false;
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
{
if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
if (saw_non_zero_tensor)
{
concat_dsize << "+";
}
saw_non_zero_tensor = true;
concat_dsize << "I" << iidx << "_D" << op().get_concatenation_axis();
}
saw_non_zero_tensor = false; // Concat views a tensor as a new type.
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx) void ngraph::runtime::plaidml::ImplConcat::Apply()
{ {
if (!shape_size(op().get_input_shape(iidx))) check_outputs(1);
{
continue;
}
std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims(
"I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
if (idx == op().get_concatenation_axis())
{
out = concat_dsize.str();
}
else
{
s << "I" << iidx << "_D" << idx;
out = s.str();
}
}
})
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
s << "d" << idx;
if (saw_non_zero_tensor &&
idx == op().get_concatenation_axis())
{
s << " + " << offset.str();
}
out = s.str();
}
}))
.set(builder::ContractionInput{"I" + sidx}.add_indices(
"d", 0, dim_count)));
if (saw_non_zero_tensor)
{
oexpr << " + ";
offset << " + ";
}
oexpr << "E" << sidx;
offset << "I" << iidx << "_D" << op().get_concatenation_axis();
saw_non_zero_tensor = true;
}
f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize()); auto f = start_tile_function();
} f.add(builder::Output{"O"});
std::size_t dim_count = op().get_shape().size();
std::ostringstream offset;
std::ostringstream oexpr;
std::ostringstream concat_dsize;
bool saw_non_zero_tensor = false;
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
{
if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
if (saw_non_zero_tensor)
{
concat_dsize << "+";
}
saw_non_zero_tensor = true;
concat_dsize << "I" << iidx << "_D" << op().get_concatenation_axis();
}
namespace saw_non_zero_tensor = false;
{ for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
Impl<op::Concat>::Registration register_concat; {
} if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims("I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
if (idx == op().get_concatenation_axis())
{
out = concat_dsize.str();
}
else
{
s << "I" << iidx << "_D" << idx;
out = s.str();
}
}
})
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
s << "d" << idx;
if (saw_non_zero_tensor && idx == op().get_concatenation_axis())
{
s << " + " << offset.str();
}
out = s.str();
}
}))
.set(builder::ContractionInput{"I" + sidx}.add_indices("d", 0, dim_count)));
if (saw_non_zero_tensor)
{
oexpr << " + ";
offset << " + ";
} }
oexpr << "E" << sidx;
offset << "I" << iidx << "_D" << op().get_concatenation_axis();
saw_non_zero_tensor = true;
} }
f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize());
} }
...@@ -24,25 +24,20 @@ namespace ngraph ...@@ -24,25 +24,20 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Convert views a tensor as a new type. NGRAPH_PLAIDML_OP_CLASS(ImplConvert, OpImpl<op::Convert>);
template <>
void Impl<op::Convert>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize());
}
namespace
{
Impl<op::Convert>::Registration register_convert;
}
} }
} }
} }
// Convert views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConvert::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Convolution;
class ConvolutionBackpropData;
class ConvolutionBackpropFilters;
}
}
}
}
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::Convolution>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
private:
std::shared_ptr<ngraph::op::Convolution> m_src;
AxisVector m_data_axes;
AxisVector m_filters_axes;
AxisVector m_output_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropData>& get_src() const { return m_src; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_data_axes() const { return m_data_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropData> m_src;
AxisVector m_filters_axes;
AxisVector m_output_axes;
AxisVector m_data_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropFilters>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> m_src;
AxisVector m_data_axes;
AxisVector m_output_axes;
AxisVector m_filters_axes;
};
...@@ -26,54 +26,49 @@ namespace ngraph ...@@ -26,54 +26,49 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Dot is a generalized dot product operation -- scalar-tensor, NGRAPH_PLAIDML_OP_CLASS(ImplDot, OpImpl<op::Dot>);
// matrix-vector, and matrix multiplication. }
template <> }
void Impl<op::Dot>::operator()() }
{
check_inputs(2);
check_outputs(1);
auto l_dim_limit = op().get_inputs()[0].get_shape().size(); // Dot is a generalized dot product operation -- scalar-tensor,
auto r_dim_limit = op().get_inputs()[1].get_shape().size(); // matrix-vector, and matrix multiplication.
auto reduce_limit = op().get_reduction_axes_count(); void ngraph::runtime::plaidml::ImplDot::Apply()
auto l_dim_mac = l_dim_limit - reduce_limit; {
auto r_dim_mic = reduce_limit; check_inputs(2);
check_outputs(1);
NGRAPH_DEBUG << "l_dim_limit=" << l_dim_limit; auto l_dim_limit = op().get_inputs()[0].get_shape().size();
NGRAPH_DEBUG << "r_dim_limit=" << r_dim_limit; auto r_dim_limit = op().get_inputs()[1].get_shape().size();
NGRAPH_DEBUG << "reduce_limit=" << reduce_limit; auto reduce_limit = op().get_reduction_axes_count();
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac; auto l_dim_mac = l_dim_limit - reduce_limit;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic; auto r_dim_mic = reduce_limit;
set_output( NGRAPH_DEBUG << "l_dim_limit=" << l_dim_limit;
start_tile_function() NGRAPH_DEBUG << "r_dim_limit=" << r_dim_limit;
.add(builder::Input{op_input(0), "L"} NGRAPH_DEBUG << "reduce_limit=" << reduce_limit;
.add_dims("DL", 1, l_dim_mac + 1) NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
.add_dims("DC", 1, reduce_limit + 1)) NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
.add(builder::Input{op_input(1), "R"}
.add_dims("DC", 1, reduce_limit + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.add(builder::Output{"O"})
.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"O"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.set_lhs(builder::ContractionInput{"L"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dc", 1, reduce_limit + 1))
.set_rhs(builder::ContractionInput{"R"}
.add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize());
}
namespace set_output(start_tile_function()
{ .add(builder::Input{op_input(0), "L"}
Impl<op::Dot>::Registration register_dot; .add_dims("DL", 1, l_dim_mac + 1)
} .add_dims("DC", 1, reduce_limit + 1))
} .add(builder::Input{op_input(1), "R"}
} .add_dims("DC", 1, reduce_limit + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.add(builder::Output{"O"})
.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"O"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.set_lhs(builder::ContractionInput{"L"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dc", 1, reduce_limit + 1))
.set_rhs(builder::ContractionInput{"R"}
.add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize());
} }
...@@ -25,33 +25,28 @@ namespace ngraph ...@@ -25,33 +25,28 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// FunctionCall invokes a sub-function. NGRAPH_PLAIDML_OP_CLASS(ImplFunctionCall, OpImpl<op::FunctionCall>);
template <>
void Impl<op::FunctionCall>::operator()()
{
Build b;
build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer};
vertexai::plaidml::function::parameters_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv =
b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
}
vertexai::plaidml::application app{f.apply(inputs)};
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv]));
}
}
namespace
{
Impl<op::FunctionCall>::Registration register_function_call;
}
} }
} }
} }
// FunctionCall invokes a sub-function.
void ngraph::runtime::plaidml::ImplFunctionCall::Apply()
{
Build b;
build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer};
vertexai::plaidml::function::parameters_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv = b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
}
vertexai::plaidml::application app{f.apply(inputs)};
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv]));
}
}
...@@ -26,19 +26,14 @@ namespace ngraph ...@@ -26,19 +26,14 @@ namespace ngraph
namespace plaidml namespace plaidml
{ {
template <typename O> template <typename O>
class IndexReductionImpl : public BaseImpl<O> class IndexReductionBase : public OpImpl<O>
{ {
public: protected:
IndexReductionImpl(Build* build, const O& op)
: BaseImpl<O>{build, op}
{
}
void build_index_reduction(const char* agg_op); void build_index_reduction(const char* agg_op);
}; };
template <typename O> template <typename O>
void IndexReductionImpl<O>::build_index_reduction(const char* agg_op) void IndexReductionBase<O>::build_index_reduction(const char* agg_op)
{ {
this->check_inputs(1); this->check_inputs(1);
this->check_outputs(1); this->check_outputs(1);
...@@ -117,37 +112,20 @@ namespace ngraph ...@@ -117,37 +112,20 @@ namespace ngraph
.finalize()); .finalize());
} }
template <> NGRAPH_PLAIDML_OP_CLASS(ImplArgMax, IndexReductionBase<op::ArgMax>);
struct ParentImpl<op::ArgMax> NGRAPH_PLAIDML_OP_CLASS(ImplArgMin, IndexReductionBase<op::ArgMin>);
{
using Type = IndexReductionImpl<op::ArgMax>;
};
template <>
struct ParentImpl<op::ArgMin>
{
using Type = IndexReductionImpl<op::ArgMin>;
};
// ArgMax computes the maximum index along a tensor axis.
template <>
void Impl<op::ArgMax>::operator()()
{
build_index_reduction(">");
}
// ArgMin computes the minimum index along a tensor axis.
template <>
void Impl<op::ArgMin>::operator()()
{
build_index_reduction("<");
}
namespace
{
Impl<op::ArgMax>::Registration register_argmax;
Impl<op::ArgMin>::Registration register_argmin;
}
} }
} }
} }
// ArgMax computes the maximum index along a tensor axis.
void ngraph::runtime::plaidml::ImplArgMax::Apply()
{
build_index_reduction(">");
}
// ArgMin computes the minimum index along a tensor axis.
void ngraph::runtime::plaidml::ImplArgMin::Apply()
{
build_index_reduction("<");
}
...@@ -26,38 +26,33 @@ namespace ngraph ...@@ -26,38 +26,33 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder. NGRAPH_PLAIDML_OP_CLASS(ImplParameter, OpImpl<op::Parameter>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplResult, OpImpl<op::Result>);
void Impl<op::Parameter>::operator()()
{
check_inputs(0);
check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
: op().get_output_shape(0).size()};
std::string name = std::string{"I"} + std::to_string(build()->input_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name));
}
// Result binds a PlaidML variable to a composed function output.
template <>
void Impl<op::Result>::operator()()
{
check_inputs(1);
check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name));
}
namespace
{
Impl<op::Parameter>::Registration register_parameter;
Impl<op::Result>::Registration register_result;
}
} }
} }
} }
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
void ngraph::runtime::plaidml::ImplParameter::Apply()
{
check_inputs(0);
check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
: op().get_output_shape(0).size()};
std::string name = std::string{"I"} + std::to_string(build()->input_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name));
}
// Result binds a PlaidML variable to a composed function output.
void ngraph::runtime::plaidml::ImplResult::Apply()
{
check_inputs(1);
check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name));
}
...@@ -23,46 +23,39 @@ namespace ngraph ...@@ -23,46 +23,39 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// LRN implements Local Response Normalization NGRAPH_PLAIDML_OP_CLASS(ImplLRN, OpImpl<op::LRN>);
template <>
void Impl<op::LRN>::operator()()
{
check_inputs(1);
check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2;
std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha()
<< ".0 / " << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"S"}
.add_indices({"n", "c"})
.add_indices("d", 0, rank)
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank))
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize());
}))
.add(builder::Elementwise{"O", div_expr.str()})
.finalize());
}
namespace
{
Impl<op::LRN>::Registration register_local_response_norm;
}
} }
} }
} }
// LRN implements Local Response Normalization
void ngraph::runtime::plaidml::ImplLRN::Apply()
{
check_inputs(1);
check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2;
std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / "
<< op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank))
.add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"S"}
.add_indices({"n", "c"})
.add_indices("d", 0, rank)
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank))
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize());
}))
.add(builder::Elementwise{"O", div_expr.str()})
.finalize());
}
...@@ -25,56 +25,47 @@ namespace ngraph ...@@ -25,56 +25,47 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// And performs a simple elementwise logical and. NGRAPH_PLAIDML_OP_CLASS(ImplAnd, OpImpl<op::And>);
template <> NGRAPH_PLAIDML_OP_CLASS(ImplNot, OpImpl<op::Not>);
void Impl<op::And>::operator()() NGRAPH_PLAIDML_OP_CLASS(ImplOr, OpImpl<op::Or>);
{ }
check_inputs(2); }
check_outputs(1); }
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? B : A"})
.finalize(),
TensorContents::LOGICAL);
}
// Not performs a simple elementwise logical not. // And performs a simple elementwise logical and.
template <> void ngraph::runtime::plaidml::ImplAnd::Apply()
void Impl<op::Not>::operator()() {
{ check_inputs(2);
check_inputs(1); check_outputs(1);
check_outputs(1); set_output(start_tile_function()
set_output(start_tile_function() .add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "I"}) .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"O"}) .add(builder::Output{"C"})
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"}) .add(builder::Elementwise{"C", "A ? B : A"})
.finalize(), .finalize());
TensorContents::LOGICAL); }
}
// Or performs a simple elementwise logical or. // Not performs a simple elementwise logical not.
template <> void ngraph::runtime::plaidml::ImplNot::Apply()
void Impl<op::Or>::operator()() {
{ check_inputs(1);
check_inputs(2); check_outputs(1);
check_outputs(1); set_output(start_tile_function()
set_output(start_tile_function() .add(builder::Input{op_input(0), "I"})
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"}) .add(builder::Output{"O"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"}) .add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.add(builder::Output{"C"}) .finalize());
.add(builder::Elementwise{"C", "A ? A : B"}) }
.finalize(),
TensorContents::LOGICAL);
}
namespace // Or performs a simple elementwise logical or.
{ void ngraph::runtime::plaidml::ImplOr::Apply()
Impl<op::And>::Registration register_and; {
Impl<op::Not>::Registration register_not; check_inputs(2);
Impl<op::Or>::Registration register_or; check_outputs(1);
} set_output(start_tile_function()
} .add(builder::Input{op_input(0), "A"})
} .add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? A : B"})
.finalize());
} }
...@@ -26,87 +26,79 @@ namespace ngraph ...@@ -26,87 +26,79 @@ namespace ngraph
{ {
namespace plaidml namespace plaidml
{ {
// OneHot performs one-hot encoding along the requested axis. NGRAPH_PLAIDML_OP_CLASS(ImplOneHot, OpImpl<op::OneHot>);
template <> }
void Impl<op::OneHot>::operator()() }
{ }
check_inputs(1);
check_outputs(1);
// Here's what's going on to implement OneHot:
//
// * We reshape the input tensor to add a size=1 dimension where we want the one-hot axis to be,
//
// * We create an index tensor that's size=1 on every dimension except the one-hot dimension,
//
// * We perform an elementwise conditional across them to assign the one-hot values.
//
// The broadcast rules will expand the index tensor on all non-one-hot dimensions to match the
// input, and will expand the input tensor on the one-hot dimension to match the index.
//
// In theory, it'd be pretty easy to implement all this with purely elementwise operations. The
// current definition of index() requires an input tensor of the index() output shape, and it's
// a little tricky to fix that, so we generate a zero tensor of the correct shape using a
// contraction. TODO: Optimize out the zero tensor contraction.
const auto& in_shape = op().get_inputs()[0].get_shape(); // OneHot performs one-hot encoding along the requested axis.
const auto& out_shape = op().get_shape(); void ngraph::runtime::plaidml::ImplOneHot::Apply()
{
check_inputs(1);
check_outputs(1);
std::ostringstream in_reshape; // Here's what's going on to implement OneHot:
for (std::size_t idx = 0; idx < out_shape.size(); ++idx) //
{ // * We reshape the input tensor to add a size=1 dimension where we want the one-hot axis to be,
if (idx) //
{ // * We create an index tensor that's size=1 on every dimension except the one-hot dimension,
in_reshape << ", "; //
} // * We perform an elementwise conditional across them to assign the one-hot values.
if (idx == op().get_one_hot_axis()) //
{ // The broadcast rules will expand the index tensor on all non-one-hot dimensions to match the
in_reshape << 1; // input, and will expand the input tensor on the one-hot dimension to match the index.
} //
else // In theory, it'd be pretty easy to implement all this with purely elementwise operations. The
{ // current definition of index() requires an input tensor of the index() output shape, and it's
in_reshape << out_shape[idx]; // a little tricky to fix that, so we generate a zero tensor of the correct shape using a
} // contraction. TODO: Optimize out the zero tensor contraction.
}
set_output( const auto& in_shape = op().get_inputs()[0].get_shape();
start_tile_function() const auto& out_shape = op().get_shape();
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"ZS"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx == op().get_one_hot_axis())
{
out = std::to_string(out_shape[idx]);
}
else
{
out = "1";
}
}
})
.add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{
"Idx", "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O",
tile_converter("OV", op().get_element_type())})
.finalize());
}
namespace std::ostringstream in_reshape;
{ for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
Impl<op::OneHot>::Registration register_one_hot; {
} if (idx)
{
in_reshape << ", ";
}
if (idx == op().get_one_hot_axis())
{
in_reshape << 1;
}
else
{
in_reshape << out_shape[idx];
} }
} }
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"ZS"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx == op().get_one_hot_axis())
{
out = std::to_string(out_shape[idx]);
}
else
{
out = "1";
}
}
})
.add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{"Idx",
"index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O", tile_converter("OV", op().get_element_type())})
.finalize());
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -130,12 +130,3 @@ std::string ngraph::runtime::plaidml::tile_converter(const std::string& tensor_n ...@@ -130,12 +130,3 @@ std::string ngraph::runtime::plaidml::tile_converter(const std::string& tensor_n
} }
return tile_converter(tensor_name, to_plaidml(element_type)); return tile_converter(tensor_name, to_plaidml(element_type));
} }
vp::variable ngraph::runtime::plaidml::plaidml_logical_to_data(vp::variable var, bool debug)
{
return builder::Function{"logicalToData", debug}
.add(builder::Input{var, "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "as_int(I ? 1 : 0, 8)"})
.finalize();
}
...@@ -46,9 +46,6 @@ namespace ngraph ...@@ -46,9 +46,6 @@ namespace ngraph
std::string tile_converter(const std::string& tensor_name, std::string tile_converter(const std::string& tensor_name,
const ngraph::element::Type& element_type); const ngraph::element::Type& element_type);
vertexai::plaidml::variable plaidml_logical_to_data(vertexai::plaidml::variable var,
bool debug);
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment