Commit df845963 authored by Robert Kimball's avatar Robert Kimball Committed by Scott Cyphers

add random init for all input types, not just float (#799)

* add random init for all input types, not just float

* remove debug
parent 2db236b7
......@@ -114,13 +114,90 @@ void print_times(const multimap<size_t, string>& timing)
}
}
static default_random_engine s_random_engine;
template <typename T>
void init_int_tv(shared_ptr<runtime::TensorView> tv, T min, T max)
{
uniform_int_distribution<T> dist(min, max);
std::vector<T> vec = read_vector<T>(tv);
for (T& element : vec)
{
element = dist(s_random_engine);
}
write_vector(tv, vec);
}
template <typename T>
void init_real_tv(shared_ptr<runtime::TensorView> tv, T min, T max)
{
uniform_real_distribution<T> dist(min, max);
std::vector<T> vec = read_vector<T>(tv);
for (T& element : vec)
{
element = dist(s_random_engine);
}
write_vector(tv, vec);
}
static void random_init(shared_ptr<runtime::TensorView> tv)
{
element::Type et = tv->get_tensor().get_element_type();
if (et == element::boolean)
{
init_int_tv<char>(tv, 0, 1);
}
else if (et == element::f32)
{
init_real_tv<float>(tv, -1, 1);
}
else if (et == element::f64)
{
init_real_tv<double>(tv, -1, 1);
}
else if (et == element::i8)
{
init_int_tv<int8_t>(tv, -1, 1);
}
else if (et == element::i16)
{
init_int_tv<int16_t>(tv, -1, 1);
}
else if (et == element::i32)
{
init_int_tv<int32_t>(tv, -1, 1);
}
else if (et == element::i64)
{
init_int_tv<int64_t>(tv, -1, 1);
}
else if (et == element::u8)
{
init_int_tv<uint8_t>(tv, 0, 1);
}
else if (et == element::u16)
{
init_int_tv<uint16_t>(tv, 0, 1);
}
else if (et == element::u32)
{
init_int_tv<uint32_t>(tv, 0, 1);
}
else if (et == element::u64)
{
init_int_tv<uint64_t>(tv, 0, 1);
}
else
{
throw runtime_error("unsupported type");
}
}
void run_benchmark(shared_ptr<Function> f,
const string& backend_name,
size_t iterations,
bool timing_detail)
{
test::Uniform<float> rng{-1, 1, 0};
stopwatch timer;
timer.start();
auto manager = runtime::Manager::get(backend_name);
......@@ -137,7 +214,7 @@ void run_benchmark(shared_ptr<Function> f,
{
auto tensor =
backend->make_primary_tensor_view(param->get_element_type(), param->get_shape());
rng.initialize(tensor);
random_init(tensor);
args.push_back(tensor);
}
vector<shared_ptr<runtime::TensorView>> results;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment