Commit 2ddcb3be authored by Tim Zerrell's avatar Tim Zerrell

Clamp explicitly in PlaidML quantize

parent c93a09a4
......@@ -181,9 +181,46 @@ void ngraph::runtime::plaidml::ImplQuantize::Apply()
}
builder::Elementwise Rounded{"Rounded", ""};
builder::Elementwise Clamped{"Clamped", ""};
builder::Elementwise O{"O", ""};
std::ostringstream formula;
std::string low_int;
int64_t q_min;
int64_t q_max;
std::ostringstream clamp_formula;
if (type.size() > 4)
{
// PlaidML doesn't support quantization clamping for types wider than 32 bits
if (!type.is_signed())
{
clamp_formula << "Uncast < 0 ? 0 : Uncast";
}
else
{
clamp_formula << "Uncast";
}
}
else
{
if (type.is_signed())
{
q_max = (1 << (8 * type.size() - 1)) - 1;
q_min = -q_max - 1;
}
else
{
q_max = (1 << (8 * type.size())) - 1;
q_min = 0;
}
if (!type.is_signed())
{
}
clamp_formula << "Uncast < " << q_min << " ? " << q_min << " : "
<< "(Uncast > " << q_max << " ? " << q_max << " : Uncast)";
}
Clamped.set_rhs(clamp_formula.str());
std::ostringstream round_formula;
std::string lower_rounded_int;
switch (round_mode)
{
case ngraph::op::Quantize::RoundMode::ROUND_DOWN: Rounded.set_rhs("floor(Frac)"); break;
......@@ -208,16 +245,16 @@ void ngraph::runtime::plaidml::ImplQuantize::Apply()
break;
case ngraph::op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN:
// This is ugly, but it produces correct output
low_int = cast_to_output_type("ceil(Frac - 0.5)");
formula << "2 * (" << low_int << " / 2) == " << low_int
<< " ? ceil(Frac - 0.5) : floor(Frac + 0.5)";
Rounded.set_rhs(formula.str());
lower_rounded_int = cast_to_output_type("ceil(Frac - 0.5)");
round_formula << "2 * (" << lower_rounded_int << " / 2) == " << lower_rounded_int
<< " ? ceil(Frac - 0.5) : floor(Frac + 0.5)";
Rounded.set_rhs(round_formula.str());
break;
default:
throw std::runtime_error("Requested quantize round mode not yet implemented in PlaidML");
}
O.set_rhs(cast_to_output_type("Uncast"));
O.set_rhs(cast_to_output_type("Clamped"));
builder::ContractionInput scale_recip_input{"SRecip"};
builder::ContractionInput zp_input{"Z"};
......@@ -247,6 +284,7 @@ void ngraph::runtime::plaidml::ImplQuantize::Apply()
.set_lhs(
builder::ContractionInput{"Rounded"}.add_indices("i", 0, input_shape.size()))
.set_rhs(zp_input))
.add(Clamped)
.add(O);
set_output(f.finalize());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment