Unverified Commit d77ace68 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Change CPU backend to use PassConfig for codegen enable (#2826)

* Change CPU backend to use PassConfig instead of backdoor for codegen

* move is_codegen decl to output dex only flag

* revert change
parent bf865efd
...@@ -23,8 +23,7 @@ using namespace std; ...@@ -23,8 +23,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
// TODO: Add file-based configuration support // TODO: Add file-based configuration support
pass::PassConfig::PassConfig(pass::CompilationMode mode) pass::PassConfig::PassConfig()
: m_compilation_mode(mode)
{ {
/** /**
* Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES * Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
...@@ -81,30 +80,32 @@ pass::PassConfig::PassConfig(pass::CompilationMode mode) ...@@ -81,30 +80,32 @@ pass::PassConfig::PassConfig(pass::CompilationMode mode)
} }
} }
void pass::PassConfig::set_pass_enable(string name, bool enable) void pass::PassConfig::set_pass_enable(const string& name, bool enable)
{ {
m_pass_enables[name] = enable; m_pass_enables[name] = enable;
} }
bool pass::PassConfig::get_pass_enable(string name) bool pass::PassConfig::get_pass_enable(const string& name) const
{ {
if (m_pass_enables.find(name) == m_pass_enables.end()) auto it = m_pass_enables.find(name);
if (it != m_pass_enables.end())
{ {
return false; return it->second;
} }
return m_pass_enables[name]; return false;
} }
void pass::PassConfig::set_pass_attribute(string name, bool enable) void pass::PassConfig::set_pass_attribute(const string& name, bool enable)
{ {
m_pass_attributes[name] = enable; m_pass_attributes[name] = enable;
} }
bool pass::PassConfig::get_pass_attribute(string name) bool pass::PassConfig::get_pass_attribute(const string& name) const
{ {
if (m_pass_attributes.find(name) == m_pass_attributes.end()) auto it = m_pass_attributes.find(name);
if (it != m_pass_attributes.end())
{ {
return false; return it->second;
} }
return m_pass_attributes[name]; return false;
} }
...@@ -23,27 +23,21 @@ namespace ngraph ...@@ -23,27 +23,21 @@ namespace ngraph
namespace pass namespace pass
{ {
class PassConfig; class PassConfig;
enum class CompilationMode
{
DEX, // Fast compilation using precompiled kernels
CODEGEN // Slower compilation for potentially faster code
};
} }
} }
class ngraph::pass::PassConfig class ngraph::pass::PassConfig
{ {
public: public:
PassConfig(CompilationMode mode = CompilationMode::DEX); PassConfig();
const std::map<std::string, bool>& get_enables() { return m_pass_enables; } const std::map<std::string, bool>& get_enables() const { return m_pass_enables; }
void set_pass_enable(std::string name, bool enable); void set_pass_enable(const std::string& name, bool enable);
bool get_pass_enable(std::string name); bool get_pass_enable(const std::string& name) const;
const std::map<std::string, bool>& get_pass_attributes() { return m_pass_attributes; } const std::map<std::string, bool>& get_pass_attributes() const { return m_pass_attributes; }
void set_pass_attribute(std::string name, bool enable); void set_pass_attribute(const std::string& name, bool enable);
bool get_pass_attribute(std::string name); bool get_pass_attribute(const std::string& name) const;
CompilationMode get_compilation_mode() const { return m_compilation_mode; }
private: private:
std::map<std::string, bool> m_pass_enables; std::map<std::string, bool> m_pass_enables;
std::map<std::string, bool> m_pass_attributes; std::map<std::string, bool> m_pass_attributes;
CompilationMode m_compilation_mode;
}; };
...@@ -1776,18 +1776,29 @@ size_t runtime::cpu::CPU_ExternalFunction::get_buffer_index(const std::string& n ...@@ -1776,18 +1776,29 @@ size_t runtime::cpu::CPU_ExternalFunction::get_buffer_index(const std::string& n
} }
} }
bool runtime::cpu::CPU_ExternalFunction::is_codegen(const ngraph::pass::PassConfig& pc)
{
auto attrs = pc.get_pass_attributes();
auto it = attrs.find("CODEGEN");
if (it != attrs.end())
{
return it->second;
}
return false;
}
shared_ptr<ngraph::runtime::cpu::CPU_CallFrame> shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config) runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config)
{ {
#if defined(NGRAPH_DEX_ONLY) #if defined(NGRAPH_DEX_ONLY)
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN) if (is_codegen(pass_config))
{ {
NGRAPH_WARN << "CPU Backend: Requested unsupported compilation mode (CODEGEN). Falling " NGRAPH_WARN << "CPU Backend: Requested unsupported compilation mode (CODEGEN). Falling "
"back to DEX instead"; "back to DEX instead";
} }
#else #else
// Override DEX if pass_config requests CODEGEN // Override DEX if pass_config requests CODEGEN
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN) if (is_codegen(pass_config))
{ {
m_direct_execution = false; m_direct_execution = false;
} }
......
...@@ -225,6 +225,7 @@ namespace ngraph ...@@ -225,6 +225,7 @@ namespace ngraph
// so they don't get freed before we are done with them // so they don't get freed before we are done with them
std::vector<std::shared_ptr<Node>> m_active_constants; std::vector<std::shared_ptr<Node>> m_active_constants;
#endif #endif
static bool is_codegen(const ngraph::pass::PassConfig& pc);
std::unordered_set<descriptor::Tensor*>& std::unordered_set<descriptor::Tensor*>&
get_tensor_set(descriptor::Tensor* output_tensor); get_tensor_set(descriptor::Tensor* output_tensor);
......
...@@ -42,7 +42,8 @@ TEST(cpu_codegen, abc) ...@@ -42,7 +42,8 @@ TEST(cpu_codegen, abc)
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector()); copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector()); copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
ngraph::pass::PassConfig pass_config{ngraph::pass::CompilationMode::CODEGEN}; ngraph::pass::PassConfig pass_config;
pass_config.set_pass_attribute("CODEGEN", true);
auto handle = backend->compile(f, pass_config); auto handle = backend->compile(f, pass_config);
handle->call_with_validate({result}, {a, b, c}); handle->call_with_validate({result}, {a, b, c});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment