Unverified Commit d77ace68 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Change CPU backend to use PassConfig for codegen enable (#2826)

* Change CPU backend to use PassConfig instead of backdoor for codegen

* move is_codegen decl to output dex only flag

* revert change
parent bf865efd
......@@ -23,8 +23,7 @@ using namespace std;
using namespace ngraph;
// TODO: Add file-based configuration support
pass::PassConfig::PassConfig(pass::CompilationMode mode)
: m_compilation_mode(mode)
pass::PassConfig::PassConfig()
{
/**
* Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES
......@@ -81,30 +80,32 @@ pass::PassConfig::PassConfig(pass::CompilationMode mode)
}
}
void pass::PassConfig::set_pass_enable(string name, bool enable)
void pass::PassConfig::set_pass_enable(const string& name, bool enable)
{
m_pass_enables[name] = enable;
}
bool pass::PassConfig::get_pass_enable(string name)
bool pass::PassConfig::get_pass_enable(const string& name) const
{
if (m_pass_enables.find(name) == m_pass_enables.end())
auto it = m_pass_enables.find(name);
if (it != m_pass_enables.end())
{
return false;
return it->second;
}
return m_pass_enables[name];
return false;
}
void pass::PassConfig::set_pass_attribute(string name, bool enable)
void pass::PassConfig::set_pass_attribute(const string& name, bool enable)
{
m_pass_attributes[name] = enable;
}
bool pass::PassConfig::get_pass_attribute(string name)
bool pass::PassConfig::get_pass_attribute(const string& name) const
{
if (m_pass_attributes.find(name) == m_pass_attributes.end())
auto it = m_pass_attributes.find(name);
if (it != m_pass_attributes.end())
{
return false;
return it->second;
}
return m_pass_attributes[name];
return false;
}
......@@ -23,27 +23,21 @@ namespace ngraph
namespace pass
{
class PassConfig;
enum class CompilationMode
{
DEX, // Fast compilation using precompiled kernels
CODEGEN // Slower compilation for potentially faster code
};
}
}
class ngraph::pass::PassConfig
{
public:
PassConfig(CompilationMode mode = CompilationMode::DEX);
const std::map<std::string, bool>& get_enables() { return m_pass_enables; }
void set_pass_enable(std::string name, bool enable);
bool get_pass_enable(std::string name);
const std::map<std::string, bool>& get_pass_attributes() { return m_pass_attributes; }
void set_pass_attribute(std::string name, bool enable);
bool get_pass_attribute(std::string name);
CompilationMode get_compilation_mode() const { return m_compilation_mode; }
PassConfig();
const std::map<std::string, bool>& get_enables() const { return m_pass_enables; }
void set_pass_enable(const std::string& name, bool enable);
bool get_pass_enable(const std::string& name) const;
const std::map<std::string, bool>& get_pass_attributes() const { return m_pass_attributes; }
void set_pass_attribute(const std::string& name, bool enable);
bool get_pass_attribute(const std::string& name) const;
private:
std::map<std::string, bool> m_pass_enables;
std::map<std::string, bool> m_pass_attributes;
CompilationMode m_compilation_mode;
};
......@@ -1776,18 +1776,29 @@ size_t runtime::cpu::CPU_ExternalFunction::get_buffer_index(const std::string& n
}
}
bool runtime::cpu::CPU_ExternalFunction::is_codegen(const ngraph::pass::PassConfig& pc)
{
auto attrs = pc.get_pass_attributes();
auto it = attrs.find("CODEGEN");
if (it != attrs.end())
{
return it->second;
}
return false;
}
shared_ptr<ngraph::runtime::cpu::CPU_CallFrame>
runtime::cpu::CPU_ExternalFunction::make_call_frame(ngraph::pass::PassConfig& pass_config)
{
#if defined(NGRAPH_DEX_ONLY)
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN)
if (is_codegen(pass_config))
{
NGRAPH_WARN << "CPU Backend: Requested unsupported compilation mode (CODEGEN). Falling "
"back to DEX instead";
}
#else
// Override DEX if pass_config requests CODEGEN
if (pass_config.get_compilation_mode() == ngraph::pass::CompilationMode::CODEGEN)
if (is_codegen(pass_config))
{
m_direct_execution = false;
}
......
......@@ -225,6 +225,7 @@ namespace ngraph
// so they don't get freed before we are done with them
std::vector<std::shared_ptr<Node>> m_active_constants;
#endif
static bool is_codegen(const ngraph::pass::PassConfig& pc);
std::unordered_set<descriptor::Tensor*>&
get_tensor_set(descriptor::Tensor* output_tensor);
......
......@@ -42,7 +42,8 @@ TEST(cpu_codegen, abc)
copy_data(b, test::NDArray<float, 2>({{5, 6}, {7, 8}}).get_vector());
copy_data(c, test::NDArray<float, 2>({{9, 10}, {11, 12}}).get_vector());
ngraph::pass::PassConfig pass_config{ngraph::pass::CompilationMode::CODEGEN};
ngraph::pass::PassConfig pass_config;
pass_config.set_pass_attribute("CODEGEN", true);
auto handle = backend->compile(f, pass_config);
handle->call_with_validate({result}, {a, b, c});
EXPECT_TRUE(test::all_close_f(read_vector<float>(result),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment