Commit f5b2d581 authored by Rob Earhart's avatar Rob Earhart Committed by Scott Cyphers

API cleanup & performance passes (#2242)

* nGraph compat updates

* Simplify op structure

* Add zero-dim elimination pass

* Add logical data conversion pass

* Add graphviz support

* Add implicit broadcast pass

* Elide unnecessary reshape ops

* Merge reshapes into convolutions

* Add winograd convolution support

* Allow three-input maxpool backprop

* Add concat elision

* Combine replication operations

* Elide unneeded replicates

* Style update
parent c11644ec
......@@ -41,10 +41,21 @@ set(SRC
plaidml_ops_pool.cpp
plaidml_ops_reduce.cpp
plaidml_ops_replace_slice.cpp
plaidml_ops_replicate.cpp
plaidml_ops_reverse.cpp
plaidml_ops_slice.cpp
plaidml_ops_softmax.cpp
plaidml_ops_tile.cpp
plaidml_ops_transcendental.cpp
plaidml_ops_winograd.cpp
plaidml_pass_concat_elision.cpp
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_reshape_elision.cpp
plaidml_pass_winograd.cpp
plaidml_tensor.cpp
plaidml_translate.cpp
)
......
......@@ -41,10 +41,11 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe
&m_config, element_type, shape, "direct_data", memory_pointer);
}
bool ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
std::shared_ptr<ngraph::Function>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
{
m_cache.compile(func, &m_compiler);
return true;
return func;
}
bool ngraph::runtime::plaidml::PlaidML_Backend::call(
......
......@@ -46,7 +46,7 @@ public:
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final;
bool compile(std::shared_ptr<Function> func) final;
std::shared_ptr<Function> compile(std::shared_ptr<Function> func) final;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
......
......@@ -18,6 +18,7 @@
#include <stdexcept>
#include <utility>
#include "ngraph/except.hpp"
#include "ngraph/runtime/plaidml/plaidml_builder.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp"
......@@ -467,6 +468,24 @@ ngraph::runtime::plaidml::builder::Input&&
return std::move(*this);
}
void ngraph::runtime::plaidml::builder::Input::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size())
{
throw ngraph_error{"Mismatched shape in input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> dims(m_dims.size());
for (auto dim : m_dims)
{
dims[axes[idx]] = dim;
idx++;
}
m_dims.clear();
m_dims.insert(m_dims.end(), dims.begin(), dims.end());
}
ngraph::runtime::plaidml::builder::Output::Output(std::string name)
: m_name{std::move(name)}
{
......@@ -611,6 +630,28 @@ ngraph::runtime::plaidml::builder::ContractionOutput&&
return std::move(*this);
}
void ngraph::runtime::plaidml::builder::ContractionOutput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size() || axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction output transposition"};
}
std::vector<std::string> dims{m_dims.begin(), m_dims.end()};
m_dims.clear();
for (auto idx : axes)
{
m_dims.emplace_back(dims[idx]);
}
std::vector<std::string> indices{m_indices.begin(), m_indices.end()};
m_indices.clear();
for (auto idx : axes)
{
m_indices.emplace_back(indices[idx]);
}
}
ngraph::runtime::plaidml::builder::ContractionInput&
ngraph::runtime::plaidml::builder::ContractionInput::add_indices(std::string prefix,
std::size_t first,
......@@ -675,6 +716,24 @@ ngraph::runtime::plaidml::builder::ContractionInput&&
return std::move(*this);
}
void ngraph::runtime::plaidml::builder::ContractionInput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> indices(m_indices.size());
for (auto dim : m_indices)
{
indices[axes[idx]] = dim;
idx++;
}
m_indices.clear();
m_indices.insert(m_indices.end(), indices.begin(), indices.end());
}
ngraph::runtime::plaidml::builder::UnaryContraction::UnaryContraction(std::string agg_op)
: m_agg_op{std::move(agg_op)}
{
......
......@@ -24,6 +24,7 @@
#include <string>
#include <utility>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp"
// Utilities for constructing PlaidML functions.
......@@ -136,9 +137,22 @@ public:
return std::move(*this);
}
Input& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
Input&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private:
friend class Function;
void apply_transpose(const AxisVector& axes);
vertexai::plaidml::variable m_var;
std::string m_name;
std::list<std::string> m_dims;
......@@ -230,9 +244,22 @@ public:
return std::move(*this);
}
ContractionOutput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionOutput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private:
friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name;
std::list<std::string> m_indices;
std::list<std::string> m_dims;
......@@ -268,9 +295,22 @@ public:
return std::move(*this);
}
ContractionInput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionInput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private:
friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name;
std::list<std::string> m_indices;
};
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
......@@ -25,8 +26,18 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
namespace
{
......@@ -78,15 +89,44 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
// N.B. We'd like to register ngraph::pass::GetOutputElementElimination, but it breaks BatchNorm
// backprop
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ExplicitLogicals>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReshapeElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (m_config->winograd)
{
pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>();
}
if (!m_config->graphviz.empty())
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>(m_config->graphviz);
}
// N.B. When we rewrite the graph, there are cases where we
// produce nodes that contain validation errors. A good example
// is in the ImplicitBroadcast pass -- after this pass, there may
// be elementwise operations whose inputs are not all the same
// shape.
//
// The caller may wish to perform operations (e.g. clone) on their
// supplied function that will cause validation to occur. So
// before we rewrite, we make our own copy of the function.
func = clone_function(*func);
// Apply passes.
pass_manager.run_passes(func);
// Compile the resulting function.
Build b;
build(std::move(func), &b);
return std::make_shared<CompiledFunction>(std::move(b));
......@@ -98,7 +138,7 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
b->config = m_config;
b->func = func;
const auto* op_map = OpImplMap();
const auto* op_map = GlobalOpImplMap();
for (const auto& op_ptr : func->get_ordered_ops())
{
......@@ -114,6 +154,6 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
std::string{"The PlaidML backend doesn't currently implement the '"} +
op->description() + "' operation"};
}
it->second(b, *op);
it->second->Apply(b, op);
}
}
......@@ -77,8 +77,10 @@ ngraph::runtime::plaidml::Config
bool help = false;
bool list = false;
bool debug = false;
bool winograd = false;
std::size_t device_idx = 0;
std::string eventlog_config;
std::string graphviz;
#ifdef NGRAPH_DEBUG_ENABLE
debug = true;
......@@ -155,7 +157,7 @@ ngraph::runtime::plaidml::Config
return (oname_end - oname_begin == len) && !strncmp(oname_begin, opt, len);
};
auto oval_len = oval_end - oval_begin;
std::size_t oval_len = oval_end - oval_begin;
bool has_oval = oval_begin != oname_end;
// N.B. oval_len != 0 => has_oval, but there's no other relationship.
......@@ -229,6 +231,25 @@ ngraph::runtime::plaidml::Config
continue;
}
// Check for visualization (GraphViz output)
if (is_opt("graphviz"))
{
if (!oval_len)
{
throw std::invalid_argument{"PlaidML graphviz requires a value"};
}
graphviz = std::string{oval_begin, oval_len};
continue;
}
// Check for Winograd. (Winograd is sometimes a performance
// boost, but not always, so we make it optional.)
if (is_opt("winograd"))
{
winograd = true;
continue;
}
// Reject unknown options
err = true;
}
......@@ -236,7 +257,7 @@ ngraph::runtime::plaidml::Config
constexpr char help_text[] =
"PlaidML Backend Specification: \""
"PlaidML[:[device_index][,debug][,help][,list_devices][,"
"eventlog=<filename>]]\". For example: \"PlaidML\", \""
"eventlog=<filename>][,graphviz=<filename>][,winograd]]\". For example: \"PlaidML\", \""
"PlaidML:0,list_devices\"";
if (err)
{
......@@ -269,5 +290,9 @@ ngraph::runtime::plaidml::Config
result.debug = debug;
result.graphviz = graphviz;
result.winograd = winograd;
return result;
}
......@@ -39,4 +39,6 @@ struct ngraph::runtime::plaidml::Config
std::shared_ptr<vertexai::ctx> ctx;
std::shared_ptr<vertexai::plaidml::device> dev;
bool debug;
bool winograd;
std::string graphviz;
};
......@@ -171,7 +171,7 @@ ngraph::runtime::plaidml::ConvPoolFormatter::ConvPoolFormatter(
}
ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var)
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var) const
{
if (m_op != OpType::Conv)
{
......@@ -191,7 +191,7 @@ ngraph::runtime::plaidml::builder::Input
}
ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var)
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var) const
{
if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{
......@@ -216,7 +216,7 @@ ngraph::runtime::plaidml::builder::Input
}
ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var)
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var) const
{
if (m_deriv == DerivType::None)
{
......@@ -240,7 +240,7 @@ ngraph::runtime::plaidml::builder::Input
}
ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header()
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header() const
{
if (m_op != OpType::Conv)
{
......@@ -254,7 +254,7 @@ ngraph::runtime::plaidml::builder::Output
}
ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header()
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header() const
{
if (m_deriv != DerivType::Data)
{
......@@ -272,7 +272,7 @@ ngraph::runtime::plaidml::builder::Output
}
ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header()
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header() const
{
if (m_deriv != DerivType::None)
{
......@@ -282,7 +282,7 @@ ngraph::runtime::plaidml::builder::Output
}
ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body()
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body() const
{
if (m_op != OpType::Conv)
{
......@@ -307,7 +307,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
}
ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body()
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body() const
{
if (m_deriv != DerivType::Data)
{
......@@ -353,7 +353,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
}
ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body()
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body() const
{
if (m_deriv != DerivType::None && m_op == OpType::Conv)
{
......@@ -405,7 +405,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
}
ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body()
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body() const
{
if (m_op != OpType::Conv)
{
......@@ -425,7 +425,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
}
ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body()
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body() const
{
if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{
......@@ -449,7 +449,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
}
ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body()
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body() const
{
if (m_deriv == DerivType::None)
{
......@@ -482,7 +482,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
}
ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones()
ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones() const
{
if (m_op != OpType::AvgPool)
{
......@@ -501,7 +501,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
}
ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Count()
ngraph::runtime::plaidml::ConvPoolFormatter::Count() const
{
if (m_op != OpType::AvgPool)
{
......@@ -535,7 +535,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
}
ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction()
ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction() const
{
std::string agg_op;
switch (m_op)
......@@ -558,7 +558,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
}
ngraph::runtime::plaidml::builder::TernaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction()
ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction() const
{
builder::ContractionOutput output{"DI"};
output.add_indices({n(), c()}).add_dims({N(), C()});
......@@ -595,27 +595,27 @@ ngraph::runtime::plaidml::builder::TernaryContraction
.set_third(incoming_deriv);
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::c()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::c() const
{
return "c";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci() const
{
return "ci";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::co()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::co() const
{
return "co";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::n()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::n() const
{
return "n";
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() const
{
if (m_xfs.empty())
{
......@@ -629,7 +629,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs()
return m_xfs;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() const
{
if (m_xis.empty())
{
......@@ -652,7 +652,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis()
return m_xis;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() const
{
if (m_xos.empty())
{
......@@ -666,27 +666,27 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos()
return m_xos;
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::C()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::C() const
{
return "C";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI() const
{
return "CI";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO() const
{
return "CO";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::N()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::N() const
{
return "N";
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() const
{
if (m_XFs.empty())
{
......@@ -707,7 +707,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs()
return m_XFs;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() const
{
if (m_XIs.empty())
{
......@@ -728,7 +728,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs()
return m_XIs;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() const
{
if (m_XOs.empty())
{
......@@ -765,7 +765,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs()
return m_XOs;
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::F()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() const
{
if (m_deriv == DerivType::Filter)
{
......@@ -774,7 +774,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::F()
return "F";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::I()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() const
{
if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{
......@@ -783,7 +783,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::I()
return "I";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::O()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::O() const
{
if (m_deriv != DerivType::None)
{
......
......@@ -72,47 +72,47 @@ public:
ConvPoolFormatter::DerivType deriv);
// Formatted tensors
builder::Input F_in_header(vertexai::plaidml::variable var);
builder::Input I_in_header(vertexai::plaidml::variable var);
builder::Input O_in_header(vertexai::plaidml::variable var);
builder::Output F_out_header();
builder::Output I_out_header();
builder::Output O_out_header();
builder::ContractionOutput F_out_body();
builder::ContractionOutput I_out_body();
builder::ContractionOutput O_out_body();
builder::ContractionInput F_in_body();
builder::ContractionInput I_in_body();
builder::ContractionInput O_in_body();
builder::Input F_in_header(vertexai::plaidml::variable var) const;
builder::Input I_in_header(vertexai::plaidml::variable var) const;
builder::Input O_in_header(vertexai::plaidml::variable var) const;
builder::Output F_out_header() const;
builder::Output I_out_header() const;
builder::Output O_out_header() const;
builder::ContractionOutput F_out_body() const;
builder::ContractionOutput I_out_body() const;
builder::ContractionOutput O_out_body() const;
builder::ContractionInput F_in_body() const;
builder::ContractionInput I_in_body() const;
builder::ContractionInput O_in_body() const;
// Special Operations
builder::UnaryContraction Broadcast_Ones();
builder::UnaryContraction Count();
builder::UnaryContraction PoolContraction();
builder::TernaryContraction PoolDerivContraction();
builder::UnaryContraction Broadcast_Ones() const;
builder::UnaryContraction Count() const;
builder::UnaryContraction PoolContraction() const;
builder::TernaryContraction PoolDerivContraction() const;
// Index names / formulas
std::string c();
std::string ci();
std::string co();
std::string n();
std::vector<std::string> xfs();
std::vector<std::string> xis();
std::vector<std::string> xos();
std::string c() const;
std::string ci() const;
std::string co() const;
std::string n() const;
std::vector<std::string> xfs() const;
std::vector<std::string> xis() const;
std::vector<std::string> xos() const;
// Dimension names / formulas
std::string C();
std::string CI();
std::string CO();
std::string N();
std::vector<std::string> XFs();
std::vector<std::string> XIs();
std::vector<std::string> XOs();
std::string C() const;
std::string CI() const;
std::string CO() const;
std::string N() const;
std::vector<std::string> XFs() const;
std::vector<std::string> XIs() const;
std::vector<std::string> XOs() const;
// Tensor names
std::string F();
std::string I();
std::string O();
std::string F() const;
std::string I() const;
std::string O() const;
private:
std::size_t m_rank;
......@@ -126,10 +126,10 @@ private:
DerivType m_deriv = DerivType::None;
ngraph::Shape m_filters_shape;
ngraph::Shape m_data_batch_shape;
std::vector<std::string> m_xfs;
std::vector<std::string> m_xis;
std::vector<std::string> m_xos;
std::vector<std::string> m_XFs;
std::vector<std::string> m_XIs;
std::vector<std::string> m_XOs;
mutable std::vector<std::string> m_xfs;
mutable std::vector<std::string> m_xis;
mutable std::vector<std::string> m_xos;
mutable std::vector<std::string> m_XFs;
mutable std::vector<std::string> m_XIs;
mutable std::vector<std::string> m_XOs;
};
......@@ -16,20 +16,8 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph
ngraph::runtime::plaidml::OpImplMap* ngraph::runtime::plaidml::GlobalOpImplMap()
{
namespace runtime
{
namespace plaidml
{
std::unordered_map<std::type_index, std::function<void(Build*, const ngraph::Node&)>>*
OpImplMap()
{
static std::unordered_map<std::type_index,
std::function<void(Build*, const ngraph::Node&)>>
op_impl_map;
return &op_impl_map;
}
}
}
static OpImplMap op_impl_map;
return &op_impl_map;
}
This diff is collapsed.
......@@ -23,95 +23,86 @@ namespace ngraph
{
namespace plaidml
{
// Concat views a tensor as a new type.
template <>
void Impl<op::Concat>::operator()()
{
check_outputs(1);
auto f = start_tile_function();
f.add(builder::Output{"O"});
std::size_t dim_count = op().get_shape().size();
std::ostringstream offset;
std::ostringstream oexpr;
std::ostringstream concat_dsize;
bool saw_non_zero_tensor = false;
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
{
if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
if (saw_non_zero_tensor)
{
concat_dsize << "+";
}
saw_non_zero_tensor = true;
concat_dsize << "I" << iidx << "_D" << op().get_concatenation_axis();
}
NGRAPH_PLAIDML_OP_CLASS(ImplConcat, OpImpl<op::Concat>);
}
}
}
saw_non_zero_tensor = false;
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
{
if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims(
"I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
if (idx == op().get_concatenation_axis())
{
out = concat_dsize.str();
}
else
{
s << "I" << iidx << "_D" << idx;
out = s.str();
}
}
})
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
s << "d" << idx;
if (saw_non_zero_tensor &&
idx == op().get_concatenation_axis())
{
s << " + " << offset.str();
}
out = s.str();
}
}))
.set(builder::ContractionInput{"I" + sidx}.add_indices(
"d", 0, dim_count)));
if (saw_non_zero_tensor)
{
oexpr << " + ";
offset << " + ";
}
oexpr << "E" << sidx;
offset << "I" << iidx << "_D" << op().get_concatenation_axis();
saw_non_zero_tensor = true;
}
f.add(builder::Elementwise{"O", oexpr.str()});
// Concat views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConcat::Apply()
{
check_outputs(1);
set_output(f.finalize());
}
auto f = start_tile_function();
f.add(builder::Output{"O"});
std::size_t dim_count = op().get_shape().size();
std::ostringstream offset;
std::ostringstream oexpr;
std::ostringstream concat_dsize;
bool saw_non_zero_tensor = false;
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
{
if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
if (saw_non_zero_tensor)
{
concat_dsize << "+";
}
saw_non_zero_tensor = true;
concat_dsize << "I" << iidx << "_D" << op().get_concatenation_axis();
}
namespace
{
Impl<op::Concat>::Registration register_concat;
}
saw_non_zero_tensor = false;
for (std::size_t iidx = 0; iidx < op().get_inputs().size(); ++iidx)
{
if (!shape_size(op().get_input_shape(iidx)))
{
continue;
}
std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims("I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
if (idx == op().get_concatenation_axis())
{
out = concat_dsize.str();
}
else
{
s << "I" << iidx << "_D" << idx;
out = s.str();
}
}
})
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
s << "d" << idx;
if (saw_non_zero_tensor && idx == op().get_concatenation_axis())
{
s << " + " << offset.str();
}
out = s.str();
}
}))
.set(builder::ContractionInput{"I" + sidx}.add_indices("d", 0, dim_count)));
if (saw_non_zero_tensor)
{
oexpr << " + ";
offset << " + ";
}
oexpr << "E" << sidx;
offset << "I" << iidx << "_D" << op().get_concatenation_axis();
saw_non_zero_tensor = true;
}
f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize());
}
......@@ -24,25 +24,20 @@ namespace ngraph
{
namespace plaidml
{
// Convert views a tensor as a new type.
template <>
void Impl<op::Convert>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize());
}
namespace
{
Impl<op::Convert>::Registration register_convert;
}
NGRAPH_PLAIDML_OP_CLASS(ImplConvert, OpImpl<op::Convert>);
}
}
}
// Convert views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConvert::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Convolution;
class ConvolutionBackpropData;
class ConvolutionBackpropFilters;
}
}
}
}
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::Convolution>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
private:
std::shared_ptr<ngraph::op::Convolution> m_src;
AxisVector m_data_axes;
AxisVector m_filters_axes;
AxisVector m_output_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropData>& get_src() const { return m_src; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_data_axes() const { return m_data_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropData> m_src;
AxisVector m_filters_axes;
AxisVector m_output_axes;
AxisVector m_data_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropFilters>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> m_src;
AxisVector m_data_axes;
AxisVector m_output_axes;
AxisVector m_filters_axes;
};
......@@ -26,54 +26,49 @@ namespace ngraph
{
namespace plaidml
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template <>
void Impl<op::Dot>::operator()()
{
check_inputs(2);
check_outputs(1);
NGRAPH_PLAIDML_OP_CLASS(ImplDot, OpImpl<op::Dot>);
}
}
}
auto l_dim_limit = op().get_inputs()[0].get_shape().size();
auto r_dim_limit = op().get_inputs()[1].get_shape().size();
auto reduce_limit = op().get_reduction_axes_count();
auto l_dim_mac = l_dim_limit - reduce_limit;
auto r_dim_mic = reduce_limit;
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
void ngraph::runtime::plaidml::ImplDot::Apply()
{
check_inputs(2);
check_outputs(1);
NGRAPH_DEBUG << "l_dim_limit=" << l_dim_limit;
NGRAPH_DEBUG << "r_dim_limit=" << r_dim_limit;
NGRAPH_DEBUG << "reduce_limit=" << reduce_limit;
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
auto l_dim_limit = op().get_inputs()[0].get_shape().size();
auto r_dim_limit = op().get_inputs()[1].get_shape().size();
auto reduce_limit = op().get_reduction_axes_count();
auto l_dim_mac = l_dim_limit - reduce_limit;
auto r_dim_mic = reduce_limit;
set_output(
start_tile_function()
.add(builder::Input{op_input(0), "L"}
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DC", 1, reduce_limit + 1))
.add(builder::Input{op_input(1), "R"}
.add_dims("DC", 1, reduce_limit + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.add(builder::Output{"O"})
.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"O"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.set_lhs(builder::ContractionInput{"L"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dc", 1, reduce_limit + 1))
.set_rhs(builder::ContractionInput{"R"}
.add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize());
}
NGRAPH_DEBUG << "l_dim_limit=" << l_dim_limit;
NGRAPH_DEBUG << "r_dim_limit=" << r_dim_limit;
NGRAPH_DEBUG << "reduce_limit=" << reduce_limit;
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
namespace
{
Impl<op::Dot>::Registration register_dot;
}
}
}
set_output(start_tile_function()
.add(builder::Input{op_input(0), "L"}
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DC", 1, reduce_limit + 1))
.add(builder::Input{op_input(1), "R"}
.add_dims("DC", 1, reduce_limit + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.add(builder::Output{"O"})
.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"O"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DR", r_dim_mic + 1, r_dim_limit + 1))
.set_lhs(builder::ContractionInput{"L"}
.add_indices("dl", 1, l_dim_mac + 1)
.add_indices("dc", 1, reduce_limit + 1))
.set_rhs(builder::ContractionInput{"R"}
.add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize());
}
......@@ -25,33 +25,28 @@ namespace ngraph
{
namespace plaidml
{
// FunctionCall invokes a sub-function.
template <>
void Impl<op::FunctionCall>::operator()()
{
Build b;
build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer};
vertexai::plaidml::function::parameters_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv =
b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
}
vertexai::plaidml::application app{f.apply(inputs)};
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv]));
}
}
namespace
{
Impl<op::FunctionCall>::Registration register_function_call;
}
NGRAPH_PLAIDML_OP_CLASS(ImplFunctionCall, OpImpl<op::FunctionCall>);
}
}
}
// FunctionCall invokes a sub-function.
void ngraph::runtime::plaidml::ImplFunctionCall::Apply()
{
Build b;
build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer};
vertexai::plaidml::function::parameters_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv = b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
}
vertexai::plaidml::application app{f.apply(inputs)};
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv]));
}
}
......@@ -26,19 +26,14 @@ namespace ngraph
namespace plaidml
{
template <typename O>
class IndexReductionImpl : public BaseImpl<O>
class IndexReductionBase : public OpImpl<O>
{
public:
IndexReductionImpl(Build* build, const O& op)
: BaseImpl<O>{build, op}
{
}
protected:
void build_index_reduction(const char* agg_op);
};
template <typename O>
void IndexReductionImpl<O>::build_index_reduction(const char* agg_op)
void IndexReductionBase<O>::build_index_reduction(const char* agg_op)
{
this->check_inputs(1);
this->check_outputs(1);
......@@ -117,37 +112,20 @@ namespace ngraph
.finalize());
}
template <>
struct ParentImpl<op::ArgMax>
{
using Type = IndexReductionImpl<op::ArgMax>;
};
template <>
struct ParentImpl<op::ArgMin>
{
using Type = IndexReductionImpl<op::ArgMin>;
};
// ArgMax computes the maximum index along a tensor axis.
template <>
void Impl<op::ArgMax>::operator()()
{
build_index_reduction(">");
}
// ArgMin computes the minimum index along a tensor axis.
template <>
void Impl<op::ArgMin>::operator()()
{
build_index_reduction("<");
}
namespace
{
Impl<op::ArgMax>::Registration register_argmax;
Impl<op::ArgMin>::Registration register_argmin;
}
NGRAPH_PLAIDML_OP_CLASS(ImplArgMax, IndexReductionBase<op::ArgMax>);
NGRAPH_PLAIDML_OP_CLASS(ImplArgMin, IndexReductionBase<op::ArgMin>);
}
}
}
// ArgMax computes the maximum index along a tensor axis.
void ngraph::runtime::plaidml::ImplArgMax::Apply()
{
build_index_reduction(">");
}
// ArgMin computes the minimum index along a tensor axis.
void ngraph::runtime::plaidml::ImplArgMin::Apply()
{
build_index_reduction("<");
}
......@@ -26,38 +26,33 @@ namespace ngraph
{
namespace plaidml
{
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template <>
void Impl<op::Parameter>::operator()()
{
check_inputs(0);
check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
: op().get_output_shape(0).size()};
std::string name = std::string{"I"} + std::to_string(build()->input_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name));
}
// Result binds a PlaidML variable to a composed function output.
template <>
void Impl<op::Result>::operator()()
{
check_inputs(1);
check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name));
}
namespace
{
Impl<op::Parameter>::Registration register_parameter;
Impl<op::Result>::Registration register_result;
}
NGRAPH_PLAIDML_OP_CLASS(ImplParameter, OpImpl<op::Parameter>);
NGRAPH_PLAIDML_OP_CLASS(ImplResult, OpImpl<op::Result>);
}
}
}
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
void ngraph::runtime::plaidml::ImplParameter::Apply()
{
check_inputs(0);
check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
: op().get_output_shape(0).size()};
std::string name = std::string{"I"} + std::to_string(build()->input_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name));
}
// Result binds a PlaidML variable to a composed function output.
void ngraph::runtime::plaidml::ImplResult::Apply()
{
check_inputs(1);
check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name));
}
......@@ -23,46 +23,39 @@ namespace ngraph
{
namespace plaidml
{
// LRN implements Local Response Normalization
template <>
void Impl<op::LRN>::operator()()
{
check_inputs(1);
check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2;
std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha()
<< ".0 / " << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"S"}
.add_indices({"n", "c"})
.add_indices("d", 0, rank)
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank))
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize());
}))
.add(builder::Elementwise{"O", div_expr.str()})
.finalize());
}
namespace
{
Impl<op::LRN>::Registration register_local_response_norm;
}
NGRAPH_PLAIDML_OP_CLASS(ImplLRN, OpImpl<op::LRN>);
}
}
}
// LRN implements Local Response Normalization
void ngraph::runtime::plaidml::ImplLRN::Apply()
{
check_inputs(1);
check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2;
std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / "
<< op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank))
.add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"S"}
.add_indices({"n", "c"})
.add_indices("d", 0, rank)
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank))
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize());
}))
.add(builder::Elementwise{"O", div_expr.str()})
.finalize());
}
......@@ -25,56 +25,47 @@ namespace ngraph
{
namespace plaidml
{
// And performs a simple elementwise logical and.
template <>
void Impl<op::And>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? B : A"})
.finalize(),
TensorContents::LOGICAL);
}
NGRAPH_PLAIDML_OP_CLASS(ImplAnd, OpImpl<op::And>);
NGRAPH_PLAIDML_OP_CLASS(ImplNot, OpImpl<op::Not>);
NGRAPH_PLAIDML_OP_CLASS(ImplOr, OpImpl<op::Or>);
}
}
}
// Not performs a simple elementwise logical not.
template <>
void Impl<op::Not>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.finalize(),
TensorContents::LOGICAL);
}
// And performs a simple elementwise logical and.
void ngraph::runtime::plaidml::ImplAnd::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? B : A"})
.finalize());
}
// Or performs a simple elementwise logical or.
template <>
void Impl<op::Or>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? A : B"})
.finalize(),
TensorContents::LOGICAL);
}
// Not performs a simple elementwise logical not.
void ngraph::runtime::plaidml::ImplNot::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.finalize());
}
namespace
{
Impl<op::And>::Registration register_and;
Impl<op::Not>::Registration register_not;
Impl<op::Or>::Registration register_or;
}
}
}
// Or performs a simple elementwise logical or.
void ngraph::runtime::plaidml::ImplOr::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? A : B"})
.finalize());
}
......@@ -26,87 +26,79 @@ namespace ngraph
{
namespace plaidml
{
// OneHot performs one-hot encoding along the requested axis.
template <>
void Impl<op::OneHot>::operator()()
{
check_inputs(1);
check_outputs(1);
// Here's what's going on to implement OneHot:
//
// * We reshape the input tensor to add a size=1 dimension where we want the one-hot axis to be,
//
// * We create an index tensor that's size=1 on every dimension except the one-hot dimension,
//
// * We perform an elementwise conditional across them to assign the one-hot values.
//
// The broadcast rules will expand the index tensor on all non-one-hot dimensions to match the
// input, and will expand the input tensor on the one-hot dimension to match the index.
//
// In theory, it'd be pretty easy to implement all this with purely elementwise operations. The
// current definition of index() requires an input tensor of the index() output shape, and it's
// a little tricky to fix that, so we generate a zero tensor of the correct shape using a
// contraction. TODO: Optimize out the zero tensor contraction.
NGRAPH_PLAIDML_OP_CLASS(ImplOneHot, OpImpl<op::OneHot>);
}
}
}
const auto& in_shape = op().get_inputs()[0].get_shape();
const auto& out_shape = op().get_shape();
// OneHot performs one-hot encoding along the requested axis.
void ngraph::runtime::plaidml::ImplOneHot::Apply()
{
check_inputs(1);
check_outputs(1);
std::ostringstream in_reshape;
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx)
{
in_reshape << ", ";
}
if (idx == op().get_one_hot_axis())
{
in_reshape << 1;
}
else
{
in_reshape << out_shape[idx];
}
}
// Here's what's going on to implement OneHot:
//
// * We reshape the input tensor to add a size=1 dimension where we want the one-hot axis to be,
//
// * We create an index tensor that's size=1 on every dimension except the one-hot dimension,
//
// * We perform an elementwise conditional across them to assign the one-hot values.
//
// The broadcast rules will expand the index tensor on all non-one-hot dimensions to match the
// input, and will expand the input tensor on the one-hot dimension to match the index.
//
// In theory, it'd be pretty easy to implement all this with purely elementwise operations. The
// current definition of index() requires an input tensor of the index() output shape, and it's
// a little tricky to fix that, so we generate a zero tensor of the correct shape using a
// contraction. TODO: Optimize out the zero tensor contraction.
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"ZS"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx == op().get_one_hot_axis())
{
out = std::to_string(out_shape[idx]);
}
else
{
out = "1";
}
}
})
.add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{
"Idx", "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O",
tile_converter("OV", op().get_element_type())})
.finalize());
}
const auto& in_shape = op().get_inputs()[0].get_shape();
const auto& out_shape = op().get_shape();
namespace
{
Impl<op::OneHot>::Registration register_one_hot;
}
std::ostringstream in_reshape;
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx)
{
in_reshape << ", ";
}
if (idx == op().get_one_hot_axis())
{
in_reshape << 1;
}
else
{
in_reshape << out_shape[idx];
}
}
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"ZS"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx == op().get_one_hot_axis())
{
out = std::to_string(out_shape[idx]);
}
else
{
out = "1";
}
}
})
.add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{"Idx",
"index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O", tile_converter("OV", op().get_element_type())})
.finalize());
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -130,12 +130,3 @@ std::string ngraph::runtime::plaidml::tile_converter(const std::string& tensor_n
}
return tile_converter(tensor_name, to_plaidml(element_type));
}
vp::variable ngraph::runtime::plaidml::plaidml_logical_to_data(vp::variable var, bool debug)
{
return builder::Function{"logicalToData", debug}
.add(builder::Input{var, "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "as_int(I ? 1 : 0, 8)"})
.finalize();
}
......@@ -46,9 +46,6 @@ namespace ngraph
std::string tile_converter(const std::string& tensor_name,
const ngraph::element::Type& element_type);
vertexai::plaidml::variable plaidml_logical_to_data(vertexai::plaidml::variable var,
bool debug);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment