Commit a8bc57cb authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Extend compile API to take in PassConfig object (#2516)

* Extend compile API to take in PassConfig object

* fix override warning

* remove extra semicolon

* cmake fixes to support cases where include_path has multiple directories

* Help pybind find the overloaded compile methods

* Limit compile-api exposed through PyBind (#2530)

* clang-format

* Remove setter for compilation mode to prevent post-init changes. Add compile-mode warning.

* Removed pass_config include

* fix merge
parent b2259d88
......@@ -23,6 +23,13 @@
namespace py = pybind11;
static std::shared_ptr<ngraph::runtime::Executable> compile(ngraph::runtime::Backend* self,
std::shared_ptr<ngraph::Function> func)
{
bool enable_performance_data = false;
return self->compile(func, enable_performance_data);
}
void regclass_pyngraph_runtime_Backend(py::module m)
{
py::class_<ngraph::runtime::Backend, std::unique_ptr<ngraph::runtime::Backend>> backend(
......@@ -34,8 +41,5 @@ void regclass_pyngraph_runtime_Backend(py::module m)
(std::shared_ptr<ngraph::runtime::Tensor>(ngraph::runtime::Backend::*)(
const ngraph::element::Type&, const ngraph::Shape&)) &
ngraph::runtime::Backend::create_tensor);
backend.def("compile",
(std::shared_ptr<ngraph::runtime::Executable>(ngraph::runtime::Backend::*)(
std::shared_ptr<ngraph::Function>)) &
ngraph::runtime::Backend::compile);
backend.def("compile", &compile);
}
......@@ -54,7 +54,10 @@ if(NGRAPH_DISTRIBUTED_ENABLE)
list(APPEND HEADER_SEARCH_DEFINES MLSL_HEADER_PATH="${MLSL_INCLUDE_DIR}")
elseif(NGRAPH_DISTRIBUTED_OMPI_ENABLE)
find_package(MPI REQUIRED)
add_definitions(-DMPI_HEADER_PATH="${MPI_PATH}")
# MPI_C_INCLUDE_PATH might have a list of directories separated by a semicolon
# Escape the semicolon to prevent cmake from interpreting the string as a list
string(REPLACE ";" "\\\;" MPI_C_INCLUDE_PATH_ESCAPED "${MPI_C_INCLUDE_PATH}")
list(APPEND HEADER_SEARCH_DEFINES MPI_HEADER_PATH="${MPI_C_INCLUDE_PATH_ESCAPED}")
else()
message(FATAL_ERROR "Distributed Library not supported/mentioned")
endif()
......
......@@ -22,7 +22,8 @@
using namespace ngraph;
// TODO: Add file-based configuration support
ngraph::pass::PassConfig::PassConfig()
ngraph::pass::PassConfig::PassConfig(ngraph::pass::CompilationMode mode)
: m_compilation_mode(mode)
{
/**
* Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
......
......@@ -23,21 +23,27 @@ namespace ngraph
namespace pass
{
class PassConfig;
enum class CompilationMode
{
DEX, // Fast compilation using precompiled kernels
CODEGEN // Slower compilation for potentially faster code
};
}
}
class ngraph::pass::PassConfig
{
public:
PassConfig();
PassConfig(CompilationMode mode = CompilationMode::DEX);
const std::map<std::string, bool>& get_enables() { return m_pass_enables; }
void set_pass_enable(std::string name, bool enable);
bool get_pass_enable(std::string name);
const std::map<std::string, bool>& get_pass_attributes() { return m_pass_attributes; }
void set_pass_attribute(std::string name, bool enable);
bool get_pass_attribute(std::string name);
CompilationMode get_compilation_mode() const { return m_compilation_mode; }
private:
std::map<std::string, bool> m_pass_enables;
std::map<std::string, bool> m_pass_attributes;
CompilationMode m_compilation_mode;
};
......@@ -39,6 +39,14 @@ vector<string> runtime::Backend::get_registered_devices()
return BackendManager::get_registered_backends();
}
std::shared_ptr<runtime::Executable>
runtime::Backend::compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool enable_performance_data)
{
return compile(func, enable_performance_data);
}
bool runtime::Backend::is_supported(const Node& node) const
{
// The default behavior is that a backend does not support any ops. If this is not the case
......
......@@ -19,6 +19,7 @@
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/pass/pass_config.hpp"
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -84,6 +85,14 @@ public:
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) = 0;
/// \brief Compiles a Function.
/// \param func The function to compile
/// \param pass_config Configuration object for defining compilation options
/// \returns compiled function or nullptr on failure
virtual std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool enable_performance_data = false);
/// \brief Test if a backend is capable of supporting an op
/// \param node is the op to test.
/// \returns true if the op is supported, false otherwise.
......
......@@ -55,7 +55,7 @@ namespace ngraph
std::shared_ptr<ngraph::runtime::Executable>
compile(std::shared_ptr<Function> func,
ngraph::pass::PassConfig& pass_config,
bool enable_performance_counters = false);
bool enable_performance_counters = false) override;
void remove_compiled_function(std::shared_ptr<Executable> exec) override;
......
......@@ -1702,7 +1702,18 @@ void*& runtime::cpu::CPU_ExternalFunction::get_tensor_data(const std::string& na
shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config)
{
#if !defined(NGRAPH_DEX_ONLY)
#if defined(NGRAPH_DEX_ONLY)
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN)
{
NGRAPH_WARN << "CPU Backend: Requested unsupported compilation mode (CODEGEN). Falling "
"back to DEX instead";
}
#else
// Override DEX if pass_config requests CODEGEN
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN)
{
m_direct_execution = false;
}
if (!m_is_compiled && !m_direct_execution)
{
compile(pass_config);
......
......@@ -727,12 +727,10 @@ TEST(cpu_test, memory_reuse_mxnet_densenet121)
ngraph::pass::PassConfig pass_config;
pass_config.set_pass_attribute("CPUMemoryAssignment::ReuseMemory", true);
auto cpu_backend = std::unique_ptr<runtime::cpu::CPU_Backend>(
static_cast<runtime::cpu::CPU_Backend*>(backend.release()));
auto cpu_f_new_reuse = make_function(file_name);
shared_ptr<runtime::Executable> handle = cpu_backend->compile(cpu_f_new_reuse, pass_config);
shared_ptr<runtime::Executable> handle = backend->compile(cpu_f_new_reuse, pass_config);
for (auto it = 0; it < 2; it++)
{
handle->call_with_validate(result_tensors, arg_tensors);
......@@ -891,6 +889,41 @@ TEST(cpu_test, convert_inplace)
EXPECT_EQ((vector<int8_t>{1, 2, 3, -2}), read_vector<int8_t>(result));
}
TEST(cpu_test, abc_codegen)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((A + B) * C, ParameterVector{A, B, C});
auto backend = runtime::Backend::create("CPU");
// Create some tensors for input/output
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
copy_data(a, test::NDArray<float, 2>({{1, 2}, {3, 4}}).get_vector());
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
ngraph::pass::PassConfig pass_config{ngraph::pass::CompilationMode::CODEGEN};
auto handle = backend->compile(f, pass_config);
handle->call_with_validate({result}, {a, b, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
handle->call_with_validate({result}, {b, a, c});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{54, 80}, {110, 144}})).get_vector());
handle->call_with_validate({result}, {a, c, b});
EXPECT_EQ(read_vector<float>(result),
(test::NDArray<float, 2>({{50, 72}, {98, 128}})).get_vector());
}
TEST(cpu_test, rotated_pooling)
{
auto make_f = [&](bool is_4d, bool avgpool) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment