Commit 61df6725 authored by Rob Earhart's avatar Rob Earhart Committed by Robert Kimball

[PlaidML] Specialize within namespaces (for Linux) (#1948)

parent 5698fa75
......@@ -28,10 +28,16 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Abs performs a simple elementwise absolute value.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Abs performs a simple elementwise absolute value.
template <>
void Impl<op::Abs>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -39,12 +45,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "abs(I)"})
.finalize());
}
}
// Add performs a simple elementwise addition.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()()
{
// Add performs a simple elementwise addition.
template <>
void Impl<op::Add>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Add>::operator()()
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A + B"})
.finalize());
}
}
// Ceiling performs a simple elementwise ceiling.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()()
{
// Ceiling performs a simple elementwise ceiling.
template <>
void Impl<op::Ceiling>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "ceil(I)"})
.finalize());
}
}
// Divide performs a simple elementwise division.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()()
{
// Divide performs a simple elementwise division.
template <>
void Impl<op::Divide>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -80,12 +86,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::operator()()
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A / B"})
.finalize());
}
}
// Floor performs a simple elementwise floor.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()()
{
// Floor performs a simple elementwise floor.
template <>
void Impl<op::Floor>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -93,12 +99,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "floor(I)"})
.finalize());
}
}
// Multiply performs a simple elementwise multiplication.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()()
{
// Multiply performs a simple elementwise multiplication.
template <>
void Impl<op::Multiply>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -107,12 +113,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::operator()()
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A * B"})
.finalize());
}
}
// Negative performs a simple elementwise negation.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()()
{
// Negative performs a simple elementwise negation.
template <>
void Impl<op::Negative>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -120,12 +126,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "-I"})
.finalize());
}
}
// Relu implements a simple elementwise rectified linear unit.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()()
{
// Relu implements a simple elementwise rectified linear unit.
template <>
void Impl<op::Relu>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -133,12 +139,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "relu(I)"})
.finalize());
}
}
// ReluBackprop computes the derivative of Relu.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()()
{
// ReluBackprop computes the derivative of Relu.
template <>
void Impl<op::ReluBackprop>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -147,12 +153,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::operator()()
.add(builder::Output{"DI"})
.add(builder::Elementwise{"DI", "I > 0 ? DO : 0"})
.finalize());
}
}
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()()
{
// Sigmoid computes a standard ML sigmoid: 1/(1+exp(-X))
template <>
void Impl<op::Sigmoid>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -160,13 +166,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.finalize());
}
}
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()()
{
// SigmoidBackprop computes the derivative of a standard ML
// sigmoid: dOutput * sigmoid(X) * (1-sigmoid(X))
template <>
void Impl<op::SigmoidBackprop>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -176,26 +182,27 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::operator()()
.add(builder::Elementwise{"O", "1/(1+exp(-I))"})
.add(builder::Elementwise{"DI", "DO * O * (1-O)"})
.finalize());
}
}
// Sign returns the sign of an element.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sign>::operator()()
{
// Sign returns the sign of an element.
template <>
void Impl<op::Sign>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
.add(builder::Input{op_input(0), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{"S", "(I < 0) ? -1 : ((I > 0) ? 1 : 0)"})
.add(builder::Elementwise{"O", tile_converter("S", op().get_element_type())})
.add(builder::Elementwise{
"O", tile_converter("S", op().get_element_type())})
.finalize());
}
}
// Subtract performs a simple elementwise subtraction.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()()
{
// Subtract performs a simple elementwise subtraction.
template <>
void Impl<op::Subtract>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -204,22 +211,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::operator()()
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "A - B"})
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Abs>::Registration register_abs;
ngraph::runtime::plaidml::Impl<ngraph::op::Add>::Registration register_add;
ngraph::runtime::plaidml::Impl<ngraph::op::Ceiling>::Registration register_ceiling;
ngraph::runtime::plaidml::Impl<ngraph::op::Divide>::Registration register_divide;
ngraph::runtime::plaidml::Impl<ngraph::op::Floor>::Registration register_floor;
ngraph::runtime::plaidml::Impl<ngraph::op::Multiply>::Registration register_multiply;
ngraph::runtime::plaidml::Impl<ngraph::op::Negative>::Registration register_negative;
ngraph::runtime::plaidml::Impl<ngraph::op::Relu>::Registration register_relu;
ngraph::runtime::plaidml::Impl<ngraph::op::ReluBackprop>::Registration register_relu_backprop;
ngraph::runtime::plaidml::Impl<ngraph::op::Sigmoid>::Registration register_sigmoid;
ngraph::runtime::plaidml::Impl<ngraph::op::SigmoidBackprop>::Registration
register_sigmoid_backprop;
ngraph::runtime::plaidml::Impl<ngraph::op::Sign>::Registration register_sign;
ngraph::runtime::plaidml::Impl<ngraph::op::Subtract>::Registration register_subtract;
namespace
{
Impl<op::Abs>::Registration register_abs;
Impl<op::Add>::Registration register_add;
Impl<op::Ceiling>::Registration register_ceiling;
Impl<op::Divide>::Registration register_divide;
Impl<op::Floor>::Registration register_floor;
Impl<op::Multiply>::Registration register_multiply;
Impl<op::Negative>::Registration register_negative;
Impl<op::Relu>::Registration register_relu;
Impl<op::ReluBackprop>::Registration register_relu_backprop;
Impl<op::Sigmoid>::Registration register_sigmoid;
Impl<op::SigmoidBackprop>::Registration register_sigmoid_backprop;
Impl<op::Sign>::Registration register_sign;
Impl<op::Subtract>::Registration register_subtract;
}
}
}
}
......@@ -18,11 +18,17 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template <>
void Impl<op::BatchNormInference>::operator()()
{
auto& input_shape = op().get_input_shape(2);
check_inputs(5);
check_outputs(1);
......@@ -45,12 +51,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
if (input_shape.size() <= 2)
{
f.add(builder::Elementwise{"GammaP", "Gamma"}).add(builder::Elementwise{"BetaP", "Beta"});
f.add(builder::Elementwise{"GammaP", "Gamma"})
.add(builder::Elementwise{"BetaP", "Beta"});
}
else
{
f.add(builder::Elementwise{"GammaP", std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"});
f.add(builder::Elementwise{"GammaP",
std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP",
std::string{"reshape(Beta, C"} + ones + ")"});
}
if (input_shape.size() <= 2)
......@@ -59,7 +68,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
}
else
{
f.add(builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
f.add(
builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
}
if (input_shape.size() <= 2)
......@@ -68,24 +78,26 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()(
}
else
{
f.add(builder::Elementwise{"VarianceP", std::string{"reshape(Variance, C"} + ones + ")"});
f.add(builder::Elementwise{"VarianceP",
std::string{"reshape(Variance, C"} + ones + ")"});
}
f.add(builder::Elementwise{"Normalized",
"(((Input-MeanP) / sqrt(VarianceP + " +
std::to_string(op().get_eps_value()) + ")) * GammaP) + BetaP"});
std::to_string(op().get_eps_value()) +
")) * GammaP) + BetaP"});
auto app = f.finalize();
set_output(app);
}
}
// BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied
// input.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
{
// BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied
// input.
template <>
void Impl<op::BatchNormTraining>::operator()()
{
auto& input_shape = op().get_input_shape(2);
check_inputs(3);
check_outputs(3);
......@@ -108,12 +120,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
if (input_shape.size() <= 2)
{
f.add(builder::Elementwise{"GammaP", "Gamma"}).add(builder::Elementwise{"BetaP", "Beta"});
f.add(builder::Elementwise{"GammaP", "Gamma"})
.add(builder::Elementwise{"BetaP", "Beta"});
}
else
{
f.add(builder::Elementwise{"GammaP", std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"});
f.add(builder::Elementwise{"GammaP",
std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP",
std::string{"reshape(Beta, C"} + ones + ")"});
}
if (input_shape.size() <= 2)
......@@ -131,7 +146,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
}
f.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumInput"}.add_indices({"c"}).add_dims({"C"}))
.set(builder::ContractionOutput{"SumInput"}.add_indices({"c"}).add_dims(
{"C"}))
.set(builder::ContractionInput{"Input"}
.add_indices({"b", "c"})
.add_indices("di", 3, input_shape.size() + 1)));
......@@ -143,13 +159,16 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
}
else
{
f.add(builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
f.add(
builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
}
f.add(builder::Elementwise{"DiffV", "(Input - MeanP)"})
.add(builder::Elementwise{"SqDiffV", "DiffV*DiffV"})
.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumSqDiffV"}.add_indices({"c"}).add_dims({"C"}))
.set(builder::ContractionOutput{"SumSqDiffV"}
.add_indices({"c"})
.add_dims({"C"}))
.set(builder::ContractionInput{"SqDiffV"}
.add_indices({"b", "c"})
.add_indices("di", 3, input_shape.size() + 1)))
......@@ -161,23 +180,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
}
else
{
f.add(builder::Elementwise{"VarianceP", std::string{"reshape(Variance, C"} + ones + ")"});
f.add(builder::Elementwise{"VarianceP",
std::string{"reshape(Variance, C"} + ones + ")"});
}
f.add(builder::Elementwise{"Normalized",
"(((Input-MeanP) / sqrt(VarianceP + " +
std::to_string(op().get_eps_value()) + ")) * GammaP) + BetaP"});
std::to_string(op().get_eps_value()) +
")) * GammaP) + BetaP"});
auto app = f.finalize();
set_output(0, app.get_output(0));
set_output(1, app.get_output(1));
set_output(2, app.get_output(2));
}
}
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::operator()()
{
template <>
void Impl<op::BatchNormTrainingBackprop>::operator()()
{
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
// swapped without the tests catching it.
......@@ -232,10 +253,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.set(builder::ContractionInput{"Input"}
.add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1)));
f.add(builder::Elementwise{"BatchMean", "BatchMeanNumerator / " + reduction_dims.str()});
f.add(builder::Elementwise{"BatchMean",
"BatchMeanNumerator / " + reduction_dims.str()});
f.add(builder::Elementwise{"NegBatchMean", "-BatchMean"});
f.add(
builder::BinaryContraction{"=", "+"}
f.add(builder::BinaryContraction{"=", "+"}
.set(builder::ContractionOutput{"Deviation"}
.add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1)
......@@ -244,7 +265,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.set_lhs(builder::ContractionInput{"Input"}
.add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1))
.set_rhs(builder::ContractionInput{"NegBatchMean"}.add_indices({"0", "c", "0", "0"})));
.set_rhs(builder::ContractionInput{"NegBatchMean"}.add_indices(
{"0", "c", "0", "0"})));
f.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"BatchVarNumerator"}
.add_indices({"0", "c", "0", "0"})
......@@ -255,7 +277,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
.set_rhs(builder::ContractionInput{"Deviation"}
.add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1)));
f.add(builder::Elementwise{"BatchVar", "BatchVarNumerator / " + reduction_dims.str()});
f.add(builder::Elementwise{"BatchVar",
"BatchVarNumerator / " + reduction_dims.str()});
f.add(builder::Elementwise{"BatchStdDev", "sqrt(BatchVar + " + epsilon + ")"});
f.add(builder::Elementwise{"NormedInput", "(Input - BatchMean) / BatchStdDev"});
......@@ -266,12 +289,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
f.add(builder::Elementwise{"DNormedInput", "DOutput * BroadcastGamma"});
f.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumDOutput"}.add_indices({"c"}).add_dims({"C"}))
.set(builder::ContractionOutput{"SumDOutput"}.add_indices({"c"}).add_dims(
{"C"}))
.set(builder::ContractionInput{"DOutput"}
.add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1)));
f.add(builder::BinaryContraction{"+", "*"}
.set(builder::ContractionOutput{"DGamma"}.add_indices({"c"}).add_dims({"C"}))
.set(builder::ContractionOutput{"DGamma"}.add_indices({"c"}).add_dims(
{"C"}))
.set_lhs(builder::ContractionInput{"DOutput"}
.add_indices({"n", "c"})
.add_indices("x", 3, input_shape.size() + 1))
......@@ -295,14 +320,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::oper
set_output(0, app.get_output(0));
set_output(1, app.get_output(1));
set_output(2, app.get_output(2));
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::Registration
register_batch_norm_inference;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::Registration
register_batch_norm_training;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::Registration
namespace
{
Impl<op::BatchNormInference>::Registration register_batch_norm_inference;
Impl<op::BatchNormTraining>::Registration register_batch_norm_training;
Impl<op::BatchNormTrainingBackprop>::Registration
register_batch_norm_training_backprop;
}
}
}
}
......@@ -24,10 +24,16 @@
#include "ngraph/op/not_equal.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Equal performs a simple elementwise equality.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Equal performs a simple elementwise equality.
template <>
void Impl<op::Equal>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -37,12 +43,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::operator()()
.add(builder::Elementwise{"C", "A == B"})
.finalize(),
TensorContents::LOGICAL);
}
}
// Greater performs a simple elementwise greater-than comparison.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()()
{
// Greater performs a simple elementwise greater-than comparison.
template <>
void Impl<op::Greater>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -52,12 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::operator()()
.add(builder::Elementwise{"C", "A > B"})
.finalize(),
TensorContents::LOGICAL);
}
}
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()()
{
// GreaterEq performs a simple elementwise greater-than-or-equal-to comparison.
template <>
void Impl<op::GreaterEq>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -67,12 +73,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::operator()()
.add(builder::Elementwise{"C", "A >= B"})
.finalize(),
TensorContents::LOGICAL);
}
}
// Less performs a simple elementwise less-than comparison.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()()
{
// Less performs a simple elementwise less-than comparison.
template <>
void Impl<op::Less>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -82,12 +88,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Less>::operator()()
.add(builder::Elementwise{"C", "A < B"})
.finalize(),
TensorContents::LOGICAL);
}
}
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()()
{
// LessEq performs a simple elementwise less-than-or-equal-to comparison.
template <>
void Impl<op::LessEq>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -97,12 +103,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::operator()()
.add(builder::Elementwise{"C", "A <= B"})
.finalize(),
TensorContents::LOGICAL);
}
}
// Maximum performs a simple elementwise maximum.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()()
{
// Maximum performs a simple elementwise maximum.
template <>
void Impl<op::Maximum>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -111,12 +117,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::operator()()
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "max(A, B)"})
.finalize());
}
}
// Minimum performs a simple elementwise minimum.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()()
{
// Minimum performs a simple elementwise minimum.
template <>
void Impl<op::Minimum>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -125,12 +131,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::operator()()
.add(builder::Output{"C"})
.add(builder::Elementwise{"C", "min(A, B)"})
.finalize());
}
}
// NotEqual performs a simple elementwise not-equality.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()()
{
// NotEqual performs a simple elementwise not-equality.
template <>
void Impl<op::NotEqual>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -140,16 +146,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::operator()()
.add(builder::Elementwise{"C", "A != B"})
.finalize(),
TensorContents::LOGICAL);
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Equal>::Registration register_equal;
ngraph::runtime::plaidml::Impl<ngraph::op::Greater>::Registration register_greater;
ngraph::runtime::plaidml::Impl<ngraph::op::GreaterEq>::Registration register_greater_eq;
ngraph::runtime::plaidml::Impl<ngraph::op::Less>::Registration register_less;
ngraph::runtime::plaidml::Impl<ngraph::op::LessEq>::Registration register_less_eq;
ngraph::runtime::plaidml::Impl<ngraph::op::Maximum>::Registration register_maximum;
ngraph::runtime::plaidml::Impl<ngraph::op::Minimum>::Registration register_minimum;
ngraph::runtime::plaidml::Impl<ngraph::op::NotEqual>::Registration register_not_equal;
namespace
{
Impl<op::Equal>::Registration register_equal;
Impl<op::Greater>::Registration register_greater;
Impl<op::GreaterEq>::Registration register_greater_eq;
Impl<op::Less>::Registration register_less;
Impl<op::LessEq>::Registration register_less_eq;
Impl<op::Maximum>::Registration register_maximum;
Impl<op::Minimum>::Registration register_minimum;
Impl<op::NotEqual>::Registration register_not_equal;
}
}
}
}
......@@ -17,10 +17,16 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Concat views a tensor as a new type.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Concat views a tensor as a new type.
template <>
void Impl<op::Concat>::operator()()
{
check_outputs(1);
auto f = start_tile_function();
......@@ -52,10 +58,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
continue;
}
std::string sidx{std::to_string(iidx)};
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims("I" + sidx + "_D", 0, dim_count));
f.add(builder::Input{op_input(iidx), "I" + sidx}.add_dims(
"I" + sidx + "_D", 0, dim_count));
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"E" + sidx}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
......@@ -70,19 +78,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
}
}
})
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_count; ++idx)
{
std::ostringstream s;
s << "d" << idx;
if (saw_non_zero_tensor && idx == op().get_concatenation_axis())
if (saw_non_zero_tensor &&
idx == op().get_concatenation_axis())
{
s << " + " << offset.str();
}
out = s.str();
}
}))
.set(builder::ContractionInput{"I" + sidx}.add_indices("d", 0, dim_count)));
.set(builder::ContractionInput{"I" + sidx}.add_indices(
"d", 0, dim_count)));
if (saw_non_zero_tensor)
{
oexpr << " + ";
......@@ -95,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::operator()()
f.add(builder::Elementwise{"O", oexpr.str()});
set_output(f.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Concat>::Registration register_concat;
namespace
{
Impl<op::Concat>::Registration register_concat;
}
}
}
}
......@@ -18,21 +18,31 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// Convert views a tensor as a new type.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Convert>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Convert views a tensor as a new type.
template <>
void Impl<op::Convert>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"})
.add(builder::Output{"O"})
.add(builder::Elementwise{
"O", tile_converter("I", to_plaidml(op().get_convert_element_type()))})
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Convert>::Registration register_convert;
namespace
{
Impl<op::Convert>::Registration register_convert;
}
}
}
}
......@@ -50,32 +50,29 @@ namespace ngraph
std::size_t output_channel_axis_result,
bool rotate_filter);
};
}
}
}
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Convolution>
{
using Type = ngraph::runtime::plaidml::ConvolutionImpl<ngraph::op::Convolution>;
};
template <>
struct ParentImpl<op::Convolution>
{
using Type = ConvolutionImpl<op::Convolution>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ConvolutionBackpropFilters>
{
using Type = ngraph::runtime::plaidml::ConvolutionImpl<ngraph::op::ConvolutionBackpropFilters>;
};
template <>
struct ParentImpl<op::ConvolutionBackpropFilters>
{
using Type = ConvolutionImpl<op::ConvolutionBackpropFilters>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ConvolutionBackpropData>
{
using Type = ngraph::runtime::plaidml::ConvolutionImpl<ngraph::op::ConvolutionBackpropData>;
};
template <>
struct ParentImpl<op::ConvolutionBackpropData>
{
using Type = ConvolutionImpl<op::ConvolutionBackpropData>;
};
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()()
{
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
template <>
void Impl<op::Convolution>::operator()()
{
this->check_inputs(2);
this->check_outputs(1);
......@@ -122,13 +119,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::operator()()
.set_lhs(cpf.I_in_body())
.set_rhs(cpf.F_in_body()))
.finalize());
}
}
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// input.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::operator()()
{
// ConvolutionBackpropFilters implements the derivative of a convolution with respect to its filter
// input.
template <>
void Impl<op::ConvolutionBackpropFilters>::operator()()
{
this->check_inputs(2);
this->check_outputs(1);
......@@ -177,13 +174,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::ope
.set_lhs(cpf.O_in_body())
.set_rhs(cpf.I_in_body()))
.finalize());
}
}
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// input.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operator()()
{
// ConvolutionBackpropData implements the derivative of a convolution with respect to its data
// input.
template <>
void Impl<op::ConvolutionBackpropData>::operator()()
{
this->check_inputs(2);
this->check_outputs(1);
......@@ -232,11 +229,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::operat
.set_lhs(cpf.O_in_body())
.set_rhs(cpf.F_in_body()))
.finalize());
}
}
template <typename O>
inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
vertexai::plaidml::variable image,
template <typename O>
inline void ConvolutionImpl<O>::LogConvolution(vertexai::plaidml::variable image,
vertexai::plaidml::variable filter,
std::size_t image_dims,
const Strides& window_movement_strides,
......@@ -251,7 +247,7 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
std::size_t batch_axis_result,
std::size_t output_channel_axis_result,
bool rotate_filter)
{
{
this->check_inputs(2);
this->check_outputs(1);
......@@ -271,13 +267,15 @@ inline void ngraph::runtime::plaidml::ConvolutionImpl<O>::LogConvolution(
NGRAPH_DEBUG << "batch_axis_result: " << batch_axis_result;
NGRAPH_DEBUG << "output_channel_axis_result: " << output_channel_axis_result;
NGRAPH_DEBUG << "rotate_filter: " << rotate_filter;
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Convolution>::Registration register_convolution;
ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropFilters>::Registration
namespace
{
Impl<op::Convolution>::Registration register_convolution;
Impl<op::ConvolutionBackpropFilters>::Registration
register_convolution_backprop_filters;
ngraph::runtime::plaidml::Impl<ngraph::op::ConvolutionBackpropData>::Registration
register_convolution_backprop_data;
Impl<op::ConvolutionBackpropData>::Registration register_convolution_backprop_data;
}
}
}
}
......@@ -20,11 +20,17 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Dot is a generalized dot product operation -- scalar-tensor,
// matrix-vector, and matrix multiplication.
template <>
void Impl<op::Dot>::operator()()
{
check_inputs(2);
check_outputs(1);
......@@ -40,7 +46,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
NGRAPH_DEBUG << "l_dim_mac=" << l_dim_mac;
NGRAPH_DEBUG << "r_dim_mic=" << r_dim_mic;
set_output(start_tile_function()
set_output(
start_tile_function()
.add(builder::Input{op_input(0), "L"}
.add_dims("DL", 1, l_dim_mac + 1)
.add_dims("DC", 1, reduce_limit + 1))
......@@ -61,9 +68,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::operator()()
.add_indices("dc", 1, reduce_limit + 1)
.add_indices("dr", r_dim_mic + 1, r_dim_limit + 1)))
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Dot>::Registration register_dot;
namespace
{
Impl<op::Dot>::Registration register_dot;
}
}
}
}
......@@ -19,10 +19,16 @@
#include "ngraph/runtime/plaidml/plaidml_compiler.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// FunctionCall invokes a sub-function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// FunctionCall invokes a sub-function.
template <>
void Impl<op::FunctionCall>::operator()()
{
Build b;
build()->compiler->build(op().get_functions()[0], &b);
vertexai::plaidml::function f{b.composer};
......@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
for (std::size_t idx = 0; idx < op().get_input_size(); ++idx)
{
auto* oitv = op().get_inputs()[idx].get_output().get_tensor_ptr().get();
auto* iitv = b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
auto* iitv =
b.func->get_parameters()[idx]->get_outputs()[0].get_tensor_ptr().get();
inputs.emplace_back(b.input_names.at(iitv), build()->bindings.at(oitv).var);
}
vertexai::plaidml::application app{f.apply(inputs)};
......@@ -39,9 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::operator()()
auto* iotv = b.func->get_results()[idx]->get_output_tensor_ptr().get();
set_output(idx, app.get_output(b.output_names[iotv]));
}
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::FunctionCall>::Registration register_function_call;
namespace
{
Impl<op::FunctionCall>::Registration register_function_call;
}
}
}
}
......@@ -28,10 +28,16 @@
namespace vp = vertexai::plaidml;
// Broadcast broadcasts a tensor to a wider shape.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Broadcast broadcasts a tensor to a wider shape.
template <>
void Impl<op::Broadcast>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -57,15 +63,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
start_tile_function()
.add(builder::Input{op_input(0), "I"}.add_rdims("D", in_dim_limit, 0))
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_rindices("o", out_dim_limit, 0)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_dim_limit; ++idx)
{
if (op().get_broadcast_axes().count(idx))
{
out = std::to_string(op().get_broadcast_shape()[idx]);
out = std::to_string(
op().get_broadcast_shape()[idx]);
}
else
{
......@@ -81,12 +91,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::operator()()
}
})))
.finalize());
}
}
// Constant fills in a tensor constant.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
{
// Constant fills in a tensor constant.
template <>
void Impl<op::Constant>::operator()()
{
check_inputs(0);
check_outputs(1);
......@@ -105,48 +115,51 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
switch (to_plaidml(op().get_element_type()))
{
case PLAIDML_DATA_BOOLEAN:
set_output(static_cast<std::int64_t>(*static_cast<const char*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const char*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_INT8:
set_output(
static_cast<std::int64_t>(*static_cast<const std::int8_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::int8_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_INT16:
set_output(
static_cast<std::int64_t>(*static_cast<const std::int16_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::int16_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_INT32:
set_output(
static_cast<std::int64_t>(*static_cast<const std::int32_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::int32_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_INT64:
set_output(*static_cast<const std::int64_t*>(op().get_data_ptr()));
return;
case PLAIDML_DATA_UINT8:
set_output(
static_cast<std::int64_t>(*static_cast<const std::uint8_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::uint8_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_UINT16:
set_output(
static_cast<std::int64_t>(*static_cast<const std::uint16_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::uint16_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_UINT32:
set_output(
static_cast<std::int64_t>(*static_cast<const std::uint32_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::uint32_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_UINT64:
set_output(
static_cast<std::int64_t>(*static_cast<const std::uint64_t*>(op().get_data_ptr())));
set_output(static_cast<std::int64_t>(
*static_cast<const std::uint64_t*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_FLOAT16:
set_output(static_cast<double>(
static_cast<float>(*static_cast<const half*>(op().get_data_ptr()))));
return;
case PLAIDML_DATA_FLOAT32:
set_output(static_cast<double>(*static_cast<const float*>(op().get_data_ptr())));
set_output(
static_cast<double>(*static_cast<const float*>(op().get_data_ptr())));
return;
case PLAIDML_DATA_FLOAT64:
set_output(static_cast<double>(*static_cast<const double*>(op().get_data_ptr())));
set_output(
static_cast<double>(*static_cast<const double*>(op().get_data_ptr())));
return;
default: break;
}
......@@ -163,22 +176,22 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::operator()()
}
set_output(tensor);
}
}
// GetOutputElement pipes one of its N inputs to its output.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::GetOutputElement>::operator()()
{
// GetOutputElement pipes one of its N inputs to its output.
template <>
void Impl<op::GetOutputElement>::operator()()
{
check_inputs_ge(op().get_n() + 1);
check_outputs(1);
set_output(op_input(op().get_n()));
}
}
// Pad adds interior and exterior padding to a tensor.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
{
// Pad adds interior and exterior padding to a tensor.
template <>
void Impl<op::Pad>::operator()()
{
check_inputs(2);
check_outputs(1);
......@@ -214,7 +227,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
auto out_dsize = [&](std::size_t idx) {
std::ostringstream s;
std::size_t total_pad = op().get_padding_below().at(idx) + op().get_padding_above().at(idx);
std::size_t total_pad =
op().get_padding_below().at(idx) + op().get_padding_above().at(idx);
std::size_t in_dsize = op().get_input_shape(0).at(idx);
if (in_dsize)
{
......@@ -267,40 +281,48 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
if (!any_zero_dims)
{
f.add(builder::Input{op_input(0), "I"}.add_dims("DI", 1, dim_limit + 1))
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"P"}
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"P"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
out = out_didx(idx);
}
})
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
out = out_dsize(idx);
}
}))
.set(builder::ContractionInput{"I"}.add_indices("d", 1, dim_limit + 1)))
.set(builder::ContractionInput{"I"}.add_indices(
"d", 1, dim_limit + 1)))
.add(builder::Elementwise{"T", "1"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"F"}
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"F"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
out = out_didx(idx);
}
})
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
out = out_dsize(idx);
}
}))
.set(builder::ContractionInput{"T"})
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
out = flag_constraints(idx);
......@@ -313,7 +335,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices("d", 0, dim_limit)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
out = out_dsize(idx);
......@@ -323,12 +346,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::operator()()
}
set_output(f.finalize());
}
}
// Reshape reshapes an input tensor.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
{
// Reshape reshapes an input tensor.
template <>
void Impl<op::Reshape>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -351,13 +374,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices("d", 0, out_shape.size())
.add_dims(
[&](std::back_insert_iterator<std::list<std::string>> out) {
std::transform(
out_shape.begin(),
.add_dims([&](
std::back_insert_iterator<std::list<std::string>>
out) {
std::transform(out_shape.begin(),
out_shape.end(),
out,
[](std::size_t sz) { return std::to_string(sz); });
[](std::size_t sz) {
return std::to_string(sz);
});
}))
.set(builder::ContractionInput{"I"}))
.finalize());
......@@ -374,28 +399,31 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
// it's also rearranging the elements of the input tensor. This is pretty easy to
// handle with a contraction.
src =
start_tile_function()
src = start_tile_function()
.add(builder::Input{src, "I"}.add_dims("D", 1, dim_limit + 1))
.add(builder::Output{"O"})
.add(
builder::UnaryContraction{"="}
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
.add_indices([&](std::back_insert_iterator<
std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit;
++idx)
{
out = "d" + std::to_string(input_order[idx] + 1);
out = "d" + std::to_string(
input_order[idx] + 1);
}
})
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
.add_dims([&](std::back_insert_iterator<
std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit;
++idx)
{
out = "D" + std::to_string(input_order[idx] + 1);
out = "D" + std::to_string(
input_order[idx] + 1);
}
}))
.set(builder::ContractionInput{"I"}.add_indices("d", 1, dim_limit + 1)))
.set(builder::ContractionInput{"I"}.add_indices(
"d", 1, dim_limit + 1)))
.finalize();
break;
}
......@@ -414,12 +442,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise("O", reshape_expr.str()))
.finalize());
}
}
// Select conditionally selects elements from input tensors.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()()
{
// Select conditionally selects elements from input tensors.
template <>
void Impl<op::Select>::operator()()
{
check_inputs(3);
check_outputs(1);
......@@ -430,26 +458,28 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Select>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "C ? T : F"})
.finalize());
}
}
// Used by nGraph for bprop graph generation, no-op as a kernel
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::StopGradient>::operator()()
{
// Used by nGraph for bprop graph generation, no-op as a kernel
template <>
void Impl<op::StopGradient>::operator()()
{
set_output(start_tile_function()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "0"})
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Broadcast>::Registration register_broadcast;
ngraph::runtime::plaidml::Impl<ngraph::op::Constant>::Registration register_constant;
ngraph::runtime::plaidml::Impl<ngraph::op::GetOutputElement>::Registration
register_get_output_element;
ngraph::runtime::plaidml::Impl<ngraph::op::Pad>::Registration register_pad;
ngraph::runtime::plaidml::Impl<ngraph::op::Reshape>::Registration register_reshape;
ngraph::runtime::plaidml::Impl<ngraph::op::Select>::Registration register_select;
ngraph::runtime::plaidml::Impl<ngraph::op::StopGradient>::Registration register_stop_gradient;
namespace
{
Impl<op::Broadcast>::Registration register_broadcast;
Impl<op::Constant>::Registration register_constant;
Impl<op::GetOutputElement>::Registration register_get_output_element;
Impl<op::Pad>::Registration register_pad;
Impl<op::Reshape>::Registration register_reshape;
Impl<op::Select>::Registration register_select;
Impl<op::StopGradient>::Registration register_stop_gradient;
}
}
}
}
......@@ -36,13 +36,10 @@ namespace ngraph
void build_index_reduction(const char* agg_op);
};
}
}
}
template <typename O>
void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(const char* agg_op)
{
template <typename O>
void IndexReductionImpl<O>::build_index_reduction(const char* agg_op)
{
this->check_inputs(1);
this->check_outputs(1);
......@@ -56,16 +53,20 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
.add(builder::Output{"O"})
.add( // Compute the maxes along the specified axis in the input
builder::UnaryContraction{agg_op}
.set(builder::ContractionOutput{"SelVal"}
.set(
builder::ContractionOutput{"SelVal"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx)
{
out = (idx == this->op().get_reduction_axis() ? "rd" : "d") +
out =
(idx == this->op().get_reduction_axis() ? "rd"
: "d") +
std::to_string(idx);
}
})
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx)
{
if (idx == this->op().get_reduction_axis())
......@@ -82,13 +83,14 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
.add( // Compare the input against the (broadcasted) max values, and select the indices
// where the max val occurs
builder::Elementwise{"SelValIdxs",
"I == SelVal ? index(I, " + reduction_axis_str + ") : D" +
reduction_axis_str})
"I == SelVal ? index(I, " + reduction_axis_str +
") : D" + reduction_axis_str})
.add( // Select the maximum index
builder::UnaryContraction{"<"}
.set(builder::ContractionOutput{"SelIdx"}
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.set(
builder::ContractionOutput{"SelIdx"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx)
{
if (idx != this->op().get_reduction_axis())
......@@ -97,7 +99,8 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
}
}
})
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < dim_limit; ++idx)
{
if (idx != this->op().get_reduction_axis())
......@@ -106,41 +109,45 @@ void ngraph::runtime::plaidml::IndexReductionImpl<O>::build_index_reduction(cons
}
}
}))
.set(builder::ContractionInput{"SelValIdxs"}.add_indices("d", 0, dim_limit)))
.set(builder::ContractionInput{"SelValIdxs"}.add_indices(
"d", 0, dim_limit)))
.add( // Convert to the requested output element type (if any)
builder::Elementwise{"O",
tile_converter("SelIdx", this->op().get_index_element_type())})
builder::Elementwise{
"O", tile_converter("SelIdx", this->op().get_index_element_type())})
.finalize());
}
}
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ArgMax>
{
using Type = IndexReductionImpl<ngraph::op::ArgMax>;
};
template <>
struct ParentImpl<op::ArgMax>
{
using Type = IndexReductionImpl<op::ArgMax>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::ArgMin>
{
using Type = IndexReductionImpl<ngraph::op::ArgMin>;
};
template <>
struct ParentImpl<op::ArgMin>
{
using Type = IndexReductionImpl<op::ArgMin>;
};
// ArgMax computes the maximum index along a tensor axis.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ArgMax>::operator()()
{
// ArgMax computes the maximum index along a tensor axis.
template <>
void Impl<op::ArgMax>::operator()()
{
build_index_reduction(">");
}
}
// ArgMin computes the minimum index along a tensor axis.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ArgMin>::operator()()
{
// ArgMin computes the minimum index along a tensor axis.
template <>
void Impl<op::ArgMin>::operator()()
{
build_index_reduction("<");
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::ArgMax>::Registration register_argmax;
ngraph::runtime::plaidml::Impl<ngraph::op::ArgMin>::Registration register_argmin;
namespace
{
Impl<op::ArgMax>::Registration register_argmax;
Impl<op::ArgMin>::Registration register_argmin;
}
}
}
}
......@@ -20,10 +20,16 @@
namespace vp = vertexai::plaidml;
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Parameter binds a descriptor::Tensor to a PlaidML Placeholder.
template <>
void Impl<op::Parameter>::operator()()
{
check_inputs(0);
check_outputs(1);
vp::placeholder ph{build()->io_dim_override ? build()->io_dim_override_count
......@@ -33,22 +39,25 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::operator()()
build()->bindings.emplace(tv, TensorInfo{ph, TensorContents::DATA});
build()->composer.input(name, ph);
build()->input_names.emplace(tv, std::move(name));
}
}
// Result binds a PlaidML variable to a composed function output.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Result>::operator()()
{
// Result binds a PlaidML variable to a composed function output.
template <>
void Impl<op::Result>::operator()()
{
check_inputs(1);
check_outputs(1);
std::string name = std::string{"O"} + std::to_string(build()->output_names.size());
descriptor::Tensor* tv = op().get_output_tensor_ptr().get();
build()->composer.output(name, op_input());
build()->output_names.emplace(tv, std::move(name));
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Parameter>::Registration register_parameter;
ngraph::runtime::plaidml::Impl<ngraph::op::Result>::Registration register_result;
namespace
{
Impl<op::Parameter>::Registration register_parameter;
Impl<op::Result>::Registration register_result;
}
}
}
}
......@@ -17,21 +17,29 @@
#include "ngraph/op/lrn.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// LRN implements Local Response Normalization
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// LRN implements Local Response Normalization
template <>
void Impl<op::LRN>::operator()()
{
check_inputs(1);
check_outputs(1);
auto dim_limit = op().get_inputs()[0].get_shape().size();
auto rank = dim_limit - 2;
auto distance = op().get_nsize() / 2;
std::ostringstream div_expr;
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha() << ".0 / "
<< op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
div_expr << "I / pow(" << op().get_bias() << ".0 + ((" << op().get_alpha()
<< ".0 / " << op().get_nsize() << ".0) * S), " << op().get_beta() << ".0)";
set_output(
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims({"N", "C"}).add_dims("D", 0, rank))
.add(builder::Input{op_input(), "I"}
.add_dims({"N", "C"})
.add_dims("D", 0, rank))
.add(builder::Output{"O"})
.add(builder::Elementwise{"ISQ", "I * I"})
.add(builder::UnaryContraction{"+"}
......@@ -43,14 +51,18 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::operator()()
.set(builder::ContractionInput{"ISQ"}
.add_indices({"n", "c + z - " + std::to_string(distance)})
.add_indices("d", 0, rank))
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
out = "z < " + std::to_string(op().get_nsize());
}))
.add(builder::Elementwise{"O", div_expr.str()})
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::LRN>::Registration register_local_response_norm;
namespace
{
Impl<op::LRN>::Registration register_local_response_norm;
}
}
}
}
......@@ -19,10 +19,16 @@
#include "ngraph/op/or.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// And performs a simple elementwise logical and.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// And performs a simple elementwise logical and.
template <>
void Impl<op::And>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -32,12 +38,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::And>::operator()()
.add(builder::Elementwise{"C", "A ? B : A"})
.finalize(),
TensorContents::LOGICAL);
}
}
// Not performs a simple elementwise logical not.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()()
{
// Not performs a simple elementwise logical not.
template <>
void Impl<op::Not>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -46,12 +52,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Not>::operator()()
.add(builder::Elementwise{"O", "cmp_eq(I, 0)"})
.finalize(),
TensorContents::LOGICAL);
}
}
// Or performs a simple elementwise logical or.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()()
{
// Or performs a simple elementwise logical or.
template <>
void Impl<op::Or>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -61,11 +67,14 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Or>::operator()()
.add(builder::Elementwise{"C", "A ? A : B"})
.finalize(),
TensorContents::LOGICAL);
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::And>::Registration register_and;
ngraph::runtime::plaidml::Impl<ngraph::op::Not>::Registration register_not;
ngraph::runtime::plaidml::Impl<ngraph::op::Or>::Registration register_or;
namespace
{
Impl<op::And>::Registration register_and;
Impl<op::Not>::Registration register_not;
Impl<op::Or>::Registration register_or;
}
}
}
}
......@@ -20,10 +20,16 @@
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
#include "ngraph/runtime/plaidml/plaidml_translate.hpp"
// OneHot performs one-hot encoding along the requested axis.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// OneHot performs one-hot encoding along the requested axis.
template <>
void Impl<op::OneHot>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -68,9 +74,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
.add(builder::Input{op_input(), "I"}.add_dims("D", 0, in_shape.size()))
.add(builder::Input{static_cast<std::int64_t>(0), "Zero"})
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"ZS"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"ZS"}
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_shape.size(); ++idx)
{
if (idx == op().get_one_hot_axis())
......@@ -85,15 +94,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::operator()()
})
.add_indices("d", 0, out_shape.size()))
.set(builder::ContractionInput{"Zero"}))
.add(builder::Elementwise{"Idx",
"index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{
"Idx", "index(ZS, " + std::to_string(op().get_one_hot_axis()) + ")"})
.add(builder::Elementwise{"IS", "reshape(I, " + in_reshape.str() + ")"})
.add(builder::Elementwise{"OV", "IS == Idx ? 1 : 0"})
.add(builder::Elementwise{"O", tile_converter("OV", op().get_element_type())})
.add(builder::Elementwise{"O",
tile_converter("OV", op().get_element_type())})
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::OneHot>::Registration register_one_hot;
namespace
{
Impl<op::OneHot>::Registration register_one_hot;
}
}
}
}
......@@ -20,10 +20,16 @@
#include "ngraph/runtime/plaidml/plaidml_convpool_formatter.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// AvgPool implements a batch average pooling operation.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// AvgPool implements a batch average pooling operation.
template <>
void Impl<op::AvgPool>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::operator()()
f.add(cpf.PoolContraction()).add(builder::Elementwise{"O", "S / Count"});
set_output(f.finalize());
}
}
// MaxPool implements a batch max pooling operation.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()()
{
// MaxPool implements a batch max pooling operation.
template <>
void Impl<op::MaxPool>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -156,11 +162,11 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::operator()()
.add(cpf.O_out_header())
.add(cpf.PoolContraction())
.finalize());
}
}
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
{
template <>
void Impl<op::AvgPoolBackprop>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -174,7 +180,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
if (include_padding)
{
throw std::runtime_error("Include padding in average not yet implemented in PlaidML");
throw std::runtime_error(
"Include padding in average not yet implemented in PlaidML");
}
ngraph::CoordinateDiff pad_above;
......@@ -229,18 +236,19 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::operator()()
{
std::ostringstream s;
s << "XI" << i - 2;
ret.add(builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
ret.add(
builder::Input{static_cast<std::int64_t>(forward_arg_shape[i]), s.str()});
}
set_output(ret.add(cpf.Broadcast_Ones())
.add(cpf.Count())
.add(builder::Elementwise{"S", "DO / Count"})
.add(cpf.PoolContraction())
.finalize());
}
}
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()()
{
template <>
void Impl<op::MaxPoolBackprop>::operator()()
{
check_inputs(2);
check_outputs(1);
......@@ -299,14 +307,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::operator()()
.add(cpf.PoolContraction())
.add(cpf.PoolDerivContraction())
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::AvgPool>::Registration register_avg_pool;
ngraph::runtime::plaidml::Impl<ngraph::op::MaxPool>::Registration register_max_pool;
ngraph::runtime::plaidml::Impl<ngraph::op::AvgPoolBackprop>::Registration
register_avg_pool_backprop;
ngraph::runtime::plaidml::Impl<ngraph::op::MaxPoolBackprop>::Registration
register_max_pool_backprop;
namespace
{
Impl<op::AvgPool>::Registration register_avg_pool;
Impl<op::MaxPool>::Registration register_max_pool;
Impl<op::AvgPoolBackprop>::Registration register_avg_pool_backprop;
Impl<op::MaxPoolBackprop>::Registration register_max_pool_backprop;
}
}
}
}
......@@ -42,13 +42,10 @@ namespace ngraph
void build_reduction(const char* agg_op);
};
}
}
}
template <typename O>
void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg_op)
{
template <typename O>
void ReductionImpl<O>::build_reduction(const char* agg_op)
{
this->check_inputs(1);
this->check_outputs(1);
......@@ -68,81 +65,86 @@ void ngraph::runtime::plaidml::ReductionImpl<O>::build_reduction(const char* agg
this->start_tile_function()
.add(builder::Output{"O"})
.add(builder::Input{this->op_input(0), "I"}.add_dims("D", 1, in_dim_limit + 1))
.add(builder::UnaryContraction{agg_op}
.set(builder::ContractionOutput{"O"}
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.add(builder::Input{this->op_input(0), "I"}.add_dims(
"D", 1, in_dim_limit + 1))
.add(
builder::UnaryContraction{agg_op}
.set(
builder::ContractionOutput{"O"}
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_idxs.size(); ++idx)
{
out = "d" + std::to_string(out_idxs[idx] + 1);
}
})
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < out_idxs.size(); ++idx)
{
out = "D" + std::to_string(out_idxs[idx] + 1);
}
}))
.set(builder::ContractionInput{"I"}.add_indices("d", 1, in_dim_limit + 1)))
.set(builder::ContractionInput{"I"}.add_indices(
"d", 1, in_dim_limit + 1)))
.finalize());
}
}
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Max>
{
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Max>;
};
template <>
struct ParentImpl<op::Max>
{
using Type = ReductionImpl<op::Max>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Min>
{
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Min>;
};
template <>
struct ParentImpl<op::Min>
{
using Type = ReductionImpl<op::Min>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Product>
{
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Product>;
};
template <>
struct ParentImpl<op::Product>
{
using Type = ReductionImpl<op::Product>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Reduce>
{
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Reduce>;
};
template <>
struct ParentImpl<op::Reduce>
{
using Type = ReductionImpl<op::Reduce>;
};
template <>
struct ngraph::runtime::plaidml::ParentImpl<ngraph::op::Sum>
{
using Type = ngraph::runtime::plaidml::ReductionImpl<ngraph::op::Sum>;
};
template <>
struct ParentImpl<op::Sum>
{
using Type = ReductionImpl<op::Sum>;
};
// Max reduces a tensor, taking the maximum along the specified axes.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Max>::operator()()
{
// Max reduces a tensor, taking the maximum along the specified axes.
template <>
void Impl<op::Max>::operator()()
{
build_reduction(">");
}
}
// Min reduces a tensor, taking the minimum along the specified axes.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Min>::operator()()
{
// Min reduces a tensor, taking the minimum along the specified axes.
template <>
void Impl<op::Min>::operator()()
{
build_reduction("<");
}
}
// Min reduces a tensor, taking the product along the specified axes.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Product>::operator()()
{
// Min reduces a tensor, taking the product along the specified axes.
template <>
void Impl<op::Product>::operator()()
{
build_reduction("*");
}
}
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
{
// Reduce reduces a tensor with an arbitrary user-supplied reduction operation.
template <>
void Impl<op::Reduce>::operator()()
{
check_inputs(2);
check_outputs(1);
......@@ -183,10 +185,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
start_tile_function()
.add(builder::Input{op_input(1), "I"})
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_indices("d", 0, agg_dim_limit)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx = 0; idx < agg_dim_limit; ++idx)
{
out = "1";
......@@ -205,9 +210,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
.add_indices([&](
std::back_insert_iterator<std::list<std::string>>
out) {
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
{
if (!op().get_reduction_axes().count(idx))
{
......@@ -215,9 +223,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
}
}
})
.add_dims(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
.add_dims([&](
std::back_insert_iterator<std::list<std::string>>
out) {
for (std::size_t idx = 0;
idx < input_shape.size();
++idx)
{
if (!op().get_reduction_axes().count(idx))
{
......@@ -225,8 +236,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
}
}
}))
.set(builder::ContractionInput{"I"}.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.set(builder::ContractionInput{"I"}.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < input_shape.size(); ++idx)
{
std::size_t cidx = 0;
......@@ -244,20 +255,23 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::operator()()
}
set_output(result);
}
}
// Sum reduces a tensor, summing the specified axes.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sum>::operator()()
{
// Sum reduces a tensor, summing the specified axes.
template <>
void Impl<op::Sum>::operator()()
{
build_reduction("+");
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Max>::Registration register_max;
ngraph::runtime::plaidml::Impl<ngraph::op::Min>::Registration register_min;
ngraph::runtime::plaidml::Impl<ngraph::op::Product>::Registration register_product;
ngraph::runtime::plaidml::Impl<ngraph::op::Reduce>::Registration register_reduce;
ngraph::runtime::plaidml::Impl<ngraph::op::Sum>::Registration register_sum;
namespace
{
Impl<op::Max>::Registration register_max;
Impl<op::Min>::Registration register_min;
Impl<op::Product>::Registration register_product;
Impl<op::Reduce>::Registration register_reduce;
Impl<op::Sum>::Registration register_sum;
}
}
}
}
......@@ -19,10 +19,16 @@
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// ReplaceSlice replaces part of a tensor with another tensor.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// ReplaceSlice replaces part of a tensor with another tensor.
template <>
void Impl<op::ReplaceSlice>::operator()()
{
check_inputs(2);
check_outputs(1);
......@@ -43,11 +49,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
.add(builder::Input{op_input(0), "L"}.add_dims("D", 0, shape.size()))
.add(builder::Input{op_input(1), "S"}.add_dims("SD", 0, shape.size()))
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_dims("D", 0, shape.size())
.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx)
{
auto stride = op().get_strides()[idx];
......@@ -73,8 +81,10 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
out = didx.str();
}
}))
.set(builder::ContractionInput{"S"}.add_indices("d", 0, shape.size()))
.add_constraints([&](std::back_insert_iterator<std::list<std::string>> out) {
.set(builder::ContractionInput{"S"}.add_indices(
"d", 0, shape.size()))
.add_constraints(
[&](std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx)
{
out = "d" + std::to_string(idx) + " < " +
......@@ -84,9 +94,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::operator()()
})
.set_default("L"))
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::ReplaceSlice>::Registration register_replace_slice;
namespace
{
Impl<op::ReplaceSlice>::Registration register_replace_slice;
}
}
}
}
......@@ -19,10 +19,16 @@
#include "ngraph/op/reverse.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Reverse reverses the selected axes within a tensor.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Reverse reverses the selected axes within a tensor.
template <>
void Impl<op::Reverse>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -35,8 +41,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
.set(builder::ContractionOutput{"O"}
.add_indices("d", 0, shape.size())
.add_dims("D", 0, shape.size()))
.set(builder::ContractionInput{"I"}.add_indices(
[&](std::back_insert_iterator<std::list<std::string>> out) {
.set(builder::ContractionInput{"I"}.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < shape.size(); ++idx)
{
auto sidx = std::to_string(idx);
......@@ -51,9 +57,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::operator()()
}
})))
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Reverse>::Registration register_reverse;
namespace
{
Impl<op::Reverse>::Registration register_reverse;
}
}
}
}
......@@ -18,10 +18,16 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Slice takes a sub-slice of a tensor.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Slice takes a sub-slice of a tensor.
template <>
void Impl<op::Slice>::operator()()
{
check_inputs(1);
check_outputs(1);
NGRAPH_DEBUG << "Slice: low: " << op().get_lower_bounds();
......@@ -33,17 +39,21 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
start_tile_function()
.add(builder::Input{op_input(), "I"}.add_dims("ID", 0, dim_limit))
.add(builder::Output{"O"})
.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"O"}
.add(
builder::UnaryContraction{"="}
.set(
builder::ContractionOutput{"O"}
.add_indices("od", 0, dim_limit)
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (std::size_t idx = 0; idx < dim_limit; ++idx)
{
std::ostringstream s;
std::size_t stride = op().get_strides()[idx];
std::ptrdiff_t trim_count =
op().get_lower_bounds()[idx] +
(shape[idx] - op().get_upper_bounds()[idx]) + 1 - stride;
(shape[idx] - op().get_upper_bounds()[idx]) +
1 - stride;
if ((stride != 1) && trim_count)
{
s << "(";
......@@ -96,9 +106,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::operator()()
}
})))
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Slice>::Registration register_slice;
namespace
{
Impl<op::Slice>::Registration register_slice;
}
}
}
}
......@@ -19,10 +19,16 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// Softmax implements a standard ML softmax operation.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// Softmax implements a standard ML softmax operation.
template <>
void Impl<op::Softmax>::operator()()
{
check_inputs(1);
check_outputs(1);
......@@ -30,7 +36,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
auto dim_limit = shape.size();
auto f = start_tile_function();
f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit)).add(builder::Output{"O"});
f.add(builder::Input{op_input(0), "I"}.add_dims("D", 0, dim_limit))
.add(builder::Output{"O"});
bool reorder_needed = false;
bool saw_element = false;
......@@ -71,7 +78,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{
f.add(builder::UnaryContraction{"="}
.set(builder::ContractionOutput{"RI"}
.add_dims([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_dims([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs)
{
out = "D" + std::to_string(idx);
......@@ -81,7 +89,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
out = "D" + std::to_string(idx);
}
})
.add_indices([&](std::back_insert_iterator<std::list<std::string>> out) {
.add_indices([&](
std::back_insert_iterator<std::list<std::string>> out) {
for (auto idx : group_idxs)
{
out = "d" + std::to_string(idx);
......@@ -117,7 +126,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
{
// Take the softmax.
std::ostringstream softmax;
softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements << ")";
softmax << "builtin_softmax(" << input << ", " << groups << ", " << elements
<< ")";
f.add(builder::Elementwise{output, softmax.str()});
}
......@@ -159,9 +169,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::operator()()
}
set_output(f.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Softmax>::Registration register_softmax;
namespace
{
Impl<op::Softmax>::Registration register_softmax;
}
}
}
}
......@@ -29,10 +29,16 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"
// acos performs a simple elementwise arccos function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()()
namespace ngraph
{
namespace runtime
{
namespace plaidml
{
// acos performs a simple elementwise arccos function.
template <>
void Impl<op::Acos>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -40,12 +46,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "acos(I)"})
.finalize());
}
}
// asin performs a simple elementwise arcsin function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()()
{
// asin performs a simple elementwise arcsin function.
template <>
void Impl<op::Asin>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -53,12 +59,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "asin(I)"})
.finalize());
}
}
// atan performs a simple elementwise arctan function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()()
{
// atan performs a simple elementwise arctan function.
template <>
void Impl<op::Atan>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -66,12 +72,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "atan(I)"})
.finalize());
}
}
// cos performs a simple elementwise cos function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()()
{
// cos performs a simple elementwise cos function.
template <>
void Impl<op::Cos>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -79,12 +85,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cos(I)"})
.finalize());
}
}
// cosh performs a simple elementwise hyperbolic cos function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()()
{
// cosh performs a simple elementwise hyperbolic cos function.
template <>
void Impl<op::Cosh>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -92,12 +98,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "cosh(I)"})
.finalize());
}
}
// exp performs a simple elementwise natural exponential function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()()
{
// exp performs a simple elementwise natural exponential function.
template <>
void Impl<op::Exp>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -105,12 +111,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "exp(I)"})
.finalize());
}
}
// log performs a simple elementwise natural logarithm function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()()
{
// log performs a simple elementwise natural logarithm function.
template <>
void Impl<op::Log>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -118,12 +124,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Log>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "log(I)"})
.finalize());
}
}
// power performs a simple elementwise power function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()()
{
// power performs a simple elementwise power function.
template <>
void Impl<op::Power>::operator()()
{
check_inputs(2);
check_outputs(1);
set_output(start_tile_function()
......@@ -132,12 +138,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Power>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "pow(I, E)"})
.finalize());
}
}
// sin performs a simple elementwise sin function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()()
{
// sin performs a simple elementwise sin function.
template <>
void Impl<op::Sin>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -145,12 +151,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sin(I)"})
.finalize());
}
}
// sinh performs a simple elementwise hyperbolic sin function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()()
{
// sinh performs a simple elementwise hyperbolic sin function.
template <>
void Impl<op::Sinh>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -158,12 +164,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sinh(I)"})
.finalize());
}
}
// sqrt performs a simple elementwise square root function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()()
{
// sqrt performs a simple elementwise square root function.
template <>
void Impl<op::Sqrt>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -171,12 +177,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "sqrt(I)"})
.finalize());
}
}
// tan performs a simple elementwise tangent function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()()
{
// tan performs a simple elementwise tangent function.
template <>
void Impl<op::Tan>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -184,12 +190,12 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tan(I)"})
.finalize());
}
}
// tanh performs a simple elementwise hyperbolic tangent function.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()()
{
// tanh performs a simple elementwise hyperbolic tangent function.
template <>
void Impl<op::Tanh>::operator()()
{
check_inputs(1);
check_outputs(1);
set_output(start_tile_function()
......@@ -197,21 +203,24 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::operator()()
.add(builder::Output{"O"})
.add(builder::Elementwise{"O", "tanh(I)"})
.finalize());
}
}
namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::Acos>::Registration register_acos;
ngraph::runtime::plaidml::Impl<ngraph::op::Asin>::Registration register_asin;
ngraph::runtime::plaidml::Impl<ngraph::op::Atan>::Registration register_atan;
ngraph::runtime::plaidml::Impl<ngraph::op::Cos>::Registration register_cos;
ngraph::runtime::plaidml::Impl<ngraph::op::Cosh>::Registration register_cosh;
ngraph::runtime::plaidml::Impl<ngraph::op::Exp>::Registration register_exp;
ngraph::runtime::plaidml::Impl<ngraph::op::Log>::Registration register_log;
ngraph::runtime::plaidml::Impl<ngraph::op::Power>::Registration register_power;
ngraph::runtime::plaidml::Impl<ngraph::op::Sin>::Registration register_sin;
ngraph::runtime::plaidml::Impl<ngraph::op::Sinh>::Registration register_sinh;
ngraph::runtime::plaidml::Impl<ngraph::op::Sqrt>::Registration register_sqrt;
ngraph::runtime::plaidml::Impl<ngraph::op::Tan>::Registration register_tan;
ngraph::runtime::plaidml::Impl<ngraph::op::Tanh>::Registration register_tanh;
namespace
{
Impl<op::Acos>::Registration register_acos;
Impl<op::Asin>::Registration register_asin;
Impl<op::Atan>::Registration register_atan;
Impl<op::Cos>::Registration register_cos;
Impl<op::Cosh>::Registration register_cosh;
Impl<op::Exp>::Registration register_exp;
Impl<op::Log>::Registration register_log;
Impl<op::Power>::Registration register_power;
Impl<op::Sin>::Registration register_sin;
Impl<op::Sinh>::Registration register_sinh;
Impl<op::Sqrt>::Registration register_sqrt;
Impl<op::Tan>::Registration register_tan;
Impl<op::Tanh>::Registration register_tanh;
}
}
}
}
......@@ -38,6 +38,8 @@ topk_2d_max_one # No plans to implement TopK
topk_2d_min_all # No plans to implement TopK
topk_2d_min_partial # No plans to implement TopK
topk_2d_min_one # No plans to implement TopK
topk_int64 # No plans to implement TopK
topk_5d_max_partial # No plans to implement TopK
# Tests that PlaidML might be able to run at some point.
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
......@@ -84,3 +86,5 @@ sum_3d_eliminate_zero_dim # Empty dims apparently should produce shape
dot_0_0 # Empty dims apparently should produce shaped 0s
dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s
numeric_float_nan
numeric_double_nan
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment