Commit 0c4b5917 authored by gaurides's avatar gaurides Committed by Scott Cyphers

Codegen support for Dropout (#3075)

* Codegen support for Dropout

* Different implementation
parent 90ca4d87
......@@ -4011,7 +4011,56 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Dropout)
{
throw ngraph_error("Not yet implemented");
auto dropout = static_cast<const ngraph::op::Dropout*>(node);
size_t ncr = ngraph::runtime::cpu::executor::GetCPUExecutor().get_num_cores();
writer.block_begin();
writer << "bool training = static_cast<bool>(" << args[1].get_name() << "[0]);\n";
writer << "bool use_seed = " << to_string(dropout->get_use_seed()) << ";\n";
writer << "int32_t seed = use_seed ? " << to_string(dropout->get_seed())
<< " : rand();\n";
writer << "double keep_prob = static_cast<double>(" << args[4].get_name()
<< "[0]);\n";
writer << "size_t count = " << args[0].get_size() << ";\n";
writer << "size_t nthr = " << to_string(ncr) << ";\n";
//writer << "size_t nthr = " << to_string(ngraph::runtime::cpu::executor::GetCPUExecutor().get_num_cores()) << ";\n";
writer << "size_t chunk_size = (count + nthr - 1) / nthr;\n";
writer << "std::vector<std::minstd_rand> vmsr(nthr);\n";
writer << "for (size_t i = 0; i < nthr; i++)\n\
{\n\
std::minstd_rand msr;\n\
msr.seed(seed+i);\n\
vmsr[i] = msr;\n\
}\n";
writer << "double dropout_prob = 1 - keep_prob;\n";
writer << "std::uniform_real_distribution<> gen(0, 1);\n";
writer << "#pragma omp parallel num_threads(nthr)\n";
writer << "{\n";
writer << "size_t tid = omp_get_thread_num();\n";
writer << "std::minstd_rand msr;\n msr.seed(seed+tid);\n";
writer << "size_t idx_start = tid * chunk_size;\n";
writer << "size_t idx_end = std::min(idx_start + chunk_size, count);\n";
writer << "for (size_t i = idx_start; i < idx_end; i++)\n";
writer << "{\n";
writer << " //out[i] = training ? static_cast<T>(bd(gen)) : "
"static_cast<float>(1);\n";
writer << " //out0[i] = training ? input[i] : static_cast<float>(1);\n";
writer << " if (static_cast<float>(gen(msr)) < dropout_prob)\n";
writer << " {\n";
writer << " " << out[0].get_name() << "[i] = 0;\n";
writer << " " << out[1].get_name() << "[i] = 0;\n";
writer << " }\n";
writer << " else\n";
writer << " {\n";
writer << " " << out[1].get_name() << "[i] = 1;\n";
writer << " " << out[0].get_name() << "[i] = " << args[0].get_name()
<< "[i] / static_cast<float>(keep_prob);\n";
writer << " }\n";
writer << "}\n"; // for loop ends
writer << "}\n"; //#pragma ends
writer.block_end();
}
template <>
......
......@@ -270,7 +270,8 @@ namespace ngraph
size_t nelems,
bool training,
const double value,
const std::vector<std::minstd_rand>& vmsr);
const std::vector<std::minstd_rand>& vmsr,
const bool use_seed);
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment