Commit 59bdd6ee authored by Louis Feng's avatar Louis Feng Committed by Robert Kimball

Adding block_begin and block_end to CodeWriter (#524)

* added block_begin and block_end util to CodeWriter.

* refactored convolution backprop to use block_begin() and block_end()
parent 4e29c153
...@@ -70,6 +70,18 @@ public: ...@@ -70,6 +70,18 @@ public:
std::string generate_temporary_name(std::string prefix = "tempvar"); std::string generate_temporary_name(std::string prefix = "tempvar");
void block_begin()
{
*this << "{\n";
indent++;
}
void block_end()
{
indent--;
*this << "}\n";
}
private: private:
std::stringstream m_ss; std::stringstream m_ss;
bool m_pending_indent; bool m_pending_indent;
......
...@@ -2161,10 +2161,9 @@ namespace ngraph ...@@ -2161,10 +2161,9 @@ namespace ngraph
writer << "memory::dims " << var << "{" << dims << "};\n"; writer << "memory::dims " << var << "{" << dims << "};\n";
}; };
writer << "{\n"; writer.block_begin();
writer.indent++; writer << "try\n";
writer << "try {\n"; writer.block_begin();
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc("data_desc", join(arg0_shape), elem_type, "nchw"); emit_memory_desc("data_desc", join(arg0_shape), elem_type, "nchw");
emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw"); emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw");
...@@ -2194,15 +2193,13 @@ namespace ngraph ...@@ -2194,15 +2193,13 @@ namespace ngraph
"result);\n" "result);\n"
"stream s = stream(stream::kind::eager);\n" "stream s = stream(stream::kind::eager);\n"
"s.submit({bwd_weights}).wait();\n"; "s.submit({bwd_weights}).wait();\n";
writer.indent--; writer.block_end();
writer << "} catch (const mkldnn::error& e) {\n"; writer << "catch (const mkldnn::error& e)\n";
writer.indent++; writer.block_begin();
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string(" writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n"; "e.status) + \"): \" + e.message);\n";
writer.indent--; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
else else
{ {
...@@ -2275,10 +2272,9 @@ namespace ngraph ...@@ -2275,10 +2272,9 @@ namespace ngraph
writer << "memory::dims " << var << "{" << dims << "};\n"; writer << "memory::dims " << var << "{" << dims << "};\n";
}; };
writer << "{\n"; writer.block_begin();
writer.indent++; writer << "try\n";
writer << "try {\n"; writer.block_begin();
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc("weight_desc", join(arg0_shape), elem_type, "oihw"); emit_memory_desc("weight_desc", join(arg0_shape), elem_type, "oihw");
emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw"); emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw");
...@@ -2307,15 +2303,13 @@ namespace ngraph ...@@ -2307,15 +2303,13 @@ namespace ngraph
"result);\n" "result);\n"
"stream s = stream(stream::kind::eager);\n" "stream s = stream(stream::kind::eager);\n"
"s.submit({bwd_data}).wait();\n"; "s.submit({bwd_data}).wait();\n";
writer.indent--; writer.block_end();
writer << "} catch (const mkldnn::error& e) {\n"; writer << "catch (const mkldnn::error& e)\n";
writer.indent++; writer.block_begin();
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string(" writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n"; "e.status) + \"): \" + e.message);\n";
writer.indent--; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
else else
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment