Unverified Commit 35646d4f authored by Fenglei's avatar Fenglei Committed by GitHub

Merge branch 'master' into tfl/gpu_framework_codegen

parents 2cd593aa 59bdd6ee
...@@ -70,6 +70,18 @@ public: ...@@ -70,6 +70,18 @@ public:
std::string generate_temporary_name(std::string prefix = "tempvar"); std::string generate_temporary_name(std::string prefix = "tempvar");
void block_begin()
{
*this << "{\n";
indent++;
}
void block_end()
{
indent--;
*this << "}\n";
}
private: private:
std::stringstream m_ss; std::stringstream m_ss;
bool m_pending_indent; bool m_pending_indent;
......
...@@ -2161,10 +2161,9 @@ namespace ngraph ...@@ -2161,10 +2161,9 @@ namespace ngraph
writer << "memory::dims " << var << "{" << dims << "};\n"; writer << "memory::dims " << var << "{" << dims << "};\n";
}; };
writer << "{\n"; writer.block_begin();
writer.indent++; writer << "try\n";
writer << "try {\n"; writer.block_begin();
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc("data_desc", join(arg0_shape), elem_type, "nchw"); emit_memory_desc("data_desc", join(arg0_shape), elem_type, "nchw");
emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw"); emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw");
...@@ -2194,15 +2193,13 @@ namespace ngraph ...@@ -2194,15 +2193,13 @@ namespace ngraph
"result);\n" "result);\n"
"stream s = stream(stream::kind::eager);\n" "stream s = stream(stream::kind::eager);\n"
"s.submit({bwd_weights}).wait();\n"; "s.submit({bwd_weights}).wait();\n";
writer.indent--; writer.block_end();
writer << "} catch (const mkldnn::error& e) {\n"; writer << "catch (const mkldnn::error& e)\n";
writer.indent++; writer.block_begin();
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string(" writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n"; "e.status) + \"): \" + e.message);\n";
writer.indent--; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
else else
{ {
...@@ -2275,10 +2272,9 @@ namespace ngraph ...@@ -2275,10 +2272,9 @@ namespace ngraph
writer << "memory::dims " << var << "{" << dims << "};\n"; writer << "memory::dims " << var << "{" << dims << "};\n";
}; };
writer << "{\n"; writer.block_begin();
writer.indent++; writer << "try\n";
writer << "try {\n"; writer.block_begin();
writer.indent++;
writer << "engine cpu_engine = engine(engine::cpu, 0);\n"; writer << "engine cpu_engine = engine(engine::cpu, 0);\n";
emit_memory_desc("weight_desc", join(arg0_shape), elem_type, "oihw"); emit_memory_desc("weight_desc", join(arg0_shape), elem_type, "oihw");
emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw"); emit_memory_desc("delta_desc", join(arg1_shape), elem_type, "nchw");
...@@ -2307,15 +2303,13 @@ namespace ngraph ...@@ -2307,15 +2303,13 @@ namespace ngraph
"result);\n" "result);\n"
"stream s = stream(stream::kind::eager);\n" "stream s = stream(stream::kind::eager);\n"
"s.submit({bwd_data}).wait();\n"; "s.submit({bwd_data}).wait();\n";
writer.indent--; writer.block_end();
writer << "} catch (const mkldnn::error& e) {\n"; writer << "catch (const mkldnn::error& e)\n";
writer.indent++; writer.block_begin();
writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string(" writer << "throw ngraph::ngraph_error(\"MKLDNN ERROR (\" + std::to_string("
"e.status) + \"): \" + e.message);\n"; "e.status) + \"): \" + e.message);\n";
writer.indent--; writer.block_end();
writer << "}\n"; writer.block_end();
writer.indent--;
writer << "}\n";
} }
else else
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment