Unverified Commit a84a8d3b authored by Ewa Tusień's avatar Ewa Tusień Committed by GitHub

Merge branch 'master' into etusien/gelu

parents a9399c21 5b59c095
......@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID c0cad98)
set(MLIR_COMMIT_ID 82d5084)
set(MLIR_LLVM_COMMIT_ID a2a6f85)
set(MLIR_COMMIT_ID 26c683c)
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
......
......@@ -21,9 +21,9 @@ from a framework on a CPU, GPU, or ASIC; it can also be used with an
*Interpreter* mode, which is primarily intended for testing, to analyze a
program, or to help a framework developer customize targeted solutions.
nGraph also provides a way to use the advanced tensor compiler PlaidML
as a backend; you can learn more about this backend and how to build it
from source in our documentation: :ref:`ngraph_plaidml_backend`.
.. nGraph also provides a way to use the advanced tensor compiler PlaidML
.. as a backend; you can learn more about this backend and how to build it
.. from source in our documentation: :ref:`ngraph_plaidml_backend`.
.. csv-table::
:header: "Backend", "Current nGraph support", "Future nGraph support"
......
......@@ -50,6 +50,7 @@
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
......@@ -108,11 +109,16 @@ void MLIRCompiler::run(std::vector<void*>& external_tensors)
cleanup();
}
unsigned MLIRCompiler::get_mem_mgr_arg_id(mlir::FuncOp& func)
{
return func.getNumArguments() - 1;
}
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module()
{
// initialize an empty module
m_module = make_unique<mlir::Module>(&m_context);
m_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&m_context));
TypeList args_type_list, result_type_list;
......@@ -133,15 +139,14 @@ void MLIRCompiler::build_ng_dialect_module()
}
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
auto function =
make_unique<mlir::Function>(mlir::UnknownLoc::get(&m_context), "main", func_type);
function->addEntryBlock();
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(&m_context), "main", func_type);
function.addEntryBlock();
// populate Tensor->Value maps
int i = 0;
for (auto input : kernel_inputs)
{
mlir::Value* arg = function->getArgument(i);
mlir::Value* arg = function.getArgument(i);
TensorInfo tensor_info{arg};
m_tensor_to_value_map.insert(
TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
......@@ -149,9 +154,9 @@ void MLIRCompiler::build_ng_dialect_module()
}
// create builder
m_builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
m_builder = llvm::make_unique<mlir::OpBuilder>(function.getBody());
build_ng_dialect();
m_module->getFunctions().push_back(function.release());
m_module->push_back(function);
if (failed(m_module->verify()))
{
NGRAPH_CHECK(false, "Invalid module after lowering to NG dialect");
......@@ -260,19 +265,21 @@ void MLIRCompiler::lower_ng_dialect()
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
// TODO: Do this via PassManager
mlir::LLVMTypeConverter llvm_converter(&m_context);
OwningRewritePatternList patterns;
mlir::populateLoopToStdConversionPatterns(patterns, &m_context);
mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
auto result = applyConversionPatterns(*m_module, target, llvm_converter, std::move(patterns));
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
target.addDynamicallyLegalOp<mlir::FuncOp>(
[&](mlir::FuncOp op) { return llvm_converter.isSignatureLegal(op.getType()); });
auto result = applyFullConversion(*m_module, target, std::move(patterns), &llvm_converter);
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:");
// Lower to LLVM BC and optimize
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
......@@ -509,8 +516,8 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
mlir::Function* func = m_module->getNamedFunction("main");
NGRAPH_CHECK(func && !func->getBlocks().empty(), "Function not found");
mlir::FuncOp func = m_module->lookupSymbol<mlir::FuncOp>("main");
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments
NGRAPH_CHECK(m_compiled_kernel, "No compiled kernel set for compiler");
......
......@@ -77,10 +77,7 @@ namespace ngraph
/// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
/// Returns memory manager pointer argument ID in call interface.
unsigned get_mem_mgr_arg_id(mlir::Function* func)
{
return func->getNumArguments() - 1;
}
unsigned get_mem_mgr_arg_id(mlir::FuncOp& func);
private:
struct TensorInfo
......@@ -147,7 +144,7 @@ namespace ngraph
// compilation.
mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module;
mlir::OwningModuleRef m_module;
std::unique_ptr<mlir::OpBuilder> m_builder;
std::unique_ptr<mlir::ExecutionEngine> m_engine;
......
......@@ -74,7 +74,7 @@ namespace
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
ConversionPatternRewriter& rewriter) const override; \
};
#include "op_lowerers.inc"
......@@ -117,14 +117,15 @@ namespace
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
Value* createTempTensor(Type type, PatternRewriter& rewriter);
mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
PatternRewriter& rewriter);
mlir::FuncOp getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
PatternRewriter& rewriter);
/// Inserts dealloc Ops for each temporary allocated by AllocOp
void insertDeallocs(PatternRewriter& rewriter);
NGraphTypeConverter& getTypeConverter() { return typeConverter; }
private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
......@@ -150,6 +151,9 @@ namespace
// Create type converter and initialize conversion patterns.
NGraphTypeConverter converter;
OwningRewritePatternList patterns;
// Add default FuncOp type conversion. It replaces the incoming FuncOp with a *new* one
// with the converted types.
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), typeConverter);
populateNGraphToAffineConversionPatterns(patterns);
// Create target that defines legal ops for nGraph dialect to be lowered to.
......@@ -157,14 +161,18 @@ namespace
// TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
// it as part of the lowering to affine/standard.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalOp<NGFakeInputOp>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp, NGFakeInputOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// FuncOp is legal only if types have been converted to Std types.
return typeConverter.isSignatureLegal(op.getType());
});
// capture output values by looking for the Return and grabbing the values
// the order of the returned values matches the order of the lowered func signature for
// results. This is used to find the arg_id that a defined value maps to if it is an output
findOutputValues();
if (failed(applyConversionPatterns(getModule(), target, converter, std::move(patterns))))
if (failed(applyFullConversion(getModule(), target, std::move(patterns), &converter)))
{
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
signalPassFailure();
......@@ -187,13 +195,13 @@ namespace
void DialectLoweringPass::findOutputValues()
{
// get original function
auto f = getModule().getNamedFunction("main");
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
SmallVector<Value*, 4> outputList;
unsigned outputCount = 0;
// we find out output values by looking at returned values
// any return should return all outputs of the subgraph
f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
f.walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++)
{
auto outputValue = ret.getOperand(i);
......@@ -280,9 +288,9 @@ namespace
void DialectLoweringPass::processFakeInstrs()
{
auto context = getModule().getContext();
auto f = getModule().getNamedFunction("main");
mlir::Block* entryBlock = &*(f->begin());
auto oldFuncType = f->getType();
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
mlir::Block* entryBlock = &*(f.begin());
auto oldFuncType = f.getType();
ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
SmallVector<mlir::Type, 4> allArgs;
......@@ -304,7 +312,7 @@ namespace
entryBlock->addArgument(indexType);
// update type
auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
f->setType(newFuncType);
f.setType(newFuncType);
// RAUW fake outputs with result values
unsigned i = 0;
......@@ -327,13 +335,13 @@ namespace
/// by nGraph op semantics.
void DialectLoweringPass::insertNoAliasArgAttrs()
{
auto func = getModule().getNamedFunction("main");
auto func = getModule().lookupSymbol<mlir::FuncOp>("main");
unsigned int argIdx = 0;
for (auto* arg : func->getArguments())
for (auto* arg : func.getArguments())
{
if (arg->getType().isa<MemRefType>())
{
func->setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
}
++argIdx;
......@@ -348,21 +356,19 @@ namespace
}
}
mlir::Function* DialectLoweringPass::getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
PatternRewriter& rewriter)
mlir::FuncOp DialectLoweringPass::getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
PatternRewriter& rewriter)
{
auto callBackFuncPtr = getModule().getNamedFunction(name);
if (callBackFuncPtr == nullptr)
auto callBackFunc = getModule().lookupSymbol<mlir::FuncOp>(name);
if (!callBackFunc)
{
auto callBackType = rewriter.getFunctionType(args, output);
auto callBackFunc =
llvm::make_unique<mlir::Function>(rewriter.getUnknownLoc(), name, callBackType);
callBackFuncPtr = callBackFunc.get();
getModule().getFunctions().push_back(callBackFunc.release());
auto callBackFunc = mlir::FuncOp::create(rewriter.getUnknownLoc(), name, callBackType);
getModule().push_back(callBackFunc);
}
return callBackFuncPtr;
return callBackFunc;
}
// NGDialect converters
......@@ -394,15 +400,15 @@ namespace
return mlir::IntegerType::get(1 /* width */, boolType.getContext());
}
NGRAPH_CHECK(false, "Unsupported type to lower");
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return
// the input type if the type is not supported.
return type;
}
#define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const
Operation* op, ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) const
// ADD
REWRITER(NGAddOp)
{
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
......
......@@ -478,6 +478,8 @@ set (SRC
shape.hpp
shape_util.cpp
shape_util.hpp
slice_plan.cpp
slice_plan.hpp
specialize_function.cpp
specialize_function.hpp
state/rng_state.cpp
......
......@@ -151,59 +151,58 @@ namespace ngraph
/// Return a pointer to the node that produces the wrapped value.
/// If no additional reshape or broadcast op was needed, simply return \p node.
static std::shared_ptr<Node>
add_required_ops(const std::shared_ptr<Node>& node,
const ngraph::Shape& node_shape_after_possible_reshaping,
const ngraph::AxisSet& node_broadcast_axes,
const ngraph::Shape& node_final_shape)
add_required_ops(const Output<Node>& value,
const ngraph::Shape& shape_after_possible_reshaping,
const ngraph::AxisSet& broadcast_axes,
const ngraph::Shape& final_shape)
{
std::shared_ptr<Node> return_node{node};
Output<Node> return_value{value};
if (node->get_shape() != node_shape_after_possible_reshaping)
if (value.get_shape() != shape_after_possible_reshaping)
{
// tell reshape to examine input dimensions in order
ngraph::AxisVector order = ngraph::get_default_order(node->get_shape());
return_node = std::make_shared<ngraph::op::Reshape>(
return_node, order, node_shape_after_possible_reshaping);
ngraph::AxisVector order = ngraph::get_default_order(value.get_shape());
return_value = std::make_shared<ngraph::op::Reshape>(
return_value, order, shape_after_possible_reshaping);
}
if (node_final_shape != node_shape_after_possible_reshaping)
if (final_shape != shape_after_possible_reshaping)
{
return_node = std::make_shared<ngraph::op::Broadcast>(
return_node, node_final_shape, node_broadcast_axes);
return_value = std::make_shared<ngraph::op::Broadcast>(
return_value, final_shape, broadcast_axes);
}
return return_node;
return return_value.get_node_shared_ptr();
}
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>& args)
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args)
{
NGRAPH_CHECK(args.first);
NGRAPH_CHECK(args.second);
NGRAPH_CHECK(args.first.get_node());
NGRAPH_CHECK(args.second.get_node());
const ngraph::Shape& arg1_in_shape = args.first->get_shape();
const ngraph::Shape& arg2_in_shape = args.second->get_shape();
const ngraph::Shape& arg1_in_shape = args.first.get_shape();
const ngraph::Shape& arg2_in_shape = args.second.get_shape();
// Handle the trivial case...
if (arg1_in_shape == arg2_in_shape)
{
return args;
return make_pair(args.first.as_single_output_node(),
args.second.as_single_output_node());
}
Autobroadcast_plan plan =
compute_shapes_and_broadcast_axes(arg1_in_shape, arg2_in_shape);
std::shared_ptr<Node> arg1_out =
add_required_ops(args.first,
plan.m_arg1_shape_after_possible_reshaping,
plan.m_arg1_broadcast_axes,
plan.m_final_shape);
std::shared_ptr<Node> arg2_out =
add_required_ops(args.second,
plan.m_arg2_shape_after_possible_reshaping,
plan.m_arg2_broadcast_axes,
plan.m_final_shape);
auto arg1_out = add_required_ops(args.first,
plan.m_arg1_shape_after_possible_reshaping,
plan.m_arg1_broadcast_axes,
plan.m_final_shape);
auto arg2_out = add_required_ops(args.second,
plan.m_arg2_shape_after_possible_reshaping,
plan.m_arg2_broadcast_axes,
plan.m_final_shape);
return {arg1_out, arg2_out};
}
......
......@@ -42,7 +42,7 @@ namespace ngraph
static std::string error_str(const ngraph::Shape& shape1, const ngraph::Shape& shape2);
};
/// \brief Wrap two graph nodes, if necessary, to obtain values with identical shapes,
/// \brief Wrap two graph values, if necessary, to obtain values with identical shapes,
/// using NumPy's auto-broadcast rules.
///
/// The elements in the std::pair returned by this function correspond to those supplied
......@@ -71,7 +71,7 @@ namespace ngraph
///
/// \exception ngraph::builder::autobroadcast_incompatible_shapes
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>
numpy_broadcast(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>& args);
numpy_broadcast(const std::pair<Output<Node>, Output<Node>>& args);
/// Create a new \p NodeType node, and any additional nodes required to simulate NumPy-style autobroadcast
/// semantics. Intended for binary operations such as "Add".
......@@ -87,11 +87,10 @@ namespace ngraph
/// \exception ngraph::builder::autobroadcast_incompatible_shapes
template <typename NodeType>
std::shared_ptr<NodeType>
make_with_numpy_broadcast(const std::shared_ptr<Node>& operand1_reshapeable,
const std::shared_ptr<Node>& operand2_reshapeable)
make_with_numpy_broadcast(const Output<Node>& operand1_reshapeable,
const Output<Node>& operand2_reshapeable)
{
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> shaped_op1_op2 =
numpy_broadcast({operand1_reshapeable, operand2_reshapeable});
auto shaped_op1_op2 = numpy_broadcast({operand1_reshapeable, operand2_reshapeable});
return std::make_shared<NodeType>(shaped_op1_op2.first, shaped_op1_op2.second);
}
......@@ -112,16 +111,13 @@ namespace ngraph
///
/// \exception ngraph::builder::autobroadcast_incompatible_shapes
template <typename NodeType>
std::shared_ptr<NodeType>
make_with_numpy_broadcast(const std::shared_ptr<Node>& operand1,
const std::shared_ptr<Node>& operand2_reshapeable,
const std::shared_ptr<Node>& operand3_reshapeable)
std::shared_ptr<Node> make_with_numpy_broadcast(const Output<Node>& operand1,
const Output<Node>& operand2_reshapeable,
const Output<Node>& operand3_reshapeable)
{
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> shaped_op2_op3 =
numpy_broadcast({operand2_reshapeable, operand3_reshapeable});
auto shaped_op2_op3 = numpy_broadcast({operand2_reshapeable, operand3_reshapeable});
return std::make_shared<NodeType>(
operand1, shaped_op2_op3.first, shaped_op2_op3.second);
}
} // namespace builder
} // namespace ngraph
......@@ -34,18 +34,18 @@ namespace ngraph
{
namespace detail
{
shared_ptr<Node> lp_norm(const shared_ptr<Node>& node,
shared_ptr<Node> lp_norm(const Output<Node>& value,
size_t p_norm,
const AxisSet& reduction_axes,
float bias)
{
// In general "entrywise" lp-norm for matrix `A` is defined as following double sum:
// ||A||_p = ||vec(A)||_p = [sum_{i=1}^m sum_{j=1}^n abs(a_{i,j})^p]^{1/p}
shared_ptr<Node> abs_values{make_shared<op::Abs>(node)};
shared_ptr<Node> abs_values{make_shared<op::Abs>(value)};
shared_ptr<Node> p_node = op::Constant::create(
node->get_element_type(),
node->get_shape(),
vector<float>(shape_size(node->get_shape()), static_cast<float>(p_norm)));
value.get_element_type(),
value.get_shape(),
vector<float>(shape_size(value.get_shape()), static_cast<float>(p_norm)));
// Get inner part of equation: abs_values^p_node, then sum over reduction_axes.
shared_ptr<Node> values{make_shared<op::Power>(abs_values, p_node)};
......@@ -68,26 +68,26 @@ namespace ngraph
}
}
shared_ptr<Node> l0_norm(const shared_ptr<Node>& node, const AxisSet& reduction_axes)
shared_ptr<Node> l0_norm(const Output<Node>& value, const AxisSet& reduction_axes)
{
// L0 norm returns number of elements different from zero.
shared_ptr<Node> zero_node{
op::Constant::create(node->get_element_type(),
node->get_shape(),
vector<float>(shape_size(node->get_shape()), 0.f))};
op::Constant::create(value.get_element_type(),
value.get_shape(),
vector<float>(shape_size(value.get_shape()), 0.f))};
// Convert bool values to input node data type.
shared_ptr<Node> non_zero_values = make_shared<op::Convert>(
make_shared<op::NotEqual>(node, zero_node), node->get_element_type());
make_shared<op::NotEqual>(value, zero_node), value.get_element_type());
return make_shared<op::Sum>(non_zero_values, reduction_axes);
}
shared_ptr<Node>
l1_norm(const shared_ptr<Node>& node, const AxisSet& reduction_axes, float bias)
l1_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias)
{
shared_ptr<Node> values{
make_shared<op::Sum>(make_shared<op::Abs>(node), reduction_axes)};
make_shared<op::Sum>(make_shared<op::Abs>(value), reduction_axes)};
shared_ptr<Node> bias_node{
op::Constant::create(values->get_element_type(),
......@@ -98,9 +98,9 @@ namespace ngraph
}
shared_ptr<Node>
l2_norm(const shared_ptr<Node>& node, const AxisSet& reduction_axes, float bias)
l2_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias)
{
shared_ptr<Node> values{make_shared<op::Sum>(node * node, reduction_axes)};
shared_ptr<Node> values{make_shared<op::Sum>(value * value, reduction_axes)};
shared_ptr<Node> bias_node{
op::Constant::create(values->get_element_type(),
......@@ -110,7 +110,7 @@ namespace ngraph
return {make_shared<op::Sqrt>(values + bias_node)};
}
shared_ptr<Node> lp_norm(const shared_ptr<Node>& node,
shared_ptr<Node> lp_norm(const Output<Node>& value,
const AxisSet& reduction_axes,
size_t p_norm,
float bias)
......@@ -118,22 +118,22 @@ namespace ngraph
// The number of non-zero elements
if (p_norm == 0)
{
return l0_norm(node, reduction_axes);
return l0_norm(value, reduction_axes);
}
// sum of absolute values.
else if (p_norm == 1)
{
return l1_norm(node, reduction_axes, bias);
return l1_norm(value, reduction_axes, bias);
}
// sqrt of sum of squares - Euclidean norm
else if (p_norm == 2)
{
return l2_norm(node, reduction_axes, bias);
return l2_norm(value, reduction_axes, bias);
}
// generic case
else
{
return detail::lp_norm(node, p_norm, reduction_axes, bias);
return detail::lp_norm(value, p_norm, reduction_axes, bias);
}
}
......
......@@ -25,62 +25,57 @@ namespace ngraph
{
namespace builder
{
/// \brief Creates node which calculates L-0 norm of input tensor.
/// \brief Calculates L-0 norm of input tensor.
///
/// \note The L-0 norm represents the cardinality of elements different
/// from zero. This actually is not a "true" norm.
///
/// \param[in] node The input tensor node.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
///
/// \return Node which calculates L-0 norm values.
/// \return L-0 norm of value.
///
std::shared_ptr<Node> l0_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes);
std::shared_ptr<Node> l0_norm(const Output<Node>& value, const AxisSet& reduction_axes);
/// \brief Creates node which calculates L-1 norm of input tensor.
/// \brief Calculates L-1 norm of a value.
///
/// \note The L-1 norm represents the sum of absolute values.
///
/// \param[in] node The input tensor node.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
///
/// \return Node which calculates L-1 norm values.
/// \return L-1 norm of value.
///
std::shared_ptr<Node> l1_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes,
float bias = 0.f);
std::shared_ptr<Node>
l1_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias = 0.f);
/// \brief Calculates L-2 norm of input tensor.
///
/// \note The L-2 norm represents the square root of sum of squares of each
/// individual element.
///
/// \param[in] node The input tensor node.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] bias The bias added to the calculated sum.
///
/// \return Node which calculates L-2 norm values.
/// \return L-2 norm of value.
///
std::shared_ptr<Node> l2_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes,
float bias = 0.f);
std::shared_ptr<Node>
l2_norm(const Output<Node>& value, const AxisSet& reduction_axes, float bias = 0.f);
/// \brief Creates node which calculates L-p norm on input tensor.
///
/// \param[in] node The input nGraph tensor.
/// \param[in] value The input tensor.
/// \param[in] reduction_axes The axes along which we calculate norm.
/// \param[in] p_norm The p norm to calculate.
/// \param[in] bias The bias added to the calculated sum.
///
/// \return Node which calculates L-p norm.
/// \return L-p norm of value.
///
std::shared_ptr<Node> lp_norm(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> lp_norm(const Output<Node>& value,
const AxisSet& reduction_axes,
std::size_t p_norm = 2,
float bias = 0.f);
} // namespace builder
} // namespace ngraph
......@@ -37,9 +37,9 @@ namespace ngraph
namespace builder
{
std::shared_ptr<Node> numpy_transpose(const std::shared_ptr<Node>& node, AxisVector order)
std::shared_ptr<Node> numpy_transpose(const Output<Node>& value, AxisVector order)
{
auto in_shape = node->get_shape();
auto in_shape = value.get_shape();
// default, reverse the order of the axes
if (order.size() == 0)
{
......@@ -74,7 +74,7 @@ namespace ngraph
out_shape.push_back(in_shape[order[i]]);
// do the reshaping with the order
return std::make_shared<ngraph::op::Reshape>(node, order, out_shape);
return std::make_shared<ngraph::op::Reshape>(value, order, out_shape);
}
} // namespace builder
......
......@@ -45,7 +45,6 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_{n-1},\dots,d_0)]\textit{ or }E[d_{order[0]},\dots,d_{order[n-1]}]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the axes reordered via Numpy Transpose rules |
std::shared_ptr<Node> numpy_transpose(const std::shared_ptr<Node>& node,
AxisVector order = {});
std::shared_ptr<Node> numpy_transpose(const Output<Node>& value, AxisVector order = {});
} // namespace builder
} // namespace ngraph
......@@ -42,8 +42,7 @@ namespace ngraph
return N;
}
std::shared_ptr<Node> l2_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes)
std::shared_ptr<Node> l2_norm(const Output<Node>& node, const AxisSet& reduction_axes)
{
auto x2 = node * node;
auto x2sum = std::make_shared<op::Sum>(x2, reduction_axes);
......@@ -51,19 +50,19 @@ namespace ngraph
return std::make_shared<op::Sqrt>(x2sum);
}
std::shared_ptr<Node> mean(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
std::shared_ptr<Node> mean(const Output<Node>& value, const AxisSet& reduction_axes)
{
auto xsum = std::make_shared<op::Sum>(node, reduction_axes);
auto xsum = std::make_shared<op::Sum>(value, reduction_axes);
auto N = get_num_elements(node->get_shape(), reduction_axes);
const auto& et = node->get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes);
const auto& et = value.get_element_type();
auto divisor = op::Constant::create(et, xsum->get_shape(), {N});
return xsum / divisor;
}
std::shared_ptr<Node> std_dev(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> std_dev(const Output<Node>& node,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
......@@ -74,13 +73,13 @@ namespace ngraph
// The second might be more numerically stable/easier to pattern match
// It also requires adding a broadcast op, and would probably be slower
// TODO(mbrookhart): Switch to E[(X-\mu)^2]?
std::shared_ptr<Node> variance(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction)
{
std::shared_ptr<Node> mu = mean(node, reduction_axes);
std::shared_ptr<Node> mu = mean(value, reduction_axes);
auto reshape = node->get_shape();
auto reshape = value.get_shape();
for (auto i : reduction_axes)
{
reshape[i] = 1;
......@@ -90,21 +89,21 @@ namespace ngraph
mu = std::make_shared<op::Reshape>(mu, order, reshape);
std::shared_ptr<Node> diff = make_with_numpy_broadcast<op::Subtract>(node, mu);
Output<Node> diff = make_with_numpy_broadcast<op::Subtract>(value, mu);
diff = std::make_shared<op::Sum>(diff * diff, reduction_axes);
const auto& et = node->get_element_type();
auto N = get_num_elements(node->get_shape(), reduction_axes);
const auto& et = value.get_element_type();
auto N = get_num_elements(value.get_shape(), reduction_axes);
if (bessel_correction)
{
auto N1const = op::Constant::create(et, diff->get_shape(), {N - 1});
auto N1const = op::Constant::create(et, diff.get_shape(), {N - 1});
return diff / N1const;
}
else
{
auto Nconst = op::Constant::create(et, diff->get_shape(), {N});
auto Nconst = op::Constant::create(et, diff.get_shape(), {N});
return diff / Nconst;
}
}
......
......@@ -35,7 +35,7 @@ namespace ngraph
///
/// | | Type | Description |
/// | ---------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `node` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `value` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `reduction_axes` | AxesSet | The axes to eliminate through reduction (0 indexed). |
///
/// ## Output
......@@ -43,8 +43,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> l2_norm(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes);
std::shared_ptr<Node> l2_norm(const Output<Node>& value, const AxisSet& reduction_axes);
/// \brief Sum-based Mean of a Tensor.
///
......@@ -66,8 +65,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> mean(const std::shared_ptr<Node>& node,
const AxisSet& reduction_axes);
std::shared_ptr<Node> mean(const Output<Node>& node, const AxisSet& reduction_axes);
/// \brief Sum-based Standard Deviation of a Tensor.
///
......@@ -85,7 +83,7 @@ namespace ngraph
///
/// | | Type | Description |
/// | ------------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `node` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `value` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `reduction_axes` | AxesSet | The axes to eliminate through reduction (0 indexed). |
/// | `bessel_correction` | bool (default = false) | Enable Bessel's correction to std_dev for Small sample sizes |
///
......@@ -94,7 +92,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> std_dev(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> std_dev(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction = false);
......@@ -114,7 +112,7 @@ namespace ngraph
///
/// | | Type | Description |
/// | ------------------- | --------------------------------- | ----------------------------------------------------------------------------------------------------- |
/// | `node` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `value | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | An input tensor of any shape
/// | `reduction_axes` | AxesSet | The axes to eliminate through reduction (0 indexed). |
/// | `bessel_correction` | bool (default = false) | Enable Bessel's correction to std_dev for Small sample sizes |
///
......@@ -123,7 +121,7 @@ namespace ngraph
/// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
std::shared_ptr<Node> variance(const std::shared_ptr<Node>& node,
std::shared_ptr<Node> variance(const Output<Node>& value,
const AxisSet& reduction_axes,
const bool bessel_correction = false);
......
......@@ -27,14 +27,14 @@
using namespace ngraph;
using namespace std;
shared_ptr<Node> builder::reshape(const shared_ptr<Node>& node, const Shape& shape)
shared_ptr<Node> builder::reshape(const Output<Node>& value, const Shape& shape)
{
return make_shared<op::Reshape>(node, get_default_order(node->get_shape().size()), shape);
return make_shared<op::Reshape>(value, get_default_order(value.get_shape().size()), shape);
}
shared_ptr<Node> builder::reorder_axes(const shared_ptr<Node>& node, vector<size_t> axes_order = {})
shared_ptr<Node> builder::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
{
Shape out_shape = node->get_shape();
Shape out_shape = value.get_shape();
if (axes_order.empty())
{
axes_order.resize(out_shape.size());
......@@ -44,25 +44,25 @@ shared_ptr<Node> builder::reorder_axes(const shared_ptr<Node>& node, vector<size
{
for (size_t i = 0; i < axes_order.size(); ++i)
{
out_shape[i] = node->get_shape().at(axes_order.at(i));
out_shape[i] = value.get_shape().at(axes_order.at(i));
}
}
auto axis_vector = AxisVector{begin(axes_order), end(axes_order)};
return make_shared<op::Reshape>(node, axis_vector, out_shape);
return make_shared<op::Reshape>(value, axis_vector, out_shape);
}
shared_ptr<Node> builder::transpose(const shared_ptr<Node>& node)
shared_ptr<Node> builder::transpose(const Output<Node>& value)
{
vector<size_t> axes_order(node->get_shape().size());
vector<size_t> axes_order(value.get_shape().size());
iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order));
return builder::reorder_axes(node, axes_order);
return builder::reorder_axes(value, axes_order);
}
shared_ptr<Node> builder::flatten(const shared_ptr<Node>& node, int axis)
shared_ptr<Node> builder::flatten(const Output<Node>& value, int axis)
{
auto data_shape = node->get_shape();
auto data_shape = value.get_shape();
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
......@@ -73,5 +73,5 @@ shared_ptr<Node> builder::flatten(const shared_ptr<Node>& node, int axis)
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return make_shared<op::Reshape>(
node, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
value, get_default_order(data_shape.size()), Shape{first_dim_size, last_dim_size});
}
......@@ -27,37 +27,37 @@ namespace ngraph
{
namespace builder
{
/// \brief Change shape of input tensor.
/// \brief Change shape of a value
///
/// \param[in] node The node producing the tensor to be reshaped.
/// \param[in] shape The new shape for input tensor.
/// \param[in] value The value to be reshaped.
/// \param[in] shape The new shape.
///
/// \return The node representing a Reshape operation.
/// \return The reshaped value.
///
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& node, const Shape& shape);
std::shared_ptr<Node> reshape(const Output<Node>& value, const Shape& shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
/// \param value The vlaue whose axes we want to permute.
/// \param axes_order The permutation of axes.
///
/// \return: New node with permuted axes.
std::shared_ptr<Node> reorder_axes(const std::shared_ptr<Node>& node,
std::vector<std::size_t> axes_order);
/// \return: Value with permuted axes.
std::shared_ptr<Node> reorder_axes(const Output<Node>& value,
std::vector<size_t> axes_order = {});
/// \brief Return transposed tensor (with axes in reversed order).
/// \brief Return transposed vlaue (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
/// \param value Value to transpose.
///
/// \return: New node with reversed dimensions.
std::shared_ptr<Node> transpose(const std::shared_ptr<Node>& node);
/// \return: Value with reversed dimensions.
std::shared_ptr<Node> transpose(const Output<Node>& value);
/// \brief Flatten the input tensor into a 2D matrix.
/// \brief Flatten a value into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param value The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node will be a 2D matrix representing the flattened input node.
std::shared_ptr<Node> flatten(const std::shared_ptr<Node>& node, int axis);
/// \return The new value will be a 2D matrix representing the flattened input node.
std::shared_ptr<Node> flatten(const Output<Node>& value, int axis);
} // namespace builder
} // namespace ngraph
......@@ -21,31 +21,31 @@ using namespace ngraph;
namespace
{
inline std::size_t get_valid_array_index(std::size_t idx, std::size_t axis_size)
inline size_t get_valid_array_index(size_t idx, size_t axis_size)
{
return std::min(idx, axis_size);
}
std::shared_ptr<op::Slice> make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& axes,
const std::vector<std::size_t>& starts,
const std::vector<std::size_t>& ends)
std::shared_ptr<op::Slice> make_ng_slice(const Output<Node>& output,
const std::vector<size_t>& axes,
const std::vector<size_t>& starts,
const std::vector<size_t>& ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
std::vector<size_t> upper_bounds{output.get_shape()};
std::vector<size_t> lower_bounds(upper_bounds.size());
for (size_t index{0}; index < axes.size(); ++index)
{
std::size_t axis{axes.at(index)};
size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
get_valid_array_index(starts.at(index), output.get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
get_valid_array_index(ends.at(index), output.get_shape().at(axis));
}
return std::make_shared<op::Slice>(node, lower_bounds, upper_bounds);
return std::make_shared<op::Slice>(output, lower_bounds, upper_bounds);
}
}
NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node,
NodeVector builder::split(const Output<ngraph::Node>& value,
const std::vector<size_t>& length_parts,
size_t axis)
{
......@@ -54,21 +54,21 @@ NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node,
for (const auto& length_part : length_parts)
{
size_t end_index{start_index + length_part};
outputs.push_back(make_ng_slice(node, {axis}, {start_index}, {end_index}));
outputs.push_back(make_ng_slice(value, {axis}, {start_index}, {end_index}));
start_index = end_index;
}
return outputs;
}
NodeVector builder::split(const std::shared_ptr<ngraph::Node>& node, size_t split_parts, int axis)
NodeVector builder::split(const Output<Node>& value, size_t split_parts, int axis)
{
size_t axis_to_split{static_cast<size_t>(axis)};
if (axis < 0)
{
axis_to_split = node->get_shape().size() + axis;
axis_to_split = value.get_shape().size() + axis;
}
size_t length_axis_to_split{node->get_shape().at(axis_to_split)};
size_t length_axis_to_split{value.get_shape().at(axis_to_split)};
std::vector<size_t> length_parts(split_parts, length_axis_to_split / split_parts);
return split(node, length_parts, axis_to_split);
return split(value, length_parts, axis_to_split);
}
......@@ -23,22 +23,22 @@ namespace ngraph
{
namespace builder
{
/// \brief Split node on specified axis into multiple parts.
/// \brief Split value on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] value The value to be split.
/// \param[in] length_parts The vector defining the lengths of each split part.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector split(const std::shared_ptr<ngraph::Node>& node,
const std::vector<std::size_t>& length_parts,
std::size_t axis = 0);
NodeVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
size_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///
/// \param[in] node The input node.
/// \param[in] split_parts The number of parts we want to split input node at given
/// \param[in] value The value to split.
/// \param[in] split_parts The number of parts we want to split output at given
/// axis. The length of the axis to split must be divisible by
/// this value.
/// \param[in] axis The axis we split input node on. Default value is zero axis.
......@@ -49,7 +49,6 @@ namespace ngraph
///
/// \return The vector containing multiple nodes we split input node into.
///
NodeVector
split(const std::shared_ptr<ngraph::Node>& node, std::size_t split_parts, int axis = 0);
NodeVector split(const Output<Node>& value, size_t split_parts, int axis = 0);
} // namespace builder
} // namespace ngraph
......@@ -14,17 +14,16 @@
// limitations under the License.
//*****************************************************************************
#include "null_node.hpp"
#include <string>
#include "ngraph/node.hpp"
#include "null_node.hpp"
namespace ngraph
{
namespace onnx_import
{
NullNode::NullNode()
: ngraph::Node("NullNode", {}, 0)
{
}
const std::string NullNode::type_name{"NullNode"};
std::shared_ptr<Node> NullNode::copy_with_new_args(const NodeVector& new_args) const
{
......
......@@ -36,7 +36,10 @@ namespace ngraph
class NullNode : public ngraph::Node
{
public:
NullNode();
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
NullNode() = default;
bool is_null() const final override { return true; }
virtual std::shared_ptr<Node>
......
......@@ -22,8 +22,10 @@
using namespace std;
using namespace ngraph;
op::Clamp::Clamp(const shared_ptr<Node>& data, const double min, const double max)
: FusedOp("Clamp", {data})
const string op::Clamp::type_name{"Clamp"};
op::Clamp::Clamp(const Output<Node>& data, const double min, const double max)
: FusedOp({data})
, m_min{min}
, m_max{max}
{
......
......@@ -32,12 +32,15 @@ namespace ngraph
class Clamp : public ngraph::op::util::FusedOp
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Clamp node.
///
/// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range
Clamp(const std::shared_ptr<ngraph::Node>& data, const double min, const double max);
Clamp(const Output<ngraph::Node>& data, const double min, const double max);
void pre_validate_and_infer_types() override;
......
......@@ -74,7 +74,8 @@ NodeVector op::Gemm::decompose_op() const
C = std::make_shared<ngraph::op::Multiply>(beta_node, C);
// alpha * A' * B' + beta * C
NodeVector broadcasted_nodes = ngraph::op::numpy_style_broadcast({a_dot_b, C});
OutputVector broadcasted_nodes =
ngraph::op::numpy_style_broadcast_values(OutputVector{a_dot_b, C});
// The input tensor `C` should be "unidirectionally broadcastable" to the `a_dot_b` tensor.
// Numpy style broadcast is bidirectional, so we only use the second output from broadcasting.
return {std::make_shared<ngraph::op::Add>(a_dot_b, broadcasted_nodes.at(1))};
......
......@@ -37,34 +37,34 @@ op::PRelu::PRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& slope)
NodeVector op::PRelu::decompose_op() const
{
auto data = get_argument(0);
auto data_shape = data->get_shape();
auto slope = get_argument(1);
auto slope_shape = slope->get_shape();
auto data = input(0).get_source_output();
auto data_shape = data.get_shape();
auto slope = input(1).get_source_output();
auto slope_shape = slope.get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
slope = make_broadcast_node(slope, data.get_shape(), index);
}
else if (data_shape != slope_shape)
{
slope = numpy_style_broadcast({slope, data})[0];
slope = numpy_style_broadcast_values({slope, data})[0];
}
// x < 0 => f(x) = x * slope
// x >= 0 => f(x) = x
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
data.get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data.get_shape());
std::shared_ptr<ngraph::Node> negative_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Less>(data, zero_node), data->get_element_type());
std::make_shared<ngraph::op::Less>(data, zero_node), data.get_element_type());
std::shared_ptr<ngraph::Node> positive_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, zero_node), data->get_element_type());
std::make_shared<ngraph::op::Greater>(data, zero_node), data.get_element_type());
slope = negative_map * slope + positive_map;
......
......@@ -31,12 +31,12 @@ op::ScaleShift::ScaleShift(const std::shared_ptr<ngraph::Node>& data,
NodeVector op::ScaleShift::decompose_op() const
{
auto data = get_argument(0);
auto scale = get_argument(1);
auto shift = get_argument(2);
auto data = input(0).get_source_output();
auto scale = input(1).get_source_output();
auto shift = input(2).get_source_output();
// broadcast all data
auto broadcasted_nodes = numpy_style_broadcast({data, scale, shift});
auto broadcasted_nodes = numpy_style_broadcast_values({data, scale, shift});
data = broadcasted_nodes[0];
scale = broadcasted_nodes[1];
shift = broadcasted_nodes[2];
......
......@@ -32,10 +32,10 @@ op::SquaredDifference::SquaredDifference(const shared_ptr<Node>& x1, const share
NodeVector op::SquaredDifference::decompose_op() const
{
const auto x1 = get_argument(0);
const auto x2 = get_argument(1);
const auto x1 = input(0).get_source_output();
const auto x2 = input(1).get_source_output();
const auto broadcasted = numpy_style_broadcast({x1, x2});
const auto broadcasted = numpy_style_broadcast_values({x1, x2});
const auto difference = broadcasted.at(0) - broadcasted.at(1);
......
......@@ -84,8 +84,7 @@ void op::Softmax::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto order = ngraph::get_default_order(zsum->get_shape());
auto zreshape = make_shared<op::Reshape>(zsum, order, shape);
auto adjoint =
z - builder::make_with_numpy_broadcast<op::Multiply>(shared_from_this(), zreshape);
auto adjoint = z - builder::make_with_numpy_broadcast<op::Multiply>(output(0), zreshape);
auto x = get_argument(0);
adjoints.add_delta(x, adjoint);
......
......@@ -104,6 +104,19 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
return get_numpy_broadcast_shapes(input_shapes);
}
static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
get_numpy_broadcast_shapes(const ngraph::OutputVector& values)
{
std::vector<ngraph::Shape> input_shapes;
for (const auto& input : values)
{
input_shapes.push_back(input.get_shape());
}
return get_numpy_broadcast_shapes(input_shapes);
}
/// \brief Broadcast input node.
///
/// \note The source shape does not have to be the actual shape of input node. However
......@@ -112,21 +125,21 @@ static std::pair<ngraph::Shape, std::vector<ngraph::Shape>>
/// The ranks of source_shape and output_shape must be equal. This means that the
/// source_shape has to be padded with ones for this operation.
///
/// \param[in] node The input Node to be broadcast.
/// \param[in] value The input Node to be broadcast.
/// \param[in] output_shape The output shape.
/// \param[in] source_shape The source shape from which we want to broadcast input node.
///
/// \return The broadcasted Node.
///
static std::shared_ptr<ngraph::Node>
broadcast_node_numpy_style(const std::shared_ptr<ngraph::Node>& node,
broadcast_node_numpy_style(const ngraph::Output<ngraph::Node>& value,
const ngraph::Shape& output_shape,
const ngraph::Shape& source_shape)
{
// If node already has the required shape, return original node
if (output_shape == node->get_shape())
if (output_shape == value.get_shape())
{
return node;
return value.as_single_output_node();
}
if (source_shape.size() != output_shape.size())
......@@ -153,16 +166,35 @@ static std::shared_ptr<ngraph::Node>
}
// Remove axes which have length of 1 from source shape
auto broadcasted_node = std::make_shared<ngraph::op::Reshape>(
node, ngraph::get_default_order(node->get_shape()), squeezed_shape);
ngraph::Output<ngraph::Node> broadcasted_value = std::make_shared<ngraph::op::Reshape>(
value, ngraph::get_default_order(value.get_shape()), squeezed_shape);
return std::make_shared<ngraph::op::Broadcast>(broadcasted_node, output_shape, broadcast_axes);
return std::make_shared<ngraph::op::Broadcast>(broadcasted_value, output_shape, broadcast_axes);
}
namespace ngraph
{
namespace op
{
OutputVector numpy_style_broadcast_values(const OutputVector& values)
{
if (values.size() <= 1)
{
return values;
}
// find the output tensor's shape, then broadcast all inputs so that they are compatible
auto bcast_shapes = get_numpy_broadcast_shapes(values);
OutputVector broadcasted_inputs;
for (std::size_t i = 0; i < values.size(); ++i)
{
broadcasted_inputs.push_back(broadcast_node_numpy_style(
values[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
NodeVector numpy_style_broadcast(const NodeVector& inputs)
{
if (inputs.size() <= 1)
......@@ -176,19 +208,17 @@ namespace ngraph
NodeVector broadcasted_inputs;
for (std::size_t i = 0; i < inputs.size(); ++i)
{
const std::shared_ptr<ngraph::Node> input_node = inputs[i];
broadcasted_inputs.push_back(broadcast_node_numpy_style(
inputs[i], bcast_shapes.first, bcast_shapes.second[i]));
}
return broadcasted_inputs;
}
std::shared_ptr<ngraph::Node>
numpy_style_broadcast(const std::shared_ptr<ngraph::Node>& input_node,
const Shape& shape)
std::shared_ptr<ngraph::Node> numpy_style_broadcast(const Output<ngraph::Node>& value,
const Shape& shape)
{
auto bcast_shape = get_numpy_broadcast_shapes({input_node->get_shape(), shape});
return broadcast_node_numpy_style(input_node, bcast_shape.first, bcast_shape.second[0]);
auto bcast_shape = get_numpy_broadcast_shapes({value.get_shape(), shape});
return broadcast_node_numpy_style(value, bcast_shape.first, bcast_shape.second[0]);
}
NodeVector
......@@ -227,6 +257,42 @@ namespace ngraph
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
OutputVector
numpy_style_broadcast_values_for_matmul_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
// Broadcast only _stack of matrices_ axes.
const auto& numpy_shapes = get_numpy_broadcast_shapes(
{Shape{std::begin(left_shape), std::next(std::end(left_shape), -2)},
Shape{std::begin(right_shape), std::next(std::end(right_shape), -2)}});
// Prepare tensors output shapes with broadcasted _stack of matrices_ axes.
auto left_output_shape = numpy_shapes.first;
auto right_output_shape = numpy_shapes.first;
// Append the last two axes original dimensions.
left_output_shape.insert(std::end(left_output_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_output_shape.insert(std::end(right_output_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
auto left_full_shape = numpy_shapes.second.at(0);
auto right_full_shape = numpy_shapes.second.at(1);
// Append the last two axes original dimensions.
left_full_shape.insert(std::end(left_full_shape),
std::next(std::begin(left_shape), left_shape.size() - 2),
std::end(left_shape));
right_full_shape.insert(std::end(right_full_shape),
std::next(std::begin(right_shape), right_shape.size() - 2),
std::end(right_shape));
return {broadcast_node_numpy_style(left, left_output_shape, left_full_shape),
broadcast_node_numpy_style(right, right_output_shape, right_full_shape)};
}
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
......@@ -288,6 +354,67 @@ namespace ngraph
return {left, broadcast_right};
}
OutputVector
legacy_style_broadcast_values_for_binary_operation(const Output<ngraph::Node>& left,
const Output<ngraph::Node>& right,
size_t start_match_axis)
{
const auto& left_shape = left.get_shape();
const auto& right_shape = right.get_shape();
bool dimensions_identical = (left_shape == right_shape);
if (dimensions_identical)
{
return {left, right};
}
// Prepare new shape of right operand for broadcasting
// Remove dimensions with length=1 from back
auto new_right_shape = right_shape;
for (int dimension = new_right_shape.size() - 1; dimension >= 0; --dimension)
{
if (new_right_shape[dimension] == 1)
{
new_right_shape.pop_back();
}
else
{
break;
}
}
// Find first dimensions at front with length different from 1
std::size_t num_ones = 0;
for (std::size_t dimension : new_right_shape)
{
if (dimension == 1)
{
++num_ones;
}
else
{
break;
}
}
// Remove dimensions with length=1 from front
new_right_shape.erase(std::begin(new_right_shape),
std::next(std::begin(new_right_shape), num_ones));
auto reshape_right = std::make_shared<ngraph::op::Reshape>(
right, ngraph::get_default_order(right_shape), new_right_shape);
// Move broadcast start axis parameter to right
start_match_axis += num_ones;
auto broadcast_right = std::make_shared<ngraph::op::Broadcast>(
reshape_right,
left_shape,
calculate_broadcast_axes(left_shape, new_right_shape, start_match_axis));
return {left, broadcast_right};
}
AxisSet calculate_broadcast_axes(const Shape& output_shape,
const Shape& input_shape,
std::size_t start_match_axis)
......
......@@ -32,19 +32,45 @@ namespace ngraph
/// \param inputs Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
NodeVector numpy_style_broadcast(const NodeVector& inputs);
NodeVector numpy_style_broadcast(const NodeVector& inputs)
NGRAPH_DEPRECATED("Replace with numpy_style_value_broadcast");
/// \brief Cast shape of an input node to the requested output shape using NumPy's broadcasting rules
/// \brief Cast shape of all input nodes for an element-wise operation that requires shape-compatibility
///
/// \param values Original list of inputs
///
/// \return Numpy-style broadcasted list of nodes.
OutputVector numpy_style_broadcast_values(const OutputVector& values);
/// \brief Cast shape of an output to the requested output shape using NumPy's broadcasting rules
///
/// \param input_node original input node
/// \param value original value
/// \param shape requested output shape
///
/// \return Broadcast node.
std::shared_ptr<ngraph::Node>
numpy_style_broadcast(const std::shared_ptr<ngraph::Node>& input_node,
const Shape& shape);
/// \return Broadcast output.
std::shared_ptr<Node> numpy_style_broadcast(const Output<Node>& value, const Shape& shape);
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
/// specified by the argument "start_match_axis", and if it is not set,
/// suffix matching is assumed.
///
/// This style of broadcast was used in ONNX Op sets prior to version 7, where it was
/// replaced by numpy-style broadcasting.
///
/// \param left Node which contain input of binary op.
/// \param right Node which contain input of binary op.
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector legacy_style_broadcast_for_binary_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right,
size_t start_match_axis)
NGRAPH_DEPRECATED("Replace with legacy_style_value_broadcast_for_binary_operation");
/// \brief Cast shape of two nodes to make them compatible for an element-wise binary operation.
/// \brief Cast shape of two outputs to make them compatible for an element-wise binary operation.
///
/// If necessary the right-hand-side argument will be broadcast to match the shape
/// of left-hand-side argument. The starting of the mutually equal shape is
......@@ -59,10 +85,9 @@ namespace ngraph
/// \param start_match_axis position in shape denoting start of the mutually equal shape
///
/// \return Left and right node after broadcasting.
NodeVector
legacy_style_broadcast_for_binary_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right,
std::size_t start_match_axis);
OutputVector legacy_style_broadcast_values_for_binary_operation(const Output<Node>& left,
const Output<Node>& right,
size_t start_match_axis);
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
......@@ -76,9 +101,24 @@ namespace ngraph
///
/// \return The vector containing both nodes broadcasted.
///
NodeVector
numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<ngraph::Node>& left,
const std::shared_ptr<ngraph::Node>& right);
NodeVector numpy_style_broadcast_for_matmul_operation(const std::shared_ptr<Node>& left,
const std::shared_ptr<Node>& right)
NGRAPH_DEPRECATED("Replace with numpy_style_broadcast_value_for_matmul_operation.");
/// \brief Broadcast shape of two nodes to make them compatible for a matrix multiplication.
///
/// \note This function is reflecting broadcasting behaviour of NumPy's `matmul` operation
/// (https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html)
/// This mean that only \"stack of matrices\" axes are bidirectionally broadcasted.
/// The last two dimension are left untouched.
///
/// \param[in] left The Node providing data for the left-hand side of matrix multiplication.
/// \param[in] right The Node providing data for the right-hand side of matrix multiplication.
///
/// \return The vector containing both outputs broadcasted.
///
OutputVector numpy_style_broadcast_values_for_matmul_operation(const Output<Node>& left,
const Output<Node>& right);
/// \brief Generate a list of broadcast axes.
///
......@@ -118,22 +158,21 @@ namespace ngraph
output_shape, input_shape, output_shape.size() - input_shape.size());
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node, ngraph::Shape new_shape)
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& output,
Shape new_shape)
{
return std::make_shared<ngraph::op::Broadcast>(
node, new_shape, calculate_broadcast_axes(new_shape, node->get_shape()));
return std::make_shared<op::Broadcast>(
output, new_shape, calculate_broadcast_axes(new_shape, output.get_shape()));
}
inline std::shared_ptr<ngraph::Node>
make_broadcast_node(const std::shared_ptr<ngraph::Node>& node,
const ngraph::Shape& new_shape,
std::size_t start_match_axis)
inline std::shared_ptr<Node> make_broadcast_node(const Output<Node>& value,
const Shape& new_shape,
std::size_t start_match_axis)
{
return std::make_shared<ngraph::op::Broadcast>(
node,
return std::make_shared<op::Broadcast>(
value,
new_shape,
calculate_broadcast_axes(new_shape, node->get_shape(), start_match_axis));
calculate_broadcast_axes(new_shape, value.get_shape(), start_match_axis));
}
} // namespace op
} // namespace ngraph
......@@ -69,32 +69,29 @@ op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size
return afunc;
}
shared_ptr<Node> op::util::RNNCellBase::add(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
shared_ptr<Node> op::util::RNNCellBase::add(const Output<Node>& lhs, const Output<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Add>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::sub(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
shared_ptr<Node> op::util::RNNCellBase::sub(const Output<Node>& lhs, const Output<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Subtract>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::mul(const shared_ptr<Node>& lhs,
const shared_ptr<Node>& rhs)
shared_ptr<Node> op::util::RNNCellBase::mul(const Output<Node>& lhs, const Output<Node>& rhs)
{
auto args = op::numpy_style_broadcast({lhs, rhs});
auto args = op::numpy_style_broadcast_values({lhs, rhs});
return {make_shared<op::Multiply>(args.at(0), args.at(1))};
}
shared_ptr<Node> op::util::RNNCellBase::clip(const shared_ptr<Node>& data) const
shared_ptr<Node> op::util::RNNCellBase::clip(const Output<Node>& data) const
{
if (m_clip == 0.f)
{
return data;
return data.as_single_output_node();
}
return make_shared<op::Clamp>(data, -m_clip, m_clip);
......
......@@ -81,8 +81,7 @@ namespace ngraph
///
/// \return Node with element-wise add operation.
///
static std::shared_ptr<Node> add(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
static std::shared_ptr<Node> add(const Output<Node>& lhs, const Output<Node>& rhs);
///
/// \brief Creates node with element-wise subtract operation with numpy broadcasting.
///
......@@ -91,8 +90,7 @@ namespace ngraph
///
/// \return Node with element-wise subtract operation.
///
static std::shared_ptr<Node> sub(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
static std::shared_ptr<Node> sub(const Output<Node>& lhs, const Output<Node>& rhs);
///
/// \brief Creates node with element-wise multiply operation with numpy broadcasting.
///
......@@ -101,8 +99,7 @@ namespace ngraph
///
/// \return Node with element-wise multiply operation.
///
static std::shared_ptr<Node> mul(const std::shared_ptr<Node>& lhs,
const std::shared_ptr<Node>& rhs);
static std::shared_ptr<Node> mul(const Output<Node>& lhs, const Output<Node>& rhs);
///
/// \brief Creates node with element-wise clip operation with numpy broadcasting.
///
......@@ -110,7 +107,7 @@ namespace ngraph
///
/// \return Node with element-wise clip operation.
///
std::shared_ptr<Node> clip(const std::shared_ptr<Node>& data) const;
std::shared_ptr<Node> clip(const Output<Node>& data) const;
private:
const std::size_t m_hidden_size;
......
This diff is collapsed.
......@@ -44,7 +44,12 @@ public:
REVERSE,
PRODUCT,
SUM,
CONCAT
CONCAT,
GATHER,
SLICE,
DYN_SLICE,
DYN_RESHAPE,
TRANSPOSE
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -64,6 +69,11 @@ public:
construct_constant_product();
construct_constant_sum();
construct_constant_concat();
construct_constant_gather();
construct_constant_slice();
construct_constant_dyn_slice();
construct_constant_dyn_reshape();
construct_constant_transpose();
}
//this allows to specify the order in which matchers will be run
......@@ -90,6 +100,11 @@ public:
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break;
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
}
}
}
......@@ -108,6 +123,11 @@ private:
void construct_constant_product();
void construct_constant_sum();
void construct_constant_concat();
void construct_constant_gather();
void construct_constant_slice();
void construct_constant_dyn_slice();
void construct_constant_dyn_reshape();
void construct_constant_transpose();
ngraph::BuildNodeExecutorMap m_cfmap;
};
This diff is collapsed.
......@@ -17,6 +17,7 @@
#pragma once
#include <map>
#include <string>
namespace ngraph
{
......
......@@ -550,6 +550,12 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sign);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Not)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::logical_not);
}
#define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher()
......@@ -627,6 +633,7 @@ namespace ngraph
REGISTER_CF_BUILDER(And);
REGISTER_CF_BUILDER(Or);
REGISTER_CF_BUILDER(Sign);
REGISTER_CF_BUILDER(Not);
}
}
}
......@@ -36,12 +36,14 @@ namespace ngraph
}
}
const std::string ngraph::runtime::plaidml::op::Convolution::type_name{"PlaidMLConvolution"};
ngraph::runtime::plaidml::op::Convolution::Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes)
: Op{"PlaidMLConvolution", args}
: Op{args}
, m_src{std::move(src)}
, m_data_axes{std::move(data_axes)}
, m_filters_axes{std::move(filters_axes)}
......@@ -69,16 +71,19 @@ std::shared_ptr<ngraph::Node>
throw ngraph_error{"PlaidMLConvolution requires two inputs (data and filters)"};
}
return std::make_shared<Convolution>(
m_src, new_args, m_data_axes, m_filters_axes, m_output_axes);
m_src, as_output_vector(new_args), m_data_axes, m_filters_axes, m_output_axes);
}
const std::string ngraph::runtime::plaidml::op::ConvolutionBackpropData::type_name{
"PlaidMLConvolutionBackpropData"};
ngraph::runtime::plaidml::op::ConvolutionBackpropData::ConvolutionBackpropData(
std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
const OutputVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes)
: Op{"PlaidMLConvolutionBackpropData", args}
: Op{args}
, m_src{std::move(src)}
, m_filters_axes{std::move(filters_axes)}
, m_output_axes{std::move(output_axes)}
......@@ -107,16 +112,19 @@ std::shared_ptr<ngraph::Node>
throw ngraph_error{"PlaidMLConvolutionBackpropData requires two inputs (data and output)"};
}
return std::make_shared<ConvolutionBackpropData>(
m_src, new_args, m_filters_axes, m_output_axes, m_data_axes);
m_src, as_output_vector(new_args), m_filters_axes, m_output_axes, m_data_axes);
}
const std::string ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::type_name{
"PlaidMLConvolutionBackpropFilters"};
ngraph::runtime::plaidml::op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes)
: Op{"PlaidMLConvolutionBackpropFilters", args}
: Op{args}
, m_src{std::move(src)}
, m_data_axes{std::move(data_axes)}
, m_output_axes{std::move(output_axes)}
......@@ -146,7 +154,7 @@ std::shared_ptr<ngraph::Node>
"PlaidMLConvolutionBackpropFilters requires two inputs (filters and output)"};
}
return std::make_shared<ConvolutionBackpropFilters>(
m_src, new_args, m_data_axes, m_output_axes, m_filters_axes);
m_src, as_output_vector(new_args), m_data_axes, m_output_axes, m_filters_axes);
}
// Convolution implements a standard ML convolultion, with optional striding, padding, and dilation.
......
......@@ -39,8 +39,11 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Convolution final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Convolution(std::shared_ptr<ngraph::op::Convolution> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector filters_axes,
AxisVector output_axes);
......@@ -63,8 +66,11 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropData final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropData(std::shared_ptr<ngraph::op::ConvolutionBackpropData> src,
const NodeVector& args,
const OutputVector& args,
AxisVector filters_axes,
AxisVector output_axes,
AxisVector data_axes);
......@@ -87,8 +93,11 @@ private:
class ngraph::runtime::plaidml::op::ConvolutionBackpropFilters final : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ConvolutionBackpropFilters(std::shared_ptr<ngraph::op::ConvolutionBackpropFilters> src,
const NodeVector& args,
const OutputVector& args,
AxisVector data_axes,
AxisVector output_axes,
AxisVector filters_axes);
......
......@@ -30,9 +30,11 @@ namespace ngraph
}
}
ngraph::runtime::plaidml::op::ImplicitBroadcast::ImplicitBroadcast(std::shared_ptr<Node> input,
const std::string ngraph::runtime::plaidml::op::ImplicitBroadcast::type_name{"ImplicitBroadcast"};
ngraph::runtime::plaidml::op::ImplicitBroadcast::ImplicitBroadcast(const Output<Node>& input,
const Shape& shape)
: Op{"ImplicitBroadcast", {input}}
: Op{{input}}
, m_shape{shape}
{
constructor_validate_and_infer_types();
......
......@@ -40,7 +40,10 @@ namespace ngraph
class ngraph::runtime::plaidml::op::ImplicitBroadcast final : public ngraph::op::Op
{
public:
ImplicitBroadcast(std::shared_ptr<Node> input, const Shape& shape);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ImplicitBroadcast(const Output<Node>& input, const Shape& shape);
void validate_and_infer_types() final;
......
......@@ -28,22 +28,24 @@ namespace ngraph
}
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
const std::string ngraph::runtime::plaidml::op::Replicate::type_name{"Replicate"};
ngraph::runtime::plaidml::op::Replicate::Replicate(const Output<Node>& arg,
std::size_t replication_axis,
std::size_t replication_count)
: Op{"Replicate", NodeVector{arg}}
, m_replication_axes(arg->get_shape().size(), 1)
: Op{{arg}}
, m_replication_axes(arg.get_shape().size(), 1)
{
m_replication_axes.at(replication_axis) = replication_count;
constructor_validate_and_infer_types();
}
ngraph::runtime::plaidml::op::Replicate::Replicate(std::shared_ptr<Node> arg,
ngraph::runtime::plaidml::op::Replicate::Replicate(const Output<Node>& arg,
std::vector<std::size_t> replication_axes)
: Op{"Replicate", NodeVector{arg}}
: Op{{arg}}
, m_replication_axes(std::move(replication_axes))
{
if (arg->get_shape().size() != m_replication_axes.size())
if (arg.get_shape().size() != m_replication_axes.size())
{
throw ngraph_error{"Replicate requires compatible axes dimensions"};
}
......
......@@ -39,11 +39,12 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Replicate final : public ngraph::op::Op
{
public:
Replicate(std::shared_ptr<Node> arg,
std::size_t replication_axis,
std::size_t replication_count);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Replicate(const Output<Node>& arg, std::size_t replication_axis, std::size_t replication_count);
Replicate(std::shared_ptr<Node> arg, std::vector<std::size_t> replication_axes);
Replicate(const Output<Node>& arg, std::vector<std::size_t> replication_axes);
void validate_and_infer_types() final;
......
......@@ -30,9 +30,11 @@ namespace ngraph
}
}
const std::string ngraph::runtime::plaidml::op::Winograd::type_name{"Winograd"};
ngraph::runtime::plaidml::op::Winograd::Winograd(std::shared_ptr<plaidml::op::Convolution> conv,
const NodeVector& args)
: Op{"Winograd", args}
const OutputVector& args)
: Op{args}
, m_conv{std::move(conv)}
{
constructor_validate_and_infer_types();
......@@ -50,7 +52,7 @@ std::shared_ptr<ngraph::Node>
{
throw ngraph_error{"Winograd requires five inputs (data, filters, A, B, and G)"};
}
return std::make_shared<Winograd>(m_conv, new_args);
return std::make_shared<Winograd>(m_conv, as_output_vector(new_args));
}
void ngraph::runtime::plaidml::ImplWinograd::Apply()
......
......@@ -38,7 +38,10 @@ namespace ngraph
class ngraph::runtime::plaidml::op::Winograd final : public ngraph::op::Op
{
public:
Winograd(std::shared_ptr<Convolution> conv, const NodeVector& args);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Winograd(std::shared_ptr<Convolution> conv, const OutputVector& args);
void validate_and_infer_types() final;
......
......@@ -97,7 +97,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
{
replace_node(target,
std::make_shared<plaidml::op::Convolution>(conv,
NodeVector{lhs, rhs},
OutputVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
......@@ -113,7 +113,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropData>(conv_bp_data,
NodeVector{lhs, rhs},
OutputVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
......@@ -126,13 +126,13 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
std::dynamic_pointer_cast<ngraph::op::ConvolutionBackpropFilters>(node);
if (conv_bp_filters)
{
replace_node(
target,
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(conv_bp_filters,
NodeVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
replace_node(target,
std::make_shared<plaidml::op::ConvolutionBackpropFilters>(
conv_bp_filters,
OutputVector{lhs, rhs},
std::move(lhs_axes),
std::move(rhs_axes),
std::move(out_axes)));
return true;
}
}
......
......@@ -113,7 +113,11 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
auto callback = [](pattern::Matcher& m) {
auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments();
OutputVector args;
for (auto input : conv->inputs())
{
args.push_back(input.get_source_output());
}
std::shared_ptr<ngraph::op::Constant> a;
std::shared_ptr<ngraph::op::Constant> b;
std::shared_ptr<ngraph::op::Constant> g;
......
......@@ -43,8 +43,8 @@ namespace ngraph
// out' = out[out_index] # rank(out') == rank(params')
// gather_nd(params', indices'', out')
template <typename T, typename U>
void gather(T* params,
U* indices,
void gather(const T* params,
const U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
......@@ -148,13 +148,14 @@ namespace ngraph
auto out_outer_coord_iter = out_outer_transform.begin();
for (const Coordinate& params_outer_coord : params_outer_transform)
{
T* params_prime = &params[params_outer_transform.index(params_outer_coord)];
const T* params_prime =
&params[params_outer_transform.index(params_outer_coord)];
T* out_outer = &out[out_outer_transform.index(*out_outer_coord_iter)];
auto out_inner_coord_iter = out_inner_transform.begin();
for (const Coordinate& indices_outer_coord : indices_outer_transform)
{
U* indices_prime =
const U* indices_prime =
&indices[indices_outer_transform.index(indices_outer_coord)];
T* out_prime = &out_outer[out_inner_transform.index(*out_inner_coord_iter)];
gather_nd<T, U>(params_prime,
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include "ngraph/check.hpp"
#include "ngraph/slice_plan.hpp"
using namespace ngraph;
SlicePlan ngraph::make_slice_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask)
{
NGRAPH_CHECK(begins.size() == ends.size());
NGRAPH_CHECK(ends.size() == strides.size());
size_t num_slice_indices = begins.size();
size_t num_real_axes = 0;
size_t num_shrink_axes = 0;
size_t num_new_axes = 0;
bool ellipsis_found = false;
// Make a pass over the original slices to make sure there is at most one
// ellipsis, and to count up the number of shrink axes, the number of
// "newaxis"es, and the number of "real" axes (axes that are not newaxis
// and are not the ellipsis).
for (size_t i = 0; i < num_slice_indices; i++)
{
if (ellipsis_mask.count(i))
{
NGRAPH_CHECK(!ellipsis_found);
ellipsis_found = true;
}
else if (new_axis_mask.count(i))
{
num_new_axes++;
}
else
{
if (shrink_axis_mask.count(i))
{
num_shrink_axes++;
}
num_real_axes++;
}
}
NGRAPH_CHECK(num_real_axes <= input_shape.size(),
"num_real_axes=",
num_real_axes,
", input_shape=",
input_shape);
// Figure out how many axes need to be inserted when the ellipsis (which
// may be an implicit ellipsis at the end) is expanded.
size_t ellipsis_size = input_shape.size() - num_real_axes;
// Initialize our slice plan.
SlicePlan p;
p.begins = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.ends = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.strides = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.reshape_in_shape = Shape(num_real_axes + ellipsis_size);
p.reshape_out_shape = Shape(num_new_axes + num_real_axes + ellipsis_size - num_shrink_axes);
p.reverse_axes = AxisSet{};
// Begin a maddeningly delicate loop to desugar the original slice.
//
// * i_in is iterating over the axes of the input shape, which are also the axes of
// p.reshape_in_shape.
// * i_out is iterating over the axes of p.reshape_out_shape
size_t i_in = 0;
size_t i_out = 0;
// If no actual ellipsis exists, there is an "implicit" one at the end,
// which we will handle after the loop. So the logic is wrapped up here,
// allowing it to be used both during and after the loop.
auto expand_ellipsis = [&]() {
for (size_t i = 0; i < ellipsis_size; i++)
{
p.begins[i_in] = 0;
p.ends[i_in] = int64_t(input_shape[i_in]);
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = input_shape[i_in];
p.reshape_out_shape[i_out] = input_shape[i_in];
i_in++;
i_out++;
}
};
for (size_t i = 0; i < num_slice_indices; i++)
{
// If this is a "newaxis", then reshape_out_shape will have a 1 here,
// but reshape_in_shape will not.
if (new_axis_mask.count(i))
{
p.reshape_out_shape[i_out] = 1;
i_out++;
}
// If this is a "shrunken" axis, then reshape_in_shape will have a 1
// here, but reshape_out_shape will not.
else if (shrink_axis_mask.count(i))
{
int64_t begin = begins[i];
// Note that clipping is not used for "shrunken" axes: an
// out-of-bounds index is an error.
NGRAPH_CHECK(begin >= -(int64_t(input_shape[i_in])) &&
begin < int64_t(input_shape[i_in]));
if (begin < 0)
{
begin += int64_t(input_shape[i_in]);
}
p.begins[i_in] = begin;
p.ends[i_in] = begin + 1;
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = 1;
i_in++;
}
// If this is the ellipsis, expand it.
else if (ellipsis_mask.count(i))
{
expand_ellipsis();
}
// In other cases, we have a nice, ordinary (begin:end:stride) slice.
// We need to adjust for begin/end being masked, and begin/end/stride
// being negative or out of bounds.
else
{
bool is_reverse = strides[i] < 0;
// Adjust the beginning for from-the-right indexing, and clip.
int64_t real_begin = begins[i];
if (lower_bounds_mask.count(i))
{
real_begin = (is_reverse ? int64_t(input_shape[i_in] - 1) : 0);
}
else if (real_begin < 0)
{
real_begin += int64_t(input_shape[i_in]);
}
int64_t max_real_begin = int64_t(input_shape[i_in]) - (is_reverse ? 1 : 0);
real_begin = std::max(int64_t(0), std::min(max_real_begin, real_begin));
// Adjust the ending for from-the-right indexing, and clip.
int64_t real_end = ends[i];
if (upper_bounds_mask.count(i))
{
real_end = (is_reverse ? -1 : int64_t(input_shape[i_in]));
}
else if (real_end < 0)
{
real_end += int64_t(input_shape[i_in]);
}
int64_t min_real_end = (is_reverse ? -1 : 0);
real_end = std::max(min_real_end, std::min(int64_t(input_shape[i_in]), real_end));
// Ensure stride is not zero, and adjust it for backwards slicing.
NGRAPH_CHECK(strides[i] != 0);
int64_t real_stride = std::abs(strides[i]);
// Adjust for reversal if needed. This isn't quite as simple as swapping begin and
// end, due to striding; we have to adjust the end point to be the _actual_ leftmost
// element, in cases where the stride does not evenly divide the span between begin
// and end.
if (is_reverse)
{
real_end += std::max(int64_t(0), real_begin - real_end - 1) % real_stride;
std::swap(real_begin, real_end);
real_begin++;
real_end++;
p.reverse_axes.insert(i_out);
}
// nGraph's slice op does not like it when end < begin, so we truncate for that case
// here.
if (real_end < real_begin)
{
real_end = real_begin;
}
// Compute output dimension.
size_t dim = (real_end <= real_begin
? 0
: size_t(real_end - real_begin - 1) / size_t(real_stride) + 1);
p.reshape_in_shape[i_in] = dim;
p.reshape_out_shape[i_out] = dim;
// Set up the begin/end/stride.
p.begins[i_in] = real_begin;
p.ends[i_in] = real_end;
p.strides[i_in] = real_stride;
i_in++;
i_out++;
}
}
// If there was no ellipsis explicitly given, there is an implicit one at
// the end (it might encompass zero axes, but that's fine).
if (!ellipsis_found)
{
expand_ellipsis();
}
return p;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <set>
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
//
// In various places, like ConstantFolding and DynElimination, it is
// useful to transform DynSlice by converting it to a sequence of ops:
//
// Slice (to do the basic slicing)
// |
// v
// Reshape (non-transposing, to handle shrinks)
// |
// v
// Reverse (to emulate backwards stride)
//
// (The Reshape, Reverse, or both may be omitted if they would just be
// identities.)
//
// A SlicePlan is used to collect parameters for these ops.
//
struct SlicePlan
{
// Parameters for the Slice
std::vector<int64_t> begins;
std::vector<int64_t> ends;
std::vector<int64_t> strides;
// Shapes coming into, and going out of, the Reshape.
Shape reshape_in_shape;
Shape reshape_out_shape;
// Parameters for the Reverse
AxisSet reverse_axes;
};
SlicePlan make_slice_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask);
}
......@@ -459,6 +459,30 @@ TEST(constant_folding, const_concat)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_not)
{
auto constant =
op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
auto logical_not = make_shared<op::Not>(constant);
auto f = make_shared<Function>(logical_not, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 0, 1, 1, 0, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_equal)
{
auto constant0 =
......@@ -715,6 +739,158 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_gather)
{
auto constant_data = op::Constant::create(
element::f32,
Shape{2, 5},
vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
auto constant_indices =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
size_t gather_axis = 1;
auto gather = make_shared<op::Gather>(constant_data, constant_indices, gather_axis);
auto f = make_shared<Function>(gather, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Gather>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_slice)
{
Shape shape_in{16};
vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
auto f = make_shared<Function>(slice, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> sliced_values{3, 6, 9, 12, 15};
ASSERT_EQ(sliced_values, values_out);
}
TEST(constant_folding, const_dyn_slice)
{
Shape shape_in{16};
vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto constant_data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_lb{2};
auto constant_lb = make_shared<op::Constant>(element::i64, Shape{1}, values_lb);
vector<int> values_ub{15};
auto constant_ub = make_shared<op::Constant>(element::i64, Shape{1}, values_ub);
vector<int> values_strides{3};
auto constant_strides = make_shared<op::Constant>(element::i64, Shape{1}, values_strides);
auto dyn_slice = make_shared<op::DynSlice>(constant_data,
constant_lb,
constant_ub,
constant_strides,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto f = make_shared<Function>(dyn_slice, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynSlice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> sliced_values{3, 6, 9, 12, 15};
ASSERT_EQ(sliced_values, values_out);
}
TEST(constant_folding, constant_dyn_reshape)
{
Shape shape_in{2, 4};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_shape{3};
vector<int64_t> values_shape{2, 4, 1};
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::DynReshape>(constant_in, constant_shape);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_transpose)
{
Shape shape_in{2, 4};
vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_perm{2};
vector<int64_t> values_perm{1, 0};
auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<double>();
vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
......@@ -1218,11 +1218,12 @@ TEST(cpu_test, constant_unary_binary)
auto logical_or = make_shared<op::Or>(i, j);
auto ceil = make_shared<op::Ceiling>(k);
auto floor = make_shared<op::Floor>(k);
auto logical_not = make_shared<op::Not>(j);
auto func = make_shared<Function>(
NodeVector{add, sub, mul, divn, min, max, absn,
neg, sqrt, relu, sign, equal, not_equal, greater,
greater_eq, less, less_eq, logical_and, logical_or, ceil, floor},
NodeVector{add, sub, mul, divn, min, max, absn, neg,
sqrt, relu, sign, equal, not_equal, greater, greater_eq, less,
less_eq, logical_and, logical_or, ceil, floor, logical_not},
ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
......@@ -1253,6 +1254,7 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(count_ops_of_type<op::Or>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Ceiling>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Floor>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Not>(func), 0);
//expected values
vector<int> add_expected{2, 4, 6, 8};
......@@ -1275,6 +1277,7 @@ TEST(cpu_test, constant_unary_binary)
vector<char> or_expected{0, 1, 1, 1};
vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f};
vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f};
vector<char> not_expected{1, 0, 1, 0};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
......@@ -1299,6 +1302,7 @@ TEST(cpu_test, constant_unary_binary)
get_result_constant<float>(func, 19), ceil_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 20), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_EQ(get_result_constant<char>(func, 21), not_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment