Commit f5b2d581 authored by Rob Earhart's avatar Rob Earhart Committed by Scott Cyphers

API cleanup & performance passes (#2242)

* nGraph compat updates

* Simplify op structure

* Add zero-dim elimination pass

* Add logical data conversion pass

* Add graphviz support

* Add implicit broadcast pass

* Elide unnecessary reshape ops

* Merge reshapes into convolutions

* Add winograd convolution support

* Allow three-input maxpool backprop

* Add concat elision

* Combine replication operations

* Elide unneeded replicates

* Style update
parent c11644ec
......@@ -41,10 +41,21 @@ set(SRC
plaidml_ops_pool.cpp
plaidml_ops_reduce.cpp
plaidml_ops_replace_slice.cpp
plaidml_ops_replicate.cpp
plaidml_ops_reverse.cpp
plaidml_ops_slice.cpp
plaidml_ops_softmax.cpp
plaidml_ops_tile.cpp
plaidml_ops_transcendental.cpp
plaidml_ops_winograd.cpp
plaidml_pass_concat_elision.cpp
plaidml_pass_explicit_logicals.cpp
plaidml_pass_implicit_broadcast.cpp
plaidml_pass_lower_convolutions.cpp
plaidml_pass_replicate_combination.cpp
plaidml_pass_replicate_elision.cpp
plaidml_pass_reshape_elision.cpp
plaidml_pass_winograd.cpp
plaidml_tensor.cpp
plaidml_translate.cpp
)
......
......@@ -41,10 +41,11 @@ std::shared_ptr<ngraph::runtime::Tensor> ngraph::runtime::plaidml::PlaidML_Backe
&m_config, element_type, shape, "direct_data", memory_pointer);
}
bool ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
std::shared_ptr<ngraph::Function>
ngraph::runtime::plaidml::PlaidML_Backend::compile(std::shared_ptr<Function> func)
{
m_cache.compile(func, &m_compiler);
return true;
return func;
}
bool ngraph::runtime::plaidml::PlaidML_Backend::call(
......
......@@ -46,7 +46,7 @@ public:
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) final;
bool compile(std::shared_ptr<Function> func) final;
std::shared_ptr<Function> compile(std::shared_ptr<Function> func) final;
bool call(std::shared_ptr<Function> func,
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
......
......@@ -18,6 +18,7 @@
#include <stdexcept>
#include <utility>
#include "ngraph/except.hpp"
#include "ngraph/runtime/plaidml/plaidml_builder.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp"
......@@ -467,6 +468,24 @@ ngraph::runtime::plaidml::builder::Input&&
return std::move(*this);
}
void ngraph::runtime::plaidml::builder::Input::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size())
{
throw ngraph_error{"Mismatched shape in input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> dims(m_dims.size());
for (auto dim : m_dims)
{
dims[axes[idx]] = dim;
idx++;
}
m_dims.clear();
m_dims.insert(m_dims.end(), dims.begin(), dims.end());
}
ngraph::runtime::plaidml::builder::Output::Output(std::string name)
: m_name{std::move(name)}
{
......@@ -611,6 +630,28 @@ ngraph::runtime::plaidml::builder::ContractionOutput&&
return std::move(*this);
}
void ngraph::runtime::plaidml::builder::ContractionOutput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_dims.size() || axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction output transposition"};
}
std::vector<std::string> dims{m_dims.begin(), m_dims.end()};
m_dims.clear();
for (auto idx : axes)
{
m_dims.emplace_back(dims[idx]);
}
std::vector<std::string> indices{m_indices.begin(), m_indices.end()};
m_indices.clear();
for (auto idx : axes)
{
m_indices.emplace_back(indices[idx]);
}
}
ngraph::runtime::plaidml::builder::ContractionInput&
ngraph::runtime::plaidml::builder::ContractionInput::add_indices(std::string prefix,
std::size_t first,
......@@ -675,6 +716,24 @@ ngraph::runtime::plaidml::builder::ContractionInput&&
return std::move(*this);
}
void ngraph::runtime::plaidml::builder::ContractionInput::apply_transpose(const AxisVector& axes)
{
if (axes.size() != m_indices.size())
{
throw ngraph_error{"Mismatched shape in contraction input transposition"};
}
std::size_t idx = 0;
std::vector<std::string> indices(m_indices.size());
for (auto dim : m_indices)
{
indices[axes[idx]] = dim;
idx++;
}
m_indices.clear();
m_indices.insert(m_indices.end(), indices.begin(), indices.end());
}
ngraph::runtime::plaidml::builder::UnaryContraction::UnaryContraction(std::string agg_op)
: m_agg_op{std::move(agg_op)}
{
......
......@@ -24,6 +24,7 @@
#include <string>
#include <utility>
#include "ngraph/axis_vector.hpp"
#include "ngraph/runtime/plaidml/plaidml_config.hpp"
// Utilities for constructing PlaidML functions.
......@@ -136,9 +137,22 @@ public:
return std::move(*this);
}
Input& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
Input&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private:
friend class Function;
void apply_transpose(const AxisVector& axes);
vertexai::plaidml::variable m_var;
std::string m_name;
std::list<std::string> m_dims;
......@@ -230,9 +244,22 @@ public:
return std::move(*this);
}
ContractionOutput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionOutput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private:
friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name;
std::list<std::string> m_indices;
std::list<std::string> m_dims;
......@@ -268,9 +295,22 @@ public:
return std::move(*this);
}
ContractionInput& transpose(const AxisVector& axes) &
{
apply_transpose(axes);
return *this;
}
ContractionInput&& transpose(const AxisVector& axes) &&
{
apply_transpose(axes);
return std::move(*this);
}
private:
friend class Function;
void apply_transpose(const AxisVector& axes);
std::string m_name;
std::list<std::string> m_indices;
};
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/any_all_replacement.hpp"
......@@ -25,8 +26,18 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_logger.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_reshape_elision.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_winograd.hpp"
namespace
{
......@@ -78,15 +89,44 @@ std::shared_ptr<ngraph::runtime::plaidml::CompiledFunction>
pass_manager.register_pass<ngraph::pass::AnyAllReplacement>();
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
// N.B. We'd like to register ngraph::pass::GetOutputElementElimination, but it breaks BatchNorm
// backprop
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ExplicitLogicals>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ConcatElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReplicateCombination>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ImplicitBroadcast>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::ReshapeElision>();
pass_manager.register_pass<ngraph::runtime::plaidml::pass::LowerConvolutions>();
if (m_config->winograd)
{
pass_manager.register_pass<ngraph::runtime::plaidml::pass::Winograd>();
}
if (!m_config->graphviz.empty())
{
pass_manager.register_pass<ngraph::pass::VisualizeTree>(m_config->graphviz);
}
// N.B. When we rewrite the graph, there are cases where we
// produce nodes that contain validation errors. A good example
// is in the ImplicitBroadcast pass -- after this pass, there may
// be elementwise operations whose inputs are not all the same
// shape.
//
// The caller may wish to perform operations (e.g. clone) on their
// supplied function that will cause validation to occur. So
// before we rewrite, we make our own copy of the function.
func = clone_function(*func);
// Apply passes.
pass_manager.run_passes(func);
// Compile the resulting function.
Build b;
build(std::move(func), &b);
return std::make_shared<CompiledFunction>(std::move(b));
......@@ -98,7 +138,7 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
b->config = m_config;
b->func = func;
const auto* op_map = OpImplMap();
const auto* op_map = GlobalOpImplMap();
for (const auto& op_ptr : func->get_ordered_ops())
{
......@@ -114,6 +154,6 @@ void ngraph::runtime::plaidml::Compiler::build(std::shared_ptr<Function> func, B
std::string{"The PlaidML backend doesn't currently implement the '"} +
op->description() + "' operation"};
}
it->second(b, *op);
it->second->Apply(b, op);
}
}
......@@ -77,8 +77,10 @@ ngraph::runtime::plaidml::Config
bool help = false;
bool list = false;
bool debug = false;
bool winograd = false;
std::size_t device_idx = 0;
std::string eventlog_config;
std::string graphviz;
#ifdef NGRAPH_DEBUG_ENABLE
debug = true;
......@@ -155,7 +157,7 @@ ngraph::runtime::plaidml::Config
return (oname_end - oname_begin == len) && !strncmp(oname_begin, opt, len);
};
auto oval_len = oval_end - oval_begin;
std::size_t oval_len = oval_end - oval_begin;
bool has_oval = oval_begin != oname_end;
// N.B. oval_len != 0 => has_oval, but there's no other relationship.
......@@ -229,6 +231,25 @@ ngraph::runtime::plaidml::Config
continue;
}
// Check for visualization (GraphViz output)
if (is_opt("graphviz"))
{
if (!oval_len)
{
throw std::invalid_argument{"PlaidML graphviz requires a value"};
}
graphviz = std::string{oval_begin, oval_len};
continue;
}
// Check for Winograd. (Winograd is sometimes a performance
// boost, but not always, so we make it optional.)
if (is_opt("winograd"))
{
winograd = true;
continue;
}
// Reject unknown options
err = true;
}
......@@ -236,7 +257,7 @@ ngraph::runtime::plaidml::Config
constexpr char help_text[] =
"PlaidML Backend Specification: \""
"PlaidML[:[device_index][,debug][,help][,list_devices][,"
"eventlog=<filename>]]\". For example: \"PlaidML\", \""
"eventlog=<filename>][,graphviz=<filename>][,winograd]]\". For example: \"PlaidML\", \""
"PlaidML:0,list_devices\"";
if (err)
{
......@@ -269,5 +290,9 @@ ngraph::runtime::plaidml::Config
result.debug = debug;
result.graphviz = graphviz;
result.winograd = winograd;
return result;
}
......@@ -39,4 +39,6 @@ struct ngraph::runtime::plaidml::Config
std::shared_ptr<vertexai::ctx> ctx;
std::shared_ptr<vertexai::plaidml::device> dev;
bool debug;
bool winograd;
std::string graphviz;
};
......@@ -171,7 +171,7 @@ ngraph::runtime::plaidml::ConvPoolFormatter::ConvPoolFormatter(
}
ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var)
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_header(vertexai::plaidml::variable var) const
{
if (m_op != OpType::Conv)
{
......@@ -191,7 +191,7 @@ ngraph::runtime::plaidml::builder::Input
}
ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var)
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_header(vertexai::plaidml::variable var) const
{
if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{
......@@ -216,7 +216,7 @@ ngraph::runtime::plaidml::builder::Input
}
ngraph::runtime::plaidml::builder::Input
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var)
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_header(vertexai::plaidml::variable var) const
{
if (m_deriv == DerivType::None)
{
......@@ -240,7 +240,7 @@ ngraph::runtime::plaidml::builder::Input
}
ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header()
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_header() const
{
if (m_op != OpType::Conv)
{
......@@ -254,7 +254,7 @@ ngraph::runtime::plaidml::builder::Output
}
ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header()
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_header() const
{
if (m_deriv != DerivType::Data)
{
......@@ -272,7 +272,7 @@ ngraph::runtime::plaidml::builder::Output
}
ngraph::runtime::plaidml::builder::Output
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header()
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_header() const
{
if (m_deriv != DerivType::None)
{
......@@ -282,7 +282,7 @@ ngraph::runtime::plaidml::builder::Output
}
ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body()
ngraph::runtime::plaidml::ConvPoolFormatter::F_out_body() const
{
if (m_op != OpType::Conv)
{
......@@ -307,7 +307,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
}
ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body()
ngraph::runtime::plaidml::ConvPoolFormatter::I_out_body() const
{
if (m_deriv != DerivType::Data)
{
......@@ -353,7 +353,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
}
ngraph::runtime::plaidml::builder::ContractionOutput
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body()
ngraph::runtime::plaidml::ConvPoolFormatter::O_out_body() const
{
if (m_deriv != DerivType::None && m_op == OpType::Conv)
{
......@@ -405,7 +405,7 @@ ngraph::runtime::plaidml::builder::ContractionOutput
}
ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body()
ngraph::runtime::plaidml::ConvPoolFormatter::F_in_body() const
{
if (m_op != OpType::Conv)
{
......@@ -425,7 +425,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
}
ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body()
ngraph::runtime::plaidml::ConvPoolFormatter::I_in_body() const
{
if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{
......@@ -449,7 +449,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
}
ngraph::runtime::plaidml::builder::ContractionInput
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body()
ngraph::runtime::plaidml::ConvPoolFormatter::O_in_body() const
{
if (m_deriv == DerivType::None)
{
......@@ -482,7 +482,7 @@ ngraph::runtime::plaidml::builder::ContractionInput
}
ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones()
ngraph::runtime::plaidml::ConvPoolFormatter::Broadcast_Ones() const
{
if (m_op != OpType::AvgPool)
{
......@@ -501,7 +501,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
}
ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::Count()
ngraph::runtime::plaidml::ConvPoolFormatter::Count() const
{
if (m_op != OpType::AvgPool)
{
......@@ -535,7 +535,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
}
ngraph::runtime::plaidml::builder::UnaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction()
ngraph::runtime::plaidml::ConvPoolFormatter::PoolContraction() const
{
std::string agg_op;
switch (m_op)
......@@ -558,7 +558,7 @@ ngraph::runtime::plaidml::builder::UnaryContraction
}
ngraph::runtime::plaidml::builder::TernaryContraction
ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction()
ngraph::runtime::plaidml::ConvPoolFormatter::PoolDerivContraction() const
{
builder::ContractionOutput output{"DI"};
output.add_indices({n(), c()}).add_dims({N(), C()});
......@@ -595,27 +595,27 @@ ngraph::runtime::plaidml::builder::TernaryContraction
.set_third(incoming_deriv);
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::c()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::c() const
{
return "c";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::ci() const
{
return "ci";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::co()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::co() const
{
return "co";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::n()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::n() const
{
return "n";
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs() const
{
if (m_xfs.empty())
{
......@@ -629,7 +629,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xfs()
return m_xfs;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis() const
{
if (m_xis.empty())
{
......@@ -652,7 +652,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xis()
return m_xis;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos() const
{
if (m_xos.empty())
{
......@@ -666,27 +666,27 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::xos()
return m_xos;
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::C()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::C() const
{
return "C";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CI() const
{
return "CI";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::CO() const
{
return "CO";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::N()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::N() const
{
return "N";
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs() const
{
if (m_XFs.empty())
{
......@@ -707,7 +707,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XFs()
return m_XFs;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs() const
{
if (m_XIs.empty())
{
......@@ -728,7 +728,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XIs()
return m_XIs;
}
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs()
std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs() const
{
if (m_XOs.empty())
{
......@@ -765,7 +765,7 @@ std::vector<std::string> ngraph::runtime::plaidml::ConvPoolFormatter::XOs()
return m_XOs;
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::F()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::F() const
{
if (m_deriv == DerivType::Filter)
{
......@@ -774,7 +774,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::F()
return "F";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::I()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::I() const
{
if (m_deriv == DerivType::Data && m_op == OpType::Conv)
{
......@@ -783,7 +783,7 @@ std::string ngraph::runtime::plaidml::ConvPoolFormatter::I()
return "I";
}
std::string ngraph::runtime::plaidml::ConvPoolFormatter::O()
std::string ngraph::runtime::plaidml::ConvPoolFormatter::O() const
{
if (m_deriv != DerivType::None)
{
......
......@@ -72,47 +72,47 @@ public:
ConvPoolFormatter::DerivType deriv);
// Formatted tensors
builder::Input F_in_header(vertexai::plaidml::variable var);
builder::Input I_in_header(vertexai::plaidml::variable var);
builder::Input O_in_header(vertexai::plaidml::variable var);
builder::Output F_out_header();
builder::Output I_out_header();
builder::Output O_out_header();
builder::ContractionOutput F_out_body();
builder::ContractionOutput I_out_body();
builder::ContractionOutput O_out_body();
builder::ContractionInput F_in_body();
builder::ContractionInput I_in_body();
builder::ContractionInput O_in_body();
builder::Input F_in_header(vertexai::plaidml::variable var) const;
builder::Input I_in_header(vertexai::plaidml::variable var) const;
builder::Input O_in_header(vertexai::plaidml::variable var) const;
builder::Output F_out_header() const;
builder::Output I_out_header() const;
builder::Output O_out_header() const;
builder::ContractionOutput F_out_body() const;
builder::ContractionOutput I_out_body() const;
builder::ContractionOutput O_out_body() const;
builder::ContractionInput F_in_body() const;
builder::ContractionInput I_in_body() const;
builder::ContractionInput O_in_body() const;
// Special Operations
builder::UnaryContraction Broadcast_Ones();
builder::UnaryContraction Count();
builder::UnaryContraction PoolContraction();
builder::TernaryContraction PoolDerivContraction();
builder::UnaryContraction Broadcast_Ones() const;
builder::UnaryContraction Count() const;
builder::UnaryContraction PoolContraction() const;
builder::TernaryContraction PoolDerivContraction() const;
// Index names / formulas
std::string c();
std::string ci();
std::string co();
std::string n();
std::vector<std::string> xfs();
std::vector<std::string> xis();
std::vector<std::string> xos();
std::string c() const;
std::string ci() const;
std::string co() const;
std::string n() const;
std::vector<std::string> xfs() const;
std::vector<std::string> xis() const;
std::vector<std::string> xos() const;
// Dimension names / formulas
std::string C();
std::string CI();
std::string CO();
std::string N();
std::vector<std::string> XFs();
std::vector<std::string> XIs();
std::vector<std::string> XOs();
std::string C() const;
std::string CI() const;
std::string CO() const;
std::string N() const;
std::vector<std::string> XFs() const;
std::vector<std::string> XIs() const;
std::vector<std::string> XOs() const;
// Tensor names
std::string F();
std::string I();
std::string O();
std::string F() const;
std::string I() const;
std::string O() const;
private:
std::size_t m_rank;
......@@ -126,10 +126,10 @@ private:
DerivType m_deriv = DerivType::None;
ngraph::Shape m_filters_shape;
ngraph::Shape m_data_batch_shape;
std::vector<std::string> m_xfs;
std::vector<std::string> m_xis;
std::vector<std::string> m_xos;
std::vector<std::string> m_XFs;
std::vector<std::string> m_XIs;
std::vector<std::string> m_XOs;
mutable std::vector<std::string> m_xfs;
mutable std::vector<std::string> m_xis;
mutable std::vector<std::string> m_xos;
mutable std::vector<std::string> m_XFs;
mutable std::vector<std::string> m_XIs;
mutable std::vector<std::string> m_XOs;
};
......@@ -16,20 +16,8 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph
ngraph::runtime::plaidml::OpImplMap* ngraph::runtime::plaidml::GlobalOpImplMap()
{
namespace runtime
{
namespace plaidml
{
std::unordered_map<std::type_index, std::function<void(Build*, const ngraph::Node&)>>*
OpImplMap()
{
static std::unordered_map<std::type_index,
std::function<void(Build*, const ngraph::Node&)>>
op_impl_map;
static OpImplMap op_impl_map;
return &op_impl_map;
}
}
}
}
This diff is collapsed.
......@@ -34,10 +34,26 @@ namespace ngraph
{
namespace plaidml
{
// Abs performs a simple elementwise absolute value.
template <>
void Impl<op::Abs>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplAbs, OpImpl<op::Abs>);
NGRAPH_PLAIDML_OP_CLASS(ImplAdd, OpImpl<op::Add>);
NGRAPH_PLAIDML_OP_CLASS(ImplCeiling, OpImpl<op::Ceiling>);
NGRAPH_PLAIDML_OP_CLASS(ImplDivide, OpImpl<op::Divide>);
NGRAPH_PLAIDML_OP_CLASS(ImplFloor, OpImpl<op::Floor>);
NGRAPH_PLAIDML_OP_CLASS(ImplMultiply, OpImpl<op::Multiply>);
NGRAPH_PLAIDML_OP_CLASS(ImplNegative, OpImpl<op::Negative>);
NGRAPH_PLAIDML_OP_CLASS(ImplRelu, OpImpl<op::Relu>);
NGRAPH_PLAIDML_OP_CLASS(ImplReluBackprop, OpImpl<op::ReluBackprop>);
NGRAPH_PLAIDML_OP_CLASS(ImplSigmoid, OpImpl<op::Sigmoid>);
NGRAPH_PLAIDML_OP_CLASS(ImplSigmoidBackprop, OpImpl<op::SigmoidBackprop>);
NGRAPH_PLAIDML_OP_CLASS(ImplSign, OpImpl<op::Sign>);
NGRAPH_PLAIDML_OP_CLASS(ImplSubtract, OpImpl<op::Subtract>);
}
}
}
// Abs performs a simple elementwise absolute value.
void ngraph::runtime::plaidml::ImplAbs::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -45,12 +61,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "abs(I)"})
.finalize());
}
}
// Add performs a simple elementwise addition.
template <>
void Impl<op::Add>::operator()()
{
// Add performs a simple elementwise addition.
void ngraph::runtime::plaidml::ImplAdd::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -59,12 +74,11 @@ namespace ngraph
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A + B"})
.finalize());
}
}
// Ceiling performs a simple elementwise ceiling.
template <>
void Impl<op::Ceiling>::operator()()
{
// Ceiling performs a simple elementwise ceiling.
void ngraph::runtime::plaidml::ImplCeiling::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -72,12 +86,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "ceil(I)"})
.finalize());
}
}
// Divide performs a simple elementwise division.
template <>
void Impl<op::Divide>::operator()()
{
// Divide performs a simple elementwise division.
void ngraph::runtime::plaidml::ImplDivide::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -86,12 +99,11 @@ namespace ngraph
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A / B"})
.finalize());
}
}
// Floor performs a simple elementwise floor.
template <>
void Impl<op::Floor>::operator()()
{
// Floor performs a simple elementwise floor.
void ngraph::runtime::plaidml::ImplFloor::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -99,12 +111,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "floor(I)"})
.finalize());
}
}
// Multiply performs a simple elementwise multiplication.
template <>
void Impl<op::Multiply>::operator()()
{
// Multiply performs a simple elementwise multiplication.
void ngraph::runtime::plaidml::ImplMultiply::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -113,12 +124,11 @@ namespace ngraph
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A * B"})
.finalize());
}
}
// Negative performs a simple elementwise negation.
template <>
void Impl<op::Negative>::operator()()
{
// Negative performs a simple elementwise negation.
void ngraph::runtime::plaidml::ImplNegative::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -126,12 +136,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "-I"})
.finalize());
}
}
// Relu implements a simple elementwise rectified linear unit.
template <>
void Impl<op::Relu>::operator()()
{
// Relu implements a simple elementwise rectified linear unit.
void ngraph::runtime::plaidml::ImplRelu::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -139,12 +148,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "relu(I)"})
.finalize());
}
}
// ReluBackprop computes the derivative of Relu.
template <>
void Impl<op::ReluBackprop>::operator()()
{
// ReluBackprop computes the derivative of Relu.
void ngraph::runtime::plaidml::ImplReluBackprop::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -153,12 +161,11 @@ namespace ngraph
.add(builder::Output{"DI"})
.add(builder::Elementwise{"DI", "I > 0 ? DO : 0"})
.finalize());
}
}
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template <>
void Impl<op::Sigmoid>::operator()()
{
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
void ngraph::runtime::plaidml::ImplSigmoid::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -166,13 +173,12 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.finalize());
}
}
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template <>
void Impl<op::SigmoidBackprop>::operator()()
{
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
void ngraph::runtime::plaidml::ImplSigmoidBackprop::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -182,27 +188,24 @@ namespace ngraph
.add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.add(builder::Elementwise{"DI", "DO * O * (1-O)"})
.finalize());
}
}
// Sign returns the sign of an element.
template <>
void Impl<op::Sign>::operator()()
{
// Sign returns the sign of an element.
void ngraph::runtime::plaidml::ImplSign::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"S", "(I < 0) ? -1 : ((I > 0) ? 1 : 0)"})
.add(builder::Elementwise{
"O", tile_converter("S", op().get_element_type())})
.add(builder::Elementwise{"O", tile_converter("S", op().get_element_type())})
.finalize());
}
}
// Subtract performs a simple elementwise subtraction.
template <>
void Impl<op::Subtract>::operator()()
{
// Subtract performs a simple elementwise subtraction.
void ngraph::runtime::plaidml::ImplSubtract::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -211,24 +214,4 @@ namespace ngraph
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A - B"})
.finalize());
}
namespace
{
Impl<op::Abs>::Registration register_abs;
Impl<op::Add>::Registration register_add;
Impl<op::Ceiling>::Registration register_ceiling;
Impl<op::Divide>::Registration register_divide;
Impl<op::Floor>::Registration register_floor;
Impl<op::Multiply>::Registration register_multiply;
Impl<op::Negative>::Registration register_negative;
Impl<op::Relu>::Registration register_relu;
Impl<op::ReluBackprop>::Registration register_relu_backprop;
Impl<op::Sigmoid>::Registration register_sigmoid;
Impl<op::SigmoidBackprop>::Registration register_sigmoid_backprop;
Impl<op::Sign>::Registration register_sign;
Impl<op::Subtract>::Registration register_subtract;
}
}
}
}
......@@ -30,25 +30,34 @@ namespace ngraph
{
namespace plaidml
{
// Equal performs a simple elementwise equality.
template <>
void Impl<op::Equal>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplEqual, OpImpl<op::Equal>);
NGRAPH_PLAIDML_OP_CLASS(ImplGreater, OpImpl<op::Greater>);
NGRAPH_PLAIDML_OP_CLASS(ImplGreaterEq, OpImpl<op::GreaterEq>);
NGRAPH_PLAIDML_OP_CLASS(ImplLess, OpImpl<op::Less>);
NGRAPH_PLAIDML_OP_CLASS(ImplLessEq, OpImpl<op::LessEq>);
NGRAPH_PLAIDML_OP_CLASS(ImplMaximum, OpImpl<op::Maximum>);
NGRAPH_PLAIDML_OP_CLASS(ImplMinimum, OpImpl<op::Minimum>);
NGRAPH_PLAIDML_OP_CLASS(ImplNotEqual, OpImpl<op::NotEqual>);
}
}
}
// Equal performs a simple elementwise equality.
void ngraph::runtime::plaidml::ImplEqual::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A == B"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// Greater performs a simple elementwise greater-than comparison.
template <>
void Impl<op::Greater>::operator()()
{
// Greater performs a simple elementwise greater-than comparison.
void ngraph::runtime::plaidml::ImplGreater::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -56,14 +65,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A > B"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template <>
void Impl<op::GreaterEq>::operator()()
{
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
void ngraph::runtime::plaidml::ImplGreaterEq::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -71,14 +78,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A >= B"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// Less performs a simple elementwise less-than comparison.
template <>
void Impl<op::Less>::operator()()
{
// Less performs a simple elementwise less-than comparison.
void ngraph::runtime::plaidml::ImplLess::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -86,14 +91,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A < B"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
template <>
void Impl<op::LessEq>::operator()()
{
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
void ngraph::runtime::plaidml::ImplLessEq::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -101,14 +104,12 @@ namespace ngraph
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A <= B"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// Maximum performs a simple elementwise maximum.
template <>
void Impl<op::Maximum>::operator()()
{
// Maximum performs a simple elementwise maximum.
void ngraph::runtime::plaidml::ImplMaximum::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -117,12 +118,11 @@ namespace ngraph
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "max(A, B)"})
.finalize());
}
}
// Minimum performs a simple elementwise minimum.
template <>
void Impl<op::Minimum>::operator()()
{
// Minimum performs a simple elementwise minimum.
void ngraph::runtime::plaidml::ImplMinimum::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -131,34 +131,17 @@ namespace ngraph
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "min(A, B)"})
.finalize());
}
}
// NotEqual performs a simple elementwise not-equality.
template <>
void Impl<op::NotEqual>::operator()()
{
// NotEqual performs a simple elementwise not-equality.
void ngraph::runtime::plaidml::ImplNotEqual::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A != B"})
.finalize(),
TensorContents::LOGICAL);
}
namespace
{
Impl<op::Equal>::Registration register_equal;
Impl<op::Greater>::Registration register_greater;
Impl<op::GreaterEq>::Registration register_greater_eq;
Impl<op::Less>::Registration register_less;
Impl<op::LessEq>::Registration register_less_eq;
Impl<op::Maximum>::Registration register_maximum;
Impl<op::Minimum>::Registration register_minimum;
Impl<op::NotEqual>::Registration register_not_equal;
}
}
}
.finalize());
}
......@@ -23,10 +23,14 @@ namespace ngraph
{
namespace plaidml
{
// Concat views a tensor as a new type.
template <>
void Impl<op::Concat>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplConcat, OpImpl<op::Concat>);
}
}
}
// Concat views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConcat::Apply()
{
check_outputs(1);
auto f = start_tile_function();
......@@ -58,12 +62,10 @@ namespace ngraph
continue;
}
std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims(
"I" + sidx + "_D", 0, dim_count));
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims("I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
......@@ -78,22 +80,19 @@ namespace ngraph
}
}
})
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
s << "d" << idx;
if (saw_non_zero_tensor &&
idx == op().get_concatenation_axis())
if (saw_non_zero_tensor && idx == op().get_concatenation_axis())
{
s << " + " << offset.str();
}
out = s.str();
}
}))
.set(builder::ContractionInput{"I" + sidx}.add_indices(
"d", 0, dim_count)));
.set(builder::ContractionInput{"I" + sidx}.add_indices("d", 0, dim_count)));
if (saw_non_zero_tensor)
{
oexpr << " + ";
......@@ -106,12 +105,4 @@ namespace ngraph
f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize());
}
namespace
{
Impl<op::Concat>::Registration register_concat;
}
}
}
}
......@@ -24,25 +24,20 @@ namespace ngraph
{
namespace plaidml
{
// Convert views a tensor as a new type.
template <>
void Impl<op::Convert>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplConvert, OpImpl<op::Convert>);
}
}
}
// Convert views a tensor as a new type.
void ngraph::runtime::plaidml::ImplConvert::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(
start_tile_function()
set_output(start_tile_function()
.add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize());
}
namespace
{
Impl<op::Convert>::Registration register_convert;
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Convolution;
class ConvolutionBackpropData;
class ConvolutionBackpropFilters;
}
}
}
}
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::Convolution>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
private:
std::shared_ptr<ngraph::op::Convolution> m_src;
AxisVector m_data_axes;
AxisVector m_filters_axes;
AxisVector m_output_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropData>& get_src() const { return m_src; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_data_axes() const { return m_data_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropData> m_src;
AxisVector m_filters_axes;
AxisVector m_output_axes;
AxisVector m_data_axes;
};
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<ngraph::op::ConvolutionBackpropFilters>& get_src() const { return m_src; }
const AxisVector& get_data_axes() const { return m_data_axes; }
const AxisVector& get_output_axes() const { return m_output_axes; }
const AxisVector& get_filters_axes() const { return m_filters_axes; }
private:
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> m_src;
AxisVector m_data_axes;
AxisVector m_output_axes;
AxisVector m_filters_axes;
};
......@@ -26,11 +26,15 @@ namespace ngraph
{
namespace plaidml
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template <>
void Impl<op::Dot>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplDot, OpImpl<op::Dot>);
}
}
}
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
void ngraph::runtime::plaidml::ImplDot::Apply()
{
check_inputs(2);
check_outputs(1);
......@@ -46,8 +50,7 @@ namespace ngraph
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
set_output(
start_tile_function()
set_output(start_tile_function()
.add(builder::Input{op_input(0), "L"}
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DC", 1, reduce_limit + 1))
......@@ -68,12 +71,4 @@ namespace ngraph
.add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize());
}
namespace
{
Impl<op::Dot>::Registration register_dot;
}
}
}
}
......@@ -25,10 +25,14 @@ namespace ngraph
{
namespace plaidml
{
// FunctionCall invokes a sub-function.
template <>
void Impl<op::FunctionCall>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplFunctionCall, OpImpl<op::FunctionCall>);
}
}
}
// FunctionCall invokes a sub-function.
void ngraph::runtime::plaidml::ImplFunctionCall::Apply()
{
Build b;
build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer};
......@@ -36,8 +40,7 @@ namespace ngraph
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv =
b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
auto* iitv = b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
}
vertexai::plaidml::application app{f.apply(inputs)};
......@@ -46,12 +49,4 @@ namespace ngraph
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv]));
}
}
namespace
{
Impl<op::FunctionCall>::Registration register_function_call;
}
}
}
}
......@@ -26,19 +26,14 @@ namespace ngraph
namespace plaidml
{
template <typename O>
class IndexReductionImpl : public BaseImpl<O>
class IndexReductionBase : public OpImpl<O>
{
public:
IndexReductionImpl(Build* build, const O& op)
: BaseImpl<O>{build, op}
{
}
protected:
void build_index_reduction(const char* agg_op);
};
template <typename O>
void IndexReductionImpl<O>::build_index_reduction(const char* agg_op)
void IndexReductionBase<O>::build_index_reduction(const char* agg_op)
{
this->check_inputs(1);
this->check_outputs(1);
......@@ -117,37 +112,20 @@ namespace ngraph
.finalize());
}
template <>
struct ParentImpl<op::ArgMax>
{
using Type = IndexReductionImpl<op::ArgMax>;
};
template <>
struct ParentImpl<op::ArgMin>
{
using Type = IndexReductionImpl<op::ArgMin>;
};
NGRAPH_PLAIDML_OP_CLASS(ImplArgMax, IndexReductionBase<op::ArgMax>);
NGRAPH_PLAIDML_OP_CLASS(ImplArgMin, IndexReductionBase<op::ArgMin>);
}
}
}
// ArgMax computes the maximum index along a tensor axis.
template <>
void Impl<op::ArgMax>::operator()()
{
// ArgMax computes the maximum index along a tensor axis.
void ngraph::runtime::plaidml::ImplArgMax::Apply()
{
build_index_reduction(">");
}
}
// ArgMin computes the minimum index along a tensor axis.
template <>
void Impl<op::ArgMin>::operator()()
{
// ArgMin computes the minimum index along a tensor axis.
void ngraph::runtime::plaidml::ImplArgMin::Apply()
{
build_index_reduction("<");
}
namespace
{
Impl<op::ArgMax>::Registration register_argmax;
Impl<op::ArgMin>::Registration register_argmin;
}
}
}
}
......@@ -26,10 +26,15 @@ namespace ngraph
{
namespace plaidml
{
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template <>
void Impl<op::Parameter>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplParameter, OpImpl<op::Parameter>);
NGRAPH_PLAIDML_OP_CLASS(ImplResult, OpImpl<op::Result>);
}
}
}
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
void ngraph::runtime::plaidml::ImplParameter::Apply()
{
check_inputs(0);
check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
......@@ -39,25 +44,15 @@ namespace ngraph
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name));
}
}
// Result binds a PlaidML variable to a composed function output.
template <>
void Impl<op::Result>::operator()()
{
// Result binds a PlaidML variable to a composed function output.
void ngraph::runtime::plaidml::ImplResult::Apply()
{
check_inputs(1);
check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name));
}
namespace
{
Impl<op::Parameter>::Registration register_parameter;
Impl<op::Result>::Registration register_result;
}
}
}
}
......@@ -23,23 +23,25 @@ namespace ngraph
{
namespace plaidml
{
// LRN implements Local Response Normalization
template <>
void Impl<op::LRN>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplLRN, OpImpl<op::LRN>);
}
}
}
// LRN implements Local Response Normalization
void ngraph::runtime::plaidml::ImplLRN::Apply()
{
check_inputs(1);
check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2;
std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha()
<< ".0 / " << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / "
<< op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank))
.add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"}
......@@ -51,18 +53,9 @@ namespace ngraph
.set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank))
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize());
}))
.add(builder::Elementwise{"O", div_expr.str()})
.finalize());
}
namespace
{
Impl<op::LRN>::Registration register_local_response_norm;
}
}
}
}
......@@ -25,56 +25,47 @@ namespace ngraph
{
namespace plaidml
{
// And performs a simple elementwise logical and.
template <>
void Impl<op::And>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplAnd, OpImpl<op::And>);
NGRAPH_PLAIDML_OP_CLASS(ImplNot, OpImpl<op::Not>);
NGRAPH_PLAIDML_OP_CLASS(ImplOr, OpImpl<op::Or>);
}
}
}
// And performs a simple elementwise logical and.
void ngraph::runtime::plaidml::ImplAnd::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? B : A"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// Not performs a simple elementwise logical not.
template <>
void Impl<op::Not>::operator()()
{
// Not performs a simple elementwise logical not.
void ngraph::runtime::plaidml::ImplNot::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "I"})
.add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.finalize(),
TensorContents::LOGICAL);
}
.finalize());
}
// Or performs a simple elementwise logical or.
template <>
void Impl<op::Or>::operator()()
{
// Or performs a simple elementwise logical or.
void ngraph::runtime::plaidml::ImplOr::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0, TensorContents::LOGICAL), "A"})
.add(builder::Input{op_input(1, TensorContents::LOGICAL), "B"})
.add(builder::Input{op_input(0), "A"})
.add(builder::Input{op_input(1), "B"})
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A ? A : B"})
.finalize(),
TensorContents::LOGICAL);
}
namespace
{
Impl<op::And>::Registration register_and;
Impl<op::Not>::Registration register_not;
Impl<op::Or>::Registration register_or;
}
}
}
.finalize());
}
......@@ -26,10 +26,14 @@ namespace ngraph
{
namespace plaidml
{
// OneHot performs one-hot encoding along the requested axis.
template <>
void Impl<op::OneHot>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplOneHot, OpImpl<op::OneHot>);
}
}
}
// OneHot performs one-hot encoding along the requested axis.
void ngraph::runtime::plaidml::ImplOneHot::Apply()
{
check_inputs(1);
check_outputs(1);
......@@ -74,12 +78,9 @@ namespace ngraph
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"ZS"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"ZS"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx == op().get_one_hot_axis())
......@@ -94,19 +95,10 @@ namespace ngraph
})
.add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{
"Idx", "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"Idx",
"index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O",
tile_converter("OV", op().get_element_type())})
.add(builder::Elementwise{"O", tile_converter("OV", op().get_element_type())})
.finalize());
}
namespace
{
Impl<op::OneHot>::Registration register_one_hot;
}
}
}
}
......@@ -26,10 +26,17 @@ namespace ngraph
{
namespace plaidml
{
// AvgPool implements a batch average pooling operation.
template <>
void Impl<op::AvgPool>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplAvgPool, OpImpl<op::AvgPool>);
NGRAPH_PLAIDML_OP_CLASS(ImplMaxPool, OpImpl<op::MaxPool>);
NGRAPH_PLAIDML_OP_CLASS(ImplAvgPoolBackprop, OpImpl<op::AvgPoolBackprop>);
NGRAPH_PLAIDML_OP_CLASS(ImplMaxPoolBackprop, OpImpl<op::MaxPoolBackprop>);
}
}
}
// AvgPool implements a batch average pooling operation.
void ngraph::runtime::plaidml::ImplAvgPool::Apply()
{
check_inputs(1);
check_outputs(1);
......@@ -98,12 +105,11 @@ namespace ngraph
f.add(cpf.PoolContraction()).add(builder::Elementwise{"O", "S / Count"});
set_output(f.finalize());
}
}
// MaxPool implements a batch max pooling operation.
template <>
void Impl<op::MaxPool>::operator()()
{
// MaxPool implements a batch max pooling operation.
void ngraph::runtime::plaidml::ImplMaxPool::Apply()
{
check_inputs(1);
check_outputs(1);
......@@ -162,11 +168,10 @@ namespace ngraph
.add(cpf.O_out_header())
.add(cpf.PoolContraction())
.finalize());
}
}
template <>
void Impl<op::AvgPoolBackprop>::operator()()
{
void ngraph::runtime::plaidml::ImplAvgPoolBackprop::Apply()
{
check_inputs(1);
check_outputs(1);
......@@ -180,8 +185,7 @@ namespace ngraph
if (include_padding)
{
throw std::runtime_error(
"Include padding in average not yet implemented in PlaidML");
throw std::runtime_error("Include padding in average not yet implemented in PlaidML");
}
ngraph::CoordinateDiff pad_above;
......@@ -236,20 +240,18 @@ namespace ngraph
{
std::ostringstream s;
s << "XI" << i - 2;
ret.add(
builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
ret.add(builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
}
set_output(ret.add(cpf.Broadcast_Ones())
.add(cpf.Count())
.add(builder::Elementwise{"S", "DO / Count"})
.add(cpf.PoolContraction())
.finalize());
}
}
template <>
void Impl<op::MaxPoolBackprop>::operator()()
{
check_inputs(2);
void ngraph::runtime::plaidml::ImplMaxPoolBackprop::Apply()
{
check_inputs_ge(2);
check_outputs(1);
auto src_dims = op().get_inputs()[0].get_shape().size() - 2;
......@@ -307,15 +309,4 @@ namespace ngraph
.add(cpf.PoolContraction())
.add(cpf.PoolDerivContraction())
.finalize());
}
namespace
{
Impl<op::AvgPool>::Registration register_avg_pool;
Impl<op::MaxPool>::Registration register_max_pool;
Impl<op::AvgPoolBackprop>::Registration register_avg_pool_backprop;
Impl<op::MaxPoolBackprop>::Registration register_max_pool_backprop;
}
}
}
}
......@@ -32,19 +32,14 @@ namespace ngraph
namespace plaidml
{
template <typename O>
class ReductionImpl : public BaseImpl<O>
class ReductionBase : public OpImpl<O>
{
public:
ReductionImpl(Build* build, const O& op)
: BaseImpl<O>{build, op}
{
}
void build_reduction(const char* agg_op);
};
template <typename O>
void ReductionImpl<O>::build_reduction(const char* agg_op)
void ReductionBase<O>::build_reduction(const char* agg_op)
{
this->check_inputs(1);
this->check_outputs(1);
......@@ -90,61 +85,36 @@ namespace ngraph
.finalize());
}
template <>
struct ParentImpl<op::Max>
{
using Type = ReductionImpl<op::Max>;
};
template <>
struct ParentImpl<op::Min>
{
using Type = ReductionImpl<op::Min>;
};
template <>
struct ParentImpl<op::Product>
{
using Type = ReductionImpl<op::Product>;
};
template <>
struct ParentImpl<op::Reduce>
{
using Type = ReductionImpl<op::Reduce>;
};
template <>
struct ParentImpl<op::Sum>
{
using Type = ReductionImpl<op::Sum>;
};
NGRAPH_PLAIDML_OP_CLASS(ImplMax, ReductionBase<op::Max>);
NGRAPH_PLAIDML_OP_CLASS(ImplMin, ReductionBase<op::Min>);
NGRAPH_PLAIDML_OP_CLASS(ImplProduct, ReductionBase<op::Product>);
NGRAPH_PLAIDML_OP_CLASS(ImplReduce, ReductionBase<op::Reduce>);
NGRAPH_PLAIDML_OP_CLASS(ImplSum, ReductionBase<op::Sum>);
}
}
}
// Max reduces a tensor, taking the maximum along the specified axes.
template <>
void Impl<op::Max>::operator()()
{
// Max reduces a tensor, taking the maximum along the specified axes.
void ngraph::runtime::plaidml::ImplMax::Apply()
{
build_reduction(">");
}
}
// Min reduces a tensor, taking the minimum along the specified axes.
template <>
void Impl<op::Min>::operator()()
{
// Min reduces a tensor, taking the minimum along the specified axes.
void ngraph::runtime::plaidml::ImplMin::Apply()
{
build_reduction("<");
}
}
// Min reduces a tensor, taking the product along the specified axes.
template <>
void Impl<op::Product>::operator()()
{
// Min reduces a tensor, taking the product along the specified axes.
void ngraph::runtime::plaidml::ImplProduct::Apply()
{
build_reduction("*");
}
}
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template <>
void Impl<op::Reduce>::operator()()
{
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
void ngraph::runtime::plaidml::ImplReduce::Apply()
{
check_inputs(2);
check_outputs(1);
......@@ -185,13 +155,10 @@ namespace ngraph
start_tile_function()
.add(builder::Input{op_input(1), "I"})
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices("d", 0, agg_dim_limit)
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < agg_dim_limit; ++idx)
{
out = "1";
......@@ -210,12 +177,9 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>>
out) {
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
{
if (!op().get_reduction_axes().count(idx))
{
......@@ -223,12 +187,9 @@ namespace ngraph
}
}
})
.add_dims([&](
std::back_insert_iterator<std::list<std::string>>
out) {
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
.add_dims(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
{
if (!op().get_reduction_axes().count(idx))
{
......@@ -236,8 +197,8 @@ namespace ngraph
}
}
}))
.set(builder::ContractionInput{"I"}.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
.set(builder::ContractionInput{"I"}.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
{
std::size_t cidx = 0;
......@@ -255,23 +216,10 @@ namespace ngraph
}
set_output(result);
}
}
// Sum reduces a tensor, summing the specified axes.
template <>
void Impl<op::Sum>::operator()()
{
// Sum reduces a tensor, summing the specified axes.
void ngraph::runtime::plaidml::ImplSum::Apply()
{
build_reduction("+");
}
namespace
{
Impl<op::Max>::Registration register_max;
Impl<op::Min>::Registration register_min;
Impl<op::Product>::Registration register_product;
Impl<op::Reduce>::Registration register_reduce;
Impl<op::Sum>::Registration register_sum;
}
}
}
}
......@@ -25,10 +25,14 @@ namespace ngraph
{
namespace plaidml
{
// ReplaceSlice replaces part of a tensor with another tensor.
template <>
void Impl<op::ReplaceSlice>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplReplaceSlice, OpImpl<op::ReplaceSlice>);
}
}
}
// ReplaceSlice replaces part of a tensor with another tensor.
void ngraph::runtime::plaidml::ImplReplaceSlice::Apply()
{
check_inputs(2);
check_outputs(1);
......@@ -49,13 +53,11 @@ namespace ngraph
.add(builder::Input{op_input(0), "L"}.add_dims("D", 0, shape.size()))
.add(builder::Input{op_input(1), "S"}.add_dims("SD", 0, shape.size()))
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_dims("D", 0, shape.size())
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx)
{
auto stride = op().get_strides()[idx];
......@@ -81,10 +83,8 @@ namespace ngraph
out = didx.str();
}
}))
.set(builder::ContractionInput{"S"}.add_indices(
"d", 0, shape.size()))
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.set(builder::ContractionInput{"S"}.add_indices("d", 0, shape.size()))
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx)
{
out = "d" + std::to_string(idx) + " < " +
......@@ -94,12 +94,4 @@ namespace ngraph
})
.set_default("L"))
.finalize());
}
namespace
{
Impl<op::ReplaceSlice>::Registration register_replace_slice;
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_ops_replicate.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplReplicate, OpImpl<plaidml::op::Replicate>);
}
}
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
std::size_t replication_axis,
std::size_t replication_count)
: Op{"Replicate", NodeVector{arg}}
, m_replication_axes(arg->get_shape().size(), 1)
{
m_replication_axes.at(replication_axis) = replication_count;
constructor_validate_and_infer_types();
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
std::vector<std::size_t> replication_axes)
: Op{"Replicate", NodeVector{arg}}
, m_replication_axes(std::move(replication_axes))
{
if (arg->get_shape().size() != m_replication_axes.size())
{
throw ngraph_error{"Replicate requires compatible axes dimensions"};
}
constructor_validate_and_infer_types();
}
void ngraph::runtime::plaidml::op::Replicate::validate_and_infer_types()
{
const auto& arg = get_arguments().at(0);
Shape shape = arg->get_shape();
for (auto rit = m_replication_axes.begin(), sit = shape.begin();
rit != m_replication_axes.end();
++rit, ++sit)
{
*sit *= *rit;
}
set_output_type(0, arg->get_element_type(), shape);
}
std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Replicate::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 1)
{
throw ngraph_error{"Replicate requires exactly one input"};
}
if (new_args.at(0)->get_shape().size() != m_replication_axes.size())
{
throw ngraph_error{"Replicate requires identical dimensions in inputs"};
}
return std::make_shared<Replicate>(new_args.at(0), m_replication_axes);
}
void ngraph::runtime::plaidml::ImplReplicate::Apply()
{
check_inputs(1);
check_outputs(1);
const auto& axes = op().get_replication_axes();
const auto& ishape = op().get_input_shape(0);
set_output(
start_tile_function()
.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, axes.size()))
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < axes.size(); ++idx)
{
std::string dsize = "D" + std::to_string(idx);
if (axes.at(idx) != 1)
{
dsize = dsize + " * " + std::to_string(axes.at(idx));
}
out = dsize;
}
})
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < axes.size(); ++idx)
{
std::string didx = "d" + std::to_string(idx);
if (axes.at(idx) != 1)
{
if (ishape.at(idx) == 1)
{
didx = didx + " + s" + std::to_string(idx);
}
else
{
didx = didx + " + (s" + std::to_string(idx) + " * " +
std::to_string(ishape.at(idx)) + ")";
}
}
out = didx;
}
}))
.set(builder::ContractionInput{"I"}.add_indices("d", 0, axes.size())))
.finalize());
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <vector>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Replicate;
}
}
}
}
// Replicate works like Concat, but only over identical inputs. This
// restriction allows it to be substantially more efficient.
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{
public:
Replicate(std::shared_ptr<Node> arg,
std::size_t replication_axis,
std::size_t replication_count);
Replicate(std::shared_ptr<Node> arg, std::vector<std::size_t> replication_axes);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
/// \return The replication axes: axis index -> the replication count along that axis.
const std::vector<std::size_t>& get_replication_axes() const { return m_replication_axes; }
private:
std::vector<std::size_t> m_replication_axes;
};
......@@ -25,10 +25,14 @@ namespace ngraph
{
namespace plaidml
{
// Reverse reverses the selected axes within a tensor.
template <>
void Impl<op::Reverse>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplReverse, OpImpl<op::Reverse>);
}
}
}
// Reverse reverses the selected axes within a tensor.
void ngraph::runtime::plaidml::ImplReverse::Apply()
{
check_inputs(1);
check_outputs(1);
......@@ -41,8 +45,8 @@ namespace ngraph
.set(builder::ContractionOutput{"O"}
.add_indices("d", 0, shape.size())
.add_dims("D", 0, shape.size()))
.set(builder::ContractionInput{"I"}.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
.set(builder::ContractionInput{"I"}.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx)
{
auto sidx = std::to_string(idx);
......@@ -57,12 +61,4 @@ namespace ngraph
}
})))
.finalize());
}
namespace
{
Impl<op::Reverse>::Registration register_reverse;
}
}
}
}
......@@ -24,10 +24,14 @@ namespace ngraph
{
namespace plaidml
{
// Slice takes a sub-slice of a tensor.
template <>
void Impl<op::Slice>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplSlice, OpImpl<op::Slice>);
}
}
}
// Slice takes a sub-slice of a tensor.
void ngraph::runtime::plaidml::ImplSlice::Apply()
{
check_inputs(1);
check_outputs(1);
NGRAPH_DEBUG << "Slice: low: " << op().get_lower_bounds();
......@@ -39,21 +43,17 @@ namespace ngraph
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("ID", 0, dim_limit))
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices("od", 0, dim_limit)
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
std::ostringstream s;
std::size_t stride = op().get_strides()[idx];
std::ptrdiff_t trim_count =
op().get_lower_bounds()[idx] +
(shape[idx] - op().get_upper_bounds()[idx]) +
1 - stride;
(shape[idx] - op().get_upper_bounds()[idx]) + 1 - stride;
if ((stride != 1) && trim_count)
{
s << "(";
......@@ -106,12 +106,4 @@ namespace ngraph
}
})))
.finalize());
}
namespace
{
Impl<op::Slice>::Registration register_slice;
}
}
}
}
......@@ -25,10 +25,14 @@ namespace ngraph
{
namespace plaidml
{
// Softmax implements a standard ML softmax operation.
template <>
void Impl<op::Softmax>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplSoftmax, OpImpl<op::Softmax>);
}
}
}
// Softmax implements a standard ML softmax operation.
void ngraph::runtime::plaidml::ImplSoftmax::Apply()
{
check_inputs(1);
check_outputs(1);
......@@ -36,8 +40,7 @@ namespace ngraph
auto dim_limit = shape.size();
auto f = start_tile_function();
f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit))
.add(builder::Output{"O"});
f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit)).add(builder::Output{"O"});
bool reorder_needed = false;
bool saw_element = false;
......@@ -78,8 +81,7 @@ namespace ngraph
{
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"RI"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs)
{
out = "D" + std::to_string(idx);
......@@ -89,8 +91,7 @@ namespace ngraph
out = "D" + std::to_string(idx);
}
})
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs)
{
out = "d" + std::to_string(idx);
......@@ -126,8 +127,7 @@ namespace ngraph
{
// Take the softmax.
std::ostringstream softmax;
softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements
<< ")";
softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements << ")";
f.add(builder::Elementwise{output, softmax.str()});
}
......@@ -169,12 +169,4 @@ namespace ngraph
}
set_output(f.finalize());
}
namespace
{
Impl<op::Softmax>::Registration register_softmax;
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <utility>
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_tile.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplTile, OpImpl<op::Tile>);
}
}
}
ngraph::runtime::plaidml::op::Tile::Tile(
const std::string& node_type,
vertexai::plaidml::function function,
const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs)
: Node{node_type, args, outputs.size()}
, m_function{std::move(function)}
, m_output_shapes{std::move(outputs)}
{
constructor_validate_and_infer_types();
}
void ngraph::runtime::plaidml::op::Tile::validate_and_infer_types()
{
// TODO: It would be useful to have PlaidML deduce the output
// shapes, instead of having them passed in via the
// constructor. The primary barrier to doing so is that
// PlaidML placeholders always have a fixed number of
// dimensions but arbitrary dimension sizes, and the only way
// to pin them down to a concrete dimension size is to bind a
// tensor to them, which requires actually allocating the
// tensor. In principal, we could fix this pretty easily;
// we'll need to know more about where the PlaidML API is
// going before doing so, though.
if (get_input_size() != m_function.num_inputs())
{
throw ngraph_error{"Incorrect input count for Tile operation node"};
}
if (m_output_shapes.size() != m_function.num_outputs())
{
throw ngraph_error{"Incorrect output count for Tile operation node"};
}
std::size_t idx = 0;
for (auto& output_shape : m_output_shapes)
{
set_output_type(idx++, std::get<0>(output_shape), std::get<1>(output_shape));
}
}
std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Tile::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != get_input_size())
{
throw ngraph_error{"Tile node input counts cannot be changed for a given Tile function"};
}
return std::make_shared<Tile>(description(), m_function, new_args, m_output_shapes);
}
void ngraph::runtime::plaidml::ImplTile::Apply()
{
vertexai::plaidml::function::positional_t inputs;
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
inputs.emplace_back(op_input(idx));
}
auto app = op().func().apply(inputs);
for (std::size_t idx = 0; idx < op().get_output_size(); ++idx)
{
set_output(idx, app.get_output(idx));
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <tuple>
#include <vector>
#include <plaidml/plaidml++.h>
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
/// An op directly representing PlaidML Tile code.
class Tile;
}
}
}
}
class ngraph::runtime::plaidml::op::Tile final : public Node
{
public:
Tile(const std::string& node_type,
vertexai::plaidml::function function,
const NodeVector& args,
std::vector<std::tuple<element::Type, PartialShape>> outputs);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
vertexai::plaidml::function func() const { return m_function; }
private:
vertexai::plaidml::function m_function;
std::vector<std::tuple<element::Type, PartialShape>> m_output_shapes;
};
......@@ -35,10 +35,26 @@ namespace ngraph
{
namespace plaidml
{
// acos performs a simple elementwise arccos function.
template <>
void Impl<op::Acos>::operator()()
{
NGRAPH_PLAIDML_OP_CLASS(ImplAcos, OpImpl<op::Acos>);
NGRAPH_PLAIDML_OP_CLASS(ImplAsin, OpImpl<op::Asin>);
NGRAPH_PLAIDML_OP_CLASS(ImplAtan, OpImpl<op::Atan>);
NGRAPH_PLAIDML_OP_CLASS(ImplCos, OpImpl<op::Cos>);
NGRAPH_PLAIDML_OP_CLASS(ImplCosh, OpImpl<op::Cosh>);
NGRAPH_PLAIDML_OP_CLASS(ImplExp, OpImpl<op::Exp>);
NGRAPH_PLAIDML_OP_CLASS(ImplLog, OpImpl<op::Log>);
NGRAPH_PLAIDML_OP_CLASS(ImplPower, OpImpl<op::Power>);
NGRAPH_PLAIDML_OP_CLASS(ImplSin, OpImpl<op::Sin>);
NGRAPH_PLAIDML_OP_CLASS(ImplSinh, OpImpl<op::Sinh>);
NGRAPH_PLAIDML_OP_CLASS(ImplSqrt, OpImpl<op::Sqrt>);
NGRAPH_PLAIDML_OP_CLASS(ImplTan, OpImpl<op::Tan>);
NGRAPH_PLAIDML_OP_CLASS(ImplTanh, OpImpl<op::Tanh>);
}
}
}
// acos performs a simple elementwise arccos function.
void ngraph::runtime::plaidml::ImplAcos::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -46,12 +62,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "acos(I)"})
.finalize());
}
}
// asin performs a simple elementwise arcsin function.
template <>
void Impl<op::Asin>::operator()()
{
// asin performs a simple elementwise arcsin function.
void ngraph::runtime::plaidml::ImplAsin::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -59,12 +74,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "asin(I)"})
.finalize());
}
}
// atan performs a simple elementwise arctan function.
template <>
void Impl<op::Atan>::operator()()
{
// atan performs a simple elementwise arctan function.
void ngraph::runtime::plaidml::ImplAtan::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -72,12 +86,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "atan(I)"})
.finalize());
}
}
// cos performs a simple elementwise cos function.
template <>
void Impl<op::Cos>::operator()()
{
// cos performs a simple elementwise cos function.
void ngraph::runtime::plaidml::ImplCos::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -85,12 +98,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cos(I)"})
.finalize());
}
}
// cosh performs a simple elementwise hyperbolic cos function.
template <>
void Impl<op::Cosh>::operator()()
{
// cosh performs a simple elementwise hyperbolic cos function.
void ngraph::runtime::plaidml::ImplCosh::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -98,12 +110,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cosh(I)"})
.finalize());
}
}
// exp performs a simple elementwise natural exponential function.
template <>
void Impl<op::Exp>::operator()()
{
// exp performs a simple elementwise natural exponential function.
void ngraph::runtime::plaidml::ImplExp::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -111,12 +122,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "exp(I)"})
.finalize());
}
}
// log performs a simple elementwise natural logarithm function.
template <>
void Impl<op::Log>::operator()()
{
// log performs a simple elementwise natural logarithm function.
void ngraph::runtime::plaidml::ImplLog::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -124,12 +134,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "log(I)"})
.finalize());
}
}
// power performs a simple elementwise power function.
template <>
void Impl<op::Power>::operator()()
{
// power performs a simple elementwise power function.
void ngraph::runtime::plaidml::ImplPower::Apply()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -138,12 +147,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "pow(I, E)"})
.finalize());
}
}
// sin performs a simple elementwise sin function.
template <>
void Impl<op::Sin>::operator()()
{
// sin performs a simple elementwise sin function.
void ngraph::runtime::plaidml::ImplSin::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -151,12 +159,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sin(I)"})
.finalize());
}
}
// sinh performs a simple elementwise hyperbolic sin function.
template <>
void Impl<op::Sinh>::operator()()
{
// sinh performs a simple elementwise hyperbolic sin function.
void ngraph::runtime::plaidml::ImplSinh::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -164,12 +171,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sinh(I)"})
.finalize());
}
}
// sqrt performs a simple elementwise square root function.
template <>
void Impl<op::Sqrt>::operator()()
{
// sqrt performs a simple elementwise square root function.
void ngraph::runtime::plaidml::ImplSqrt::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -177,12 +183,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sqrt(I)"})
.finalize());
}
}
// tan performs a simple elementwise tangent function.
template <>
void Impl<op::Tan>::operator()()
{
// tan performs a simple elementwise tangent function.
void ngraph::runtime::plaidml::ImplTan::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -190,12 +195,11 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tan(I)"})
.finalize());
}
}
// tanh performs a simple elementwise hyperbolic tangent function.
template <>
void Impl<op::Tanh>::operator()()
{
// tanh performs a simple elementwise hyperbolic tangent function.
void ngraph::runtime::plaidml::ImplTanh::Apply()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -203,24 +207,4 @@ namespace ngraph
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tanh(I)"})
.finalize());
}
namespace
{
Impl<op::Acos>::Registration register_acos;
Impl<op::Asin>::Registration register_asin;
Impl<op::Atan>::Registration register_atan;
Impl<op::Cos>::Registration register_cos;
Impl<op::Cosh>::Registration register_cosh;
Impl<op::Exp>::Registration register_exp;
Impl<op::Log>::Registration register_log;
Impl<op::Power>::Registration register_power;
Impl<op::Sin>::Registration register_sin;
Impl<op::Sinh>::Registration register_sinh;
Impl<op::Sqrt>::Registration register_sqrt;
Impl<op::Tan>::Registration register_tan;
Impl<op::Tanh>::Registration register_tanh;
}
}
}
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_ops_winograd.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
namespace vp = vertexai::plaidml;
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
NGRAPH_PLAIDML_OP_CLASS(ImplWinograd, OpImpl<plaidml::op::Winograd>);
}
}
}
ngraph::runtime::plaidml::op::Winograd::Winograd(std::shared_ptr<plaidml::op::Convolution> conv,
const NodeVector& args)
: Op{"Winograd", args}
, m_conv{std::move(conv)}
{
constructor_validate_and_infer_types();
}
void ngraph::runtime::plaidml::op::Winograd::validate_and_infer_types()
{
set_output_type(0, m_conv->get_element_type(), m_conv->get_output_partial_shape(0));
}
std::shared_ptr<ngraph::Node>
ngraph::runtime::plaidml::op::Winograd::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 5)
{
throw ngraph_error{"Winograd requires five inputs (data, filters, A, B, and G)"};
}
return std::make_shared<Winograd>(m_conv, new_args);
}
void ngraph::runtime::plaidml::ImplWinograd::Apply()
{
check_inputs(5);
check_outputs(1);
const auto& data_shape = op().get_input_shape(0);
const auto& filters_shape = op().get_input_shape(1);
const auto& padding_above = op().get_conv()->get_src()->get_padding_above();
const auto& padding_below = op().get_conv()->get_src()->get_padding_below();
vp::variable xo(static_cast<int64_t>(data_shape.at(1) + padding_below.at(0) +
padding_above.at(0) - filters_shape.at(0) + 1));
vp::variable yo(static_cast<int64_t>(data_shape.at(2) + padding_below.at(1) +
padding_above.at(1) - filters_shape.at(1) + 1));
vp::variable xp(static_cast<int64_t>(padding_below.at(0)));
vp::variable yp(static_cast<int64_t>(padding_below.at(1)));
set_output(vp::function{R"(
function (I[N, X, Y, CI], K[S, S, CI, CO], A[BI, BO], B[BI, BI], G[BI, S], XO, YO, XP, YP) -> (O) {
Assert = assert_winograd_valid(BI - CI + 1 == BO);
XB = (XO + BO - 1) / BO;
YB = (YO + BO - 1) / BO;
U1[i, j, ci, co : BI, S, CI, CO] = +(G[i, k] * K[k, j, ci, co]);
U[i, j, ci, co : BI, BI, CI, CO] = +(U1[i, k, ci, co] * G[j, k]);
V1[n, i, j, x, y, ci : N, BI, BI, XB, YB, CI] = +(B[k, i] * I[n, BO*x + k - XP, BO*y + j - YP, ci]);
V[n, i, j, x, y, ci : N, BI, BI, XB, YB, CI] = +(V1[n, i, k, x, y, ci] * B[k, j]);
M[n, i, j, x, y, co : N, BI, BI, XB, YB, CO] = +(V[n, i, j, x, y, ci] * U[i, j, ci, co]);
O1[n, i, j, x, y, co : N, BO, BI, XB, YB, CO] = +(A[k, i] * M[n, k, j, x, y, co]);
O[n, BO*x + i, BO*y + j, co : N, XO, YO, CO] = +(O1[n, i, k, x, y, co] * A[k, j]) no_defract;
})"}(op_input(0), op_input(1), op_input(2), op_input(3), op_input(4), xo, yo, xp, yp));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_convolution.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace op
{
class Winograd;
}
}
}
}
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{
public:
Winograd(std::shared_ptr<Convolution> conv, const NodeVector& args);
void validate_and_infer_types() final;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const final;
const std::shared_ptr<Convolution> get_conv() const { return m_conv; }
private:
std::shared_ptr<Convolution> m_conv;
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_concat_elision.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_replicate.hpp"
ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
{
auto concat_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
return dynamic_cast<ngraph::op::Concat*>(node.get()) != nullptr;
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto concat = std::dynamic_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments();
// Elide one-argument concats.
if (args.size() == 1)
{
replace_node(concat, args.at(0));
return true;
}
// Check for a run of inputs from the same source -- if we see
// one, we can simplify it, and otherwise, we already have the
// best Concat we can make.
{
bool found_input_run = false;
std::size_t prev_instance_id = concat->get_instance_id(); // This will never be an arg
for (const auto& arg : args)
{
auto current_instance_id = arg->get_instance_id();
if (current_instance_id == prev_instance_id)
{
found_input_run = true;
break;
}
prev_instance_id = current_instance_id;
}
if (!found_input_run)
{
return false;
}
}
// Merge runs with the same input into Replicate calls.
NodeVector new_args;
auto run_begin = args.begin();
// N.B. There's at least one argument to concat at this point
// (actually, two, but we only care that there's at least
// one), so run_end is still valid after this incremenent.
auto run_end = run_begin + 1;
for (;;)
{
// Invariants:
// * [run_begin..run_end) is a range of identical arguments
// * run_begin < run_end (there's at least one member of the range).
if (run_end == args.end() || *run_begin != *run_end)
{
// End of the range.
if (run_end - run_begin == 1)
{
new_args.emplace_back(*run_begin);
}
else
{
new_args.emplace_back(std::make_shared<plaidml::op::Replicate>(
*run_begin, concat->get_concatenation_axis(), run_end - run_begin));
}
if (run_end == args.end())
{
break;
}
run_begin = run_end;
}
++run_end;
}
// Re-check for single-input concat.
if (new_args.size() == 1)
{
replace_node(concat, new_args.at(0));
return true;
}
// Build a replacement concat.
auto new_concat =
std::make_shared<ngraph::op::Concat>(new_args, concat->get_concatenation_axis());
replace_node(std::move(concat), std::move(new_concat));
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ConcatElision;
}
}
}
}
// A pass to elide unnecessary concats:
//
// ) Concats with a single input and single output can be entitely
// elided.
//
// ) Concats with multiples of the same input can be replaced by
// Replicate.
class ngraph::runtime::plaidml::pass::ConcatElision final : public ngraph::pass::GraphRewrite
{
public:
ConcatElision();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <typeindex>
#include "ngraph/graph_util.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_tile.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_explicit_logicals.hpp"
void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data()
{
auto producer_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
static const std::unordered_set<std::type_index> logical_producers{
std::type_index{typeid(ngraph::op::And)},
std::type_index{typeid(ngraph::op::Equal)},
std::type_index{typeid(ngraph::op::Greater)},
std::type_index{typeid(ngraph::op::GreaterEq)},
std::type_index{typeid(ngraph::op::Less)},
std::type_index{typeid(ngraph::op::LessEq)},
std::type_index{typeid(ngraph::op::Not)},
std::type_index{typeid(ngraph::op::NotEqual)},
std::type_index{typeid(ngraph::op::Or)}};
const ngraph::Node* node_ptr = node.get();
// True iff this node produces a logical output.
return logical_producers.count(std::type_index(typeid(*node_ptr))) != 0;
});
auto data_consumer_op = std::make_shared<pattern::op::Any>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
static const std::unordered_set<std::type_index> logical_consumers{
std::type_index{typeid(ngraph::op::And)},
std::type_index{typeid(ngraph::op::Equal)},
std::type_index{typeid(ngraph::op::Not)},
std::type_index{typeid(ngraph::op::NotEqual)},
std::type_index{typeid(ngraph::op::Or)}};
const ngraph::Node* node_ptr = node.get();
// True iff this node should not be presented with a logical output.
return logical_consumers.count(std::type_index(typeid(*node_ptr))) == 0;
},
NodeVector{producer_op});
pattern::graph_rewrite_callback callback = [producer_op](pattern::Matcher& m) {
auto consumer = m.get_match_root();
auto producer = m.get_pattern_map()[producer_op];
NGRAPH_DEBUG << "Adding conversion for " << producer->description() << " -> "
<< consumer->description();
ngraph::insert_new_node_between(
producer,
consumer,
std::make_shared<op::Tile>(
"ConvertLogicalToData",
vertexai::plaidml::function{"function (I) -> (O) { O = as_int(I ? 1 : 0, 8);}"},
NodeVector{producer},
std::vector<std::tuple<element::Type, PartialShape>>{
{element::i8, PartialShape{producer->get_output_shape(0)}}}));
return true;
};
auto m = std::make_shared<pattern::Matcher>(data_consumer_op, callback);
add_matcher(m);
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ExplicitLogicals;
}
}
}
}
// ExplicitLogicals handles conversion between logical and binary
// values.
//
// In PlaidML, the logical values do not have well-defined binary
// representations -- for example, due to how some frameworks handle
// vectorization, 'true' might be represented as a binary '1' or as a
// binary '-1', even within the same kernel.
//
// nGraph semantics are that 'false' == '0' and 'true' == '1', that
// booleans are exactly equivalent to binary uint8 values, and that
// binary uint8 values can be passed directly into logical operations.
//
// The ExplicitLogicals pass inserts conversions as needed to preserve
// the semantics expected by the other nGraph operations.
class ngraph::runtime::plaidml::pass::ExplicitLogicals final : public ngraph::pass::GraphRewrite
{
public:
ExplicitLogicals()
: GraphRewrite()
{
construct_logical_to_data();
}
private:
void construct_logical_to_data();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_implicit_broadcast.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
{
auto src_op = std::make_shared<pattern::op::Label>(
element::i8, Shape{}, [](std::shared_ptr<Node>) { return true; });
auto broadcast_op = std::make_shared<op::Broadcast>(src_op, Shape{}, AxisSet{});
auto target_op = std::make_shared<pattern::op::AnyOf>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return pattern::has_class<op::util::UnaryElementwiseArithmetic>()(node) ||
pattern::has_class<op::util::BinaryElementwiseArithmetic>()(node);
},
NodeVector{broadcast_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
// Since the broadcast is going to an elementwise operation, we
// can always replace it with an equivalent reshape that uses ones
// for the broadcast axes.
auto src = m.get_matched_nodes().at(2);
Shape src_shape = src->get_shape();
auto broadcast = std::dynamic_pointer_cast<op::Broadcast>(m.get_matched_nodes().at(1));
AxisVector reshape_order;
Shape reshape_shape;
std::size_t input_dim = 0;
std::size_t didx_limit = broadcast->get_broadcast_shape().size();
for (std::size_t didx = 0; didx < didx_limit; ++didx)
{
if (broadcast->get_broadcast_axes().count(didx))
{
reshape_shape.emplace_back(1);
}
else
{
reshape_order.emplace_back(input_dim);
reshape_shape.emplace_back(src_shape.at(input_dim++));
}
}
auto reshape = std::make_shared<op::Reshape>(src, reshape_order, reshape_shape);
replace_node(broadcast, reshape);
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ImplicitBroadcast;
}
}
}
}
// The PlaidML nGraph runtime's implementation of the Broadcast
// operation requires a contraction, and then the broadcasted output
// needs to be read by the downstream operation.
//
// Most explicit Broadcast operations are passed as inputs to
// elementwise operations. When a tensor is used as an input to an
// elementwise operation, PlaidML automatically provides NumPy
// broadcast semantics.
//
// So eliding Broadcast operations can significantly reduce the IO
// needed by an elementwise operation, and eliminates an unnecessary
// contraction.
class ngraph::runtime::plaidml::pass::ImplicitBroadcast final : public ngraph::pass::GraphRewrite
{
public:
ImplicitBroadcast();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_convolution.hpp"
#include "ngraph/runtime/plaidml/plaidml_pass_lower_convolutions.hpp"
ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{
auto convolution_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
return pattern::has_class<ngraph::op::Convolution>()(node) ||
pattern::has_class<ngraph::op::ConvolutionBackpropData>()(node) ||
pattern::has_class<ngraph::op::ConvolutionBackpropFilters>()(node);
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto to_transpose = [](const std::shared_ptr<Node>& node) -> ngraph::op::Reshape* {
if (!node)
{
return nullptr;
}
auto* reshape = dynamic_cast<ngraph::op::Reshape*>(node.get());
if (reshape && reshape->get_is_transpose())
{
return reshape;
}
return nullptr;
};
auto to_axes = [](const std::shared_ptr<Node>& node, ngraph::op::Reshape* reshape) {
if (reshape)
{
return reshape->get_input_order();
}
AxisVector result(node->get_shape().size());
std::iota(result.begin(), result.end(), 0);
return result;
};
std::shared_ptr<Node> node = m.get_match_root();
std::shared_ptr<Node> output;
auto users = node->get_users(true);
if (users.size() == 1)
{
output = users[0];
}
auto target = node;
auto* output_transpose = to_transpose(output);
if (output_transpose)
{
target = output;
}
// N.B. For the output axes, we can either use the convolution
// or the final output op -- but there might not be an output
// op. Using target always works.
AxisVector out_axes = to_axes(target, output_transpose);
auto lhs = node->get_arguments().at(0);
auto* lhs_transpose = to_transpose(lhs);
if (lhs_transpose)
{
lhs = lhs_transpose->get_arguments().at(0);
}
AxisVector lhs_axes = to_axes(lhs, lhs_transpose);
auto rhs = node->get_arguments().at(1);
auto* rhs_transpose = to_transpose(rhs);
if (rhs_transpose)
{
rhs = rhs_transpose->get_arguments().at(0);
}
AxisVector rhs_axes = to_axes(rhs, rhs_transpose);
{
auto conv = std::dynamic_pointer_cast<ngraph::op::Convolution>(node);
if (conv)
{
replace_node(target,
std::make_shared<plaidml::op::Convolution>(conv,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
{
auto conv_bp_data =
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropData>(node);
if (conv_bp_data)
{
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropData>(conv_bp_data,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
{
auto conv_bp_filters =
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropFilters>(node);
if (conv_bp_filters)
{
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(conv_bp_filters,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
return false;
};
add_matcher(std::make_shared<pattern::Matcher>(convolution_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class LowerConvolutions;
}
}
}
}
// Lowers op::Convolution, op::ConvolutionDataBackprop, and
// op::ConvolutionFilterBackprop into plaidml::op::Convolution,
// allowing for unification with transposition operations.
class ngraph::runtime::plaidml::pass::LowerConvolutions final : public ngraph::pass::GraphRewrite
{
public:
LowerConvolutions();
};
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/plaidml/plaidml_pass_replicate_combination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/plaidml/plaidml_ops_replicate.hpp"
ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
{
auto upper_replicate_op =
std::make_shared<pattern::op::Label>(element::i8, Shape{}, [](std::shared_ptr<Node> node) {
return pattern::has_class<plaidml::op::Replicate>()(node);
});
auto lower_replicate_op = std::make_shared<pattern::op::Any>(
element::i8,
Shape{},
[](std::shared_ptr<Node> node) {
return pattern::has_class<plaidml::op::Replicate>()(node);
},
NodeVector{upper_replicate_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto nodes = m.get_matched_nodes();
auto lower = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::dynamic_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
std::vector<size_t> axes = lower->get_replication_axes();
const std::vector<size_t>& upper_axes = upper->get_replication_axes();
auto uit = upper_axes.begin();
for (auto ait = axes.begin(); ait != axes.end(); ++ait, ++uit)
{
*ait *= *uit;
}
replace_node(lower,
std::make_shared<plaidml::op::Replicate>(upper->get_arguments().at(0),
std::move(axes)));
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(lower_replicate_op, callback));
}
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
namespace pass
{
class ReplicateCombination;
}
}
}
}
// Combines adjacent Replicate operations.
class ngraph::runtime::plaidml::pass::ReplicateCombination final : public ngraph::pass::GraphRewrite
{
public:
ReplicateCombination();
};
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -130,12 +130,3 @@ std::string ngraph::runtime::plaidml::tile_converter(const std::string& tensor_n
}
return tile_converter(tensor_name, to_plaidml(element_type));
}
vp::variable ngraph::runtime::plaidml::plaidml_logical_to_data(vp::variable var, bool debug)
{
return builder::Function{"logicalToData", debug}
.add(builder::Input{var, "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "as_int(I ? 1 : 0, 8)"})
.finalize();
}
......@@ -46,9 +46,6 @@ namespace ngraph
std::string tile_converter(const std::string& tensor_name,
const ngraph::element::Type& element_type);
vertexai::plaidml::variable plaidml_logical_to_data(vertexai::plaidml::variable var,
bool debug);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment