Unverified Commit 02559274 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into bob/backend_api3

parents d607c6f4 13b4966b
......@@ -325,6 +325,12 @@ namespace ngraph
max_freezed_output_conv_1,
min_freezed_output_conv_2,
max_freezed_output_conv_2);
if (output_et == element::u8)
{
// Need to multiply by two to account for u8 requantization_scale
auto two = make_constant(element::f32, sum_scale->get_shape(), 2.0f);
sum_scale = two * sum_scale;
}
if (bias->get_element_type() != element::i32)
{
......
......@@ -548,3 +548,19 @@ bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t
}
return false;
}
bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
{
if (!(n1->is_constant() && n2->is_constant()))
{
return false;
}
if (static_pointer_cast<op::Constant>(n1)->get_value_strings() !=
static_pointer_cast<op::Constant>(n2)->get_value_strings())
{
return false;
}
return true;
}
......@@ -309,6 +309,8 @@ namespace ngraph
bool is_one(std::shared_ptr<Node> reduce_constant);
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
// Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node`
bool is_used(Node* node);
......
......@@ -334,7 +334,7 @@ namespace ngraph
}
if (old_pops.kind(i) == mkldnn::primitive::kind::sum)
{
new_pops.append_sum(2 * dyn_post_op_scales[0]);
new_pops.append_sum(dyn_post_op_scales[0]);
}
}
conv_attr.set_post_ops(new_pops);
......
......@@ -1102,6 +1102,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
#if defined(NGRAPH_HALIDE)
......
......@@ -186,21 +186,14 @@ namespace ngraph
ops.append_sum(1.f);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>())
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5);
ops.append_sum(sum_scale_val[0]);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasSignedAdd>(node,
5);
ops.append_sum(2.0 * sum_scale_val[0]);
}
if (has_relu<OP>(node))
{
const float ops_scale = 1.f;
......@@ -740,21 +733,14 @@ namespace ngraph
ops.append_sum(1.f);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>())
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5);
ops.append_sum(sum_scale_val[0]);
}
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasSignedAdd>(node,
5);
ops.append_sum(2.0 * sum_scale_val[0]);
}
if (has_relu<OP>(node))
{
const float ops_scale = 1.f;
......
......@@ -99,11 +99,27 @@ namespace ngraph
{
return false;
}
if (node->get_input_element_type(0) != element::f32)
// Data
if (node->get_input_element_type(0) != element::f32 &&
node->get_input_element_type(0) != element::i8 &&
node->get_input_element_type(0) != element::u8)
{
return false;
}
// Weights
if (node->get_input_element_type(1) != element::f32 &&
node->get_input_element_type(1) != element::i8)
{
return false;
}
// Outputs
if (node->get_output_element_type(0) != element::f32 &&
node->get_output_element_type(0) != element::i8 &&
node->get_output_element_type(0) != element::u8 &&
node->get_output_element_type(0) != element::i32)
{
return false;
}
return true;
}
}
......
This diff is collapsed.
......@@ -28,6 +28,7 @@ namespace ngraph
namespace pass
{
class CPUFusion;
class CPUQuantFusion;
}
}
}
......@@ -101,3 +102,21 @@ private:
void construct_update_slice();
void construct_fuse_lstm_recurrent_state();
};
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUQuantFusion : public ngraph::pass::GraphRewrite
{
public:
CPUQuantFusion()
: GraphRewrite()
{
construct_qconv_relu(true);
construct_qconv_relu(false);
construct_qconvb_add();
construct_dq_q();
}
private:
void construct_qconv_relu(bool with_bias);
void construct_dq_q();
void construct_qconvb_add();
};
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment