Unverified Commit 2775b0bf authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

make CPU emit functions static so they can be called by other backends (#376)

parent c682fbf4
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -23,7 +23,8 @@
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#define EMITTER_DECL(E) \
E(const ngraph::Node* n, \
E(codegen::CodeWriter& writer, \
const ngraph::Node* n, \
const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& args, \
const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& out)
......@@ -35,79 +36,67 @@ namespace ngraph
{
class CPU_Emitter
{
protected:
codegen::CodeWriter m_out;
bool m_use_ref_kernels;
public:
CPU_Emitter()
: m_out()
, m_use_ref_kernels(std::getenv("NGRAPH_CPU_USE_REF_KERNELS") != nullptr)
{
}
std::string get_code() { return m_out.get_code(); }
codegen::CodeWriter& get_code_writer() { return m_out; }
void EMITTER_DECL(EmitNop);
void EMITTER_DECL(EmitAdd);
void EMITTER_DECL(EmitDot);
void EMITTER_DECL(EmitMultiply);
void EMITTER_DECL(EmitGetOutputElement);
void EMITTER_DECL(EmitXLAGetTupleElement);
void EMITTER_DECL(EmitTuple);
void EMITTER_DECL(EmitAbs);
void EMITTER_DECL(EmitConcat);
void EMITTER_DECL(EmitDivide);
void EMITTER_DECL(EmitEqual);
void EMITTER_DECL(EmitGreater);
void EMITTER_DECL(EmitGreaterEq);
void EMITTER_DECL(EmitLess);
void EMITTER_DECL(EmitLessEq);
void EMITTER_DECL(EmitLog);
void EMITTER_DECL(EmitMaximum);
void EMITTER_DECL(EmitMinimum);
void EMITTER_DECL(EmitNegative);
void EMITTER_DECL(EmitNotEqual);
void EMITTER_DECL(EmitSelect);
void EMITTER_DECL(EmitSubtract);
void EMITTER_DECL(EmitBroadcast);
void EMITTER_DECL(EmitConvert);
void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitSlice);
void EMITTER_DECL(EmitSum);
void EMITTER_DECL(EmitExp);
void EMITTER_DECL(EmitSin);
void EMITTER_DECL(EmitSinh);
void EMITTER_DECL(EmitCos);
void EMITTER_DECL(EmitCosh);
void EMITTER_DECL(EmitTan);
void EMITTER_DECL(EmitTanh);
void EMITTER_DECL(EmitAsin);
void EMITTER_DECL(EmitAcos);
void EMITTER_DECL(EmitAtan);
void EMITTER_DECL(EmitPower);
void EMITTER_DECL(EmitReplaceSlice);
void EMITTER_DECL(EmitOneHot);
void EMITTER_DECL(EmitFloor);
void EMITTER_DECL(EmitCeiling);
void EMITTER_DECL(EmitSqrt);
void EMITTER_DECL(EmitConvolution);
void EMITTER_DECL(EmitNot);
void EMITTER_DECL(EmitMaxPool);
void EMITTER_DECL(EmitReverse);
void EMITTER_DECL(EmitReduceWindow);
static void EMITTER_DECL(EmitNop);
static void EMITTER_DECL(EmitAdd);
static void EMITTER_DECL(EmitDot);
static void EMITTER_DECL(EmitMultiply);
static void EMITTER_DECL(EmitGetOutputElement);
static void EMITTER_DECL(EmitXLAGetTupleElement);
static void EMITTER_DECL(EmitTuple);
static void EMITTER_DECL(EmitAbs);
static void EMITTER_DECL(EmitConcat);
static void EMITTER_DECL(EmitDivide);
static void EMITTER_DECL(EmitEqual);
static void EMITTER_DECL(EmitGreater);
static void EMITTER_DECL(EmitGreaterEq);
static void EMITTER_DECL(EmitLess);
static void EMITTER_DECL(EmitLessEq);
static void EMITTER_DECL(EmitLog);
static void EMITTER_DECL(EmitMaximum);
static void EMITTER_DECL(EmitMinimum);
static void EMITTER_DECL(EmitNegative);
static void EMITTER_DECL(EmitNotEqual);
static void EMITTER_DECL(EmitSelect);
static void EMITTER_DECL(EmitSubtract);
static void EMITTER_DECL(EmitBroadcast);
static void EMITTER_DECL(EmitConvert);
static void EMITTER_DECL(EmitConstant);
static void EMITTER_DECL(EmitReshape);
static void EMITTER_DECL(EmitFunctionCall);
static void EMITTER_DECL(EmitReduce);
static void EMITTER_DECL(EmitSign);
static void EMITTER_DECL(EmitSlice);
static void EMITTER_DECL(EmitSum);
static void EMITTER_DECL(EmitExp);
static void EMITTER_DECL(EmitSin);
static void EMITTER_DECL(EmitSinh);
static void EMITTER_DECL(EmitCos);
static void EMITTER_DECL(EmitCosh);
static void EMITTER_DECL(EmitTan);
static void EMITTER_DECL(EmitTanh);
static void EMITTER_DECL(EmitAsin);
static void EMITTER_DECL(EmitAcos);
static void EMITTER_DECL(EmitAtan);
static void EMITTER_DECL(EmitPower);
static void EMITTER_DECL(EmitReplaceSlice);
static void EMITTER_DECL(EmitOneHot);
static void EMITTER_DECL(EmitFloor);
static void EMITTER_DECL(EmitCeiling);
static void EMITTER_DECL(EmitSqrt);
static void EMITTER_DECL(EmitConvolution);
static void EMITTER_DECL(EmitNot);
static void EMITTER_DECL(EmitMaxPool);
static void EMITTER_DECL(EmitReverse);
static void EMITTER_DECL(EmitReduceWindow);
private:
void generate_call(const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
std::shared_ptr<Function> function);
std::string emit_vector(const TensorViewWrapper&, const std::string& name = "");
std::string emit_array1d(const TensorViewWrapper&, const std::string& name = "");
std::string emit_matrix(const TensorViewWrapper&, const std::string& name = "");
static std::string emit_vector(const TensorViewWrapper&,
const std::string& name = "");
static std::string emit_array1d(const TensorViewWrapper&,
const std::string& name = "");
static std::string emit_matrix(const TensorViewWrapper&,
const std::string& name = "");
};
}
}
......
......@@ -216,8 +216,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<pass::DumpSorted>(dump_filename);
pass_manager.run_passes(m_function);
CPU_Emitter emitter;
codegen::CodeWriter& writer = emitter.get_code_writer();
codegen::CodeWriter writer;
writer +=
R"(// Generated by the NGraph CPU backend
......@@ -432,7 +431,7 @@ using namespace ngraph::runtime;
writer << "\n)\n";
writer << "{\n";
writer.indent++;
handler->second(&emitter, &n, in, out);
handler->second(writer, &n, in, out);
writer.indent--;
writer << "}\n";
}
......@@ -632,7 +631,7 @@ using namespace ngraph::runtime;
}
if (func_name.empty())
{
handler->second(&emitter, node.get(), in, out);
handler->second(writer, node.get(), in, out);
}
else
{
......
......@@ -38,7 +38,7 @@ namespace ngraph
class CPU_Emitter;
class CPU_CallFrame;
using OpFunction = std::function<void(CPU_Emitter*,
using OpFunction = std::function<void(codegen::CodeWriter&,
const ngraph::Node*,
const std::vector<TensorViewWrapper>& inputs,
const std::vector<TensorViewWrapper>& outputs)>;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment