Unverified Commit 2775b0bf authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

make CPU emit functions static so they can be called by other backends (#376)

parent c682fbf4
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#define EMITTER_DECL(E) \ #define EMITTER_DECL(E) \
E(const ngraph::Node* n, \ E(codegen::CodeWriter& writer, \
const ngraph::Node* n, \
const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& args, \ const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& args, \
const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& out) const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& out)
...@@ -35,79 +36,67 @@ namespace ngraph ...@@ -35,79 +36,67 @@ namespace ngraph
{ {
class CPU_Emitter class CPU_Emitter
{ {
protected:
codegen::CodeWriter m_out;
bool m_use_ref_kernels;
public: public:
CPU_Emitter() static void EMITTER_DECL(EmitNop);
: m_out() static void EMITTER_DECL(EmitAdd);
, m_use_ref_kernels(std::getenv("NGRAPH_CPU_USE_REF_KERNELS") != nullptr) static void EMITTER_DECL(EmitDot);
{ static void EMITTER_DECL(EmitMultiply);
} static void EMITTER_DECL(EmitGetOutputElement);
std::string get_code() { return m_out.get_code(); } static void EMITTER_DECL(EmitXLAGetTupleElement);
codegen::CodeWriter& get_code_writer() { return m_out; } static void EMITTER_DECL(EmitTuple);
void EMITTER_DECL(EmitNop); static void EMITTER_DECL(EmitAbs);
void EMITTER_DECL(EmitAdd); static void EMITTER_DECL(EmitConcat);
void EMITTER_DECL(EmitDot); static void EMITTER_DECL(EmitDivide);
void EMITTER_DECL(EmitMultiply); static void EMITTER_DECL(EmitEqual);
void EMITTER_DECL(EmitGetOutputElement); static void EMITTER_DECL(EmitGreater);
void EMITTER_DECL(EmitXLAGetTupleElement); static void EMITTER_DECL(EmitGreaterEq);
void EMITTER_DECL(EmitTuple); static void EMITTER_DECL(EmitLess);
void EMITTER_DECL(EmitAbs); static void EMITTER_DECL(EmitLessEq);
void EMITTER_DECL(EmitConcat); static void EMITTER_DECL(EmitLog);
void EMITTER_DECL(EmitDivide); static void EMITTER_DECL(EmitMaximum);
void EMITTER_DECL(EmitEqual); static void EMITTER_DECL(EmitMinimum);
void EMITTER_DECL(EmitGreater); static void EMITTER_DECL(EmitNegative);
void EMITTER_DECL(EmitGreaterEq); static void EMITTER_DECL(EmitNotEqual);
void EMITTER_DECL(EmitLess); static void EMITTER_DECL(EmitSelect);
void EMITTER_DECL(EmitLessEq); static void EMITTER_DECL(EmitSubtract);
void EMITTER_DECL(EmitLog); static void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitMaximum); static void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitMinimum); static void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitNegative); static void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitNotEqual); static void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitSelect); static void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSubtract); static void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitBroadcast); static void EMITTER_DECL(EmitSlice);
void EMITTER_DECL(EmitConvert); static void EMITTER_DECL(EmitSum);
void EMITTER_DECL(EmitConstant); static void EMITTER_DECL(EmitExp);
void EMITTER_DECL(EmitReshape); static void EMITTER_DECL(EmitSin);
void EMITTER_DECL(EmitFunctionCall); static void EMITTER_DECL(EmitSinh);
void EMITTER_DECL(EmitReduce); static void EMITTER_DECL(EmitCos);
void EMITTER_DECL(EmitSign); static void EMITTER_DECL(EmitCosh);
void EMITTER_DECL(EmitSlice); static void EMITTER_DECL(EmitTan);
void EMITTER_DECL(EmitSum); static void EMITTER_DECL(EmitTanh);
void EMITTER_DECL(EmitExp); static void EMITTER_DECL(EmitAsin);
void EMITTER_DECL(EmitSin); static void EMITTER_DECL(EmitAcos);
void EMITTER_DECL(EmitSinh); static void EMITTER_DECL(EmitAtan);
void EMITTER_DECL(EmitCos); static void EMITTER_DECL(EmitPower);
void EMITTER_DECL(EmitCosh); static void EMITTER_DECL(EmitReplaceSlice);
void EMITTER_DECL(EmitTan); static void EMITTER_DECL(EmitOneHot);
void EMITTER_DECL(EmitTanh); static void EMITTER_DECL(EmitFloor);
void EMITTER_DECL(EmitAsin); static void EMITTER_DECL(EmitCeiling);
void EMITTER_DECL(EmitAcos); static void EMITTER_DECL(EmitSqrt);
void EMITTER_DECL(EmitAtan); static void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitPower); static void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitReplaceSlice); static void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitOneHot); static void EMITTER_DECL(EmitReverse);
void EMITTER_DECL(EmitFloor); static void EMITTER_DECL(EmitReduceWindow);
void EMITTER_DECL(EmitCeiling);
void EMITTER_DECL(EmitSqrt);
void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitReverse);
void EMITTER_DECL(EmitReduceWindow);
private: private:
void generate_call(const std::vector<TensorViewWrapper>& args, static std::string emit_vector(const TensorViewWrapper&,
const std::vector<TensorViewWrapper>& out, const std::string& name = "");
std::shared_ptr<Function> function); static std::string emit_array1d(const TensorViewWrapper&,
const std::string& name = "");
std::string emit_vector(const TensorViewWrapper&, const std::string& name = ""); static std::string emit_matrix(const TensorViewWrapper&,
std::string emit_array1d(const TensorViewWrapper&, const std::string& name = ""); const std::string& name = "");
std::string emit_matrix(const TensorViewWrapper&, const std::string& name = "");
}; };
} }
} }
......
...@@ -216,8 +216,7 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -216,8 +216,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<pass::DumpSorted>(dump_filename); pass_manager.register_pass<pass::DumpSorted>(dump_filename);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
CPU_Emitter emitter; codegen::CodeWriter writer;
codegen::CodeWriter& writer = emitter.get_code_writer();
writer += writer +=
R"(// Generated by the NGraph CPU backend R"(// Generated by the NGraph CPU backend
...@@ -432,7 +431,7 @@ using namespace ngraph::runtime; ...@@ -432,7 +431,7 @@ using namespace ngraph::runtime;
writer << "\n)\n"; writer << "\n)\n";
writer << "{\n"; writer << "{\n";
writer.indent++; writer.indent++;
handler->second(&emitter, &n, in, out); handler->second(writer, &n, in, out);
writer.indent--; writer.indent--;
writer << "}\n"; writer << "}\n";
} }
...@@ -632,7 +631,7 @@ using namespace ngraph::runtime; ...@@ -632,7 +631,7 @@ using namespace ngraph::runtime;
} }
if (func_name.empty()) if (func_name.empty())
{ {
handler->second(&emitter, node.get(), in, out); handler->second(writer, node.get(), in, out);
} }
else else
{ {
......
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
class CPU_Emitter; class CPU_Emitter;
class CPU_CallFrame; class CPU_CallFrame;
using OpFunction = std::function<void(CPU_Emitter*, using OpFunction = std::function<void(codegen::CodeWriter&,
const ngraph::Node*, const ngraph::Node*,
const std::vector<TensorViewWrapper>& inputs, const std::vector<TensorViewWrapper>& inputs,
const std::vector<TensorViewWrapper>& outputs)>; const std::vector<TensorViewWrapper>& outputs)>;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment