Commit c2b0b066 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Faster ReverseSequence CPU implementation (#981)

* clean up, rename

*  remove commented code; add comments
parent 95ab987e
......@@ -2828,15 +2828,76 @@ namespace ngraph
void CPU_Emitter::EMITTER_DECL(ngraph::op::ReverseSequence)
{
auto rs = static_cast<const ngraph::op::ReverseSequence*>(node);
string iv_prefix{"i"};
size_t ibi = rs->get_batch_axis();
string bi = iv_prefix + std::to_string(ibi);
string si = iv_prefix + std::to_string(rs->get_sequence_axis());
auto arg_shape = args[0].get_shape();
writer << "reference::reverse_sequence<" << out[0].get_type() << ","
<< args[1].get_type() << ">(" << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(arg_shape) << "},\n";
writer << " " << rs->get_batch_axis() << ",\n";
writer << " " << rs->get_sequence_axis() << ",\n";
writer << " " << args[1].get_name() << ");\n";
//iterate over seq_lengths make sure indices aren't out of bounds and normalize
writer << "std::vector<size_t> norm_seq_lengths (" << arg_shape.at(ibi) << ");\n";
writer << emit_for_lt(iv_prefix, ibi, arg_shape.at(ibi));
writer.block_begin();
writer << "auto seq_len = static_cast<size_t>(" << args[1].get_name() << "[" << bi
<< "]);\n";
writer << "if (seq_len > " << arg_shape.at(rs->get_sequence_axis()) << ")\n";
writer.block_begin();
writer << "throw \"One of the elements of sequence lengths is greater than "
"sequence axis\";\n";
writer.block_end();
writer << "if (seq_len == 0)\n";
writer.block_begin();
writer << "norm_seq_lengths[" << bi << "] = 1;\n";
writer.block_end();
writer << " else \n";
writer.block_begin();
writer << "norm_seq_lengths[" << bi << "] = seq_len;\n";
writer.block_end();
writer.block_end();
std::vector<std::string> sdims;
for (auto d : arg_shape)
{
sdims.push_back(std::to_string(d));
}
//convert input and output into multidimensional arrays
auto isdims = emit_indices(sdims);
writer << args[0].get_type() << "(&src)" << isdims << " = *reinterpret_cast<"
<< args[0].get_type() << " (*)" << isdims << ">(" << args[0].get_name()
<< ");\n";
writer << args[0].get_type() << "(&dst)" << isdims << " = *reinterpret_cast<"
<< args[0].get_type() << " (*)" << isdims << ">(" << out[0].get_name()
<< ");\n";
//reverse sequence
std::vector<std::string> source_indices;
for (size_t i = 0; i < arg_shape.size(); i++)
{
writer << emit_for_lt(iv_prefix, i, arg_shape.at(i));
writer.block_begin();
source_indices.push_back(iv_prefix + std::to_string(i));
}
writer << "auto seq_len = norm_seq_lengths[" << bi << "];\n";
writer << "size_t sequence_index = " << si << " < seq_len "
<< "? seq_len - " << si << " - 1"
<< ": " << si << ";\n";
std::vector<std::string> output_indices(source_indices);
output_indices.at(rs->get_sequence_axis()) = "sequence_index";
writer << "dst" << emit_indices(output_indices) << " = "
<< "src" << emit_indices(source_indices) << ";\n";
for (size_t i = 0; i < arg_shape.size(); i++)
{
writer.block_end();
}
}
template <>
......@@ -3832,8 +3893,27 @@ static string format_name(const string& name)
return rc;
}
string runtime::cpu::CPU_Emitter::emit_vector(const runtime::cpu::TensorViewWrapper& tvi,
const string& name)
std::string runtime::cpu::CPU_Emitter::emit_indices(const std::vector<std::string> indices)
{
stringstream ss;
for (auto i : indices)
{
ss << "[" << i << "]";
}
return ss.str();
}
std::string
runtime::cpu::CPU_Emitter::emit_for_lt(const std::string& prefix, size_t index, size_t to)
{
stringstream ss;
auto iv = prefix + std::to_string(index);
ss << "for (size_t " << iv << " = 0 ; " << iv << " < " << to << "; " << iv << "++)\n";
return ss.str();
}
std::string runtime::cpu::CPU_Emitter::emit_vector(const runtime::cpu::TensorViewWrapper& tvi,
const string& name)
{
stringstream ss;
......
......@@ -72,6 +72,9 @@ namespace ngraph
const std::string& name = "");
static std::string emit_matrix(const TensorViewWrapper&,
const std::string& name = "");
static std::string emit_for_lt(const std::string& prefix, size_t index, size_t to);
static std::string emit_indices(const std::vector<std::string> indices);
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment