Unverified Commit 02559274 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into bob/backend_api3

parents d607c6f4 13b4966b
...@@ -325,6 +325,12 @@ namespace ngraph ...@@ -325,6 +325,12 @@ namespace ngraph
max_freezed_output_conv_1, max_freezed_output_conv_1,
min_freezed_output_conv_2, min_freezed_output_conv_2,
max_freezed_output_conv_2); max_freezed_output_conv_2);
if (output_et == element::u8)
{
// Need to multiply by two to account for u8 requantization_scale
auto two = make_constant(element::f32, sum_scale->get_shape(), 2.0f);
sum_scale = two * sum_scale;
}
if (bias->get_element_type() != element::i32) if (bias->get_element_type() != element::i32)
{ {
......
...@@ -548,3 +548,19 @@ bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t ...@@ -548,3 +548,19 @@ bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t
} }
return false; return false;
} }
bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
{
if (!(n1->is_constant() && n2->is_constant()))
{
return false;
}
if (static_pointer_cast<op::Constant>(n1)->get_value_strings() !=
static_pointer_cast<op::Constant>(n2)->get_value_strings())
{
return false;
}
return true;
}
...@@ -309,6 +309,8 @@ namespace ngraph ...@@ -309,6 +309,8 @@ namespace ngraph
bool is_one(std::shared_ptr<Node> reduce_constant); bool is_one(std::shared_ptr<Node> reduce_constant);
bool compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2);
// Returns true if `node` is live in the graph i.e. a result op // Returns true if `node` is live in the graph i.e. a result op
// transitively uses this `node` // transitively uses this `node`
bool is_used(Node* node); bool is_used(Node* node);
......
...@@ -334,7 +334,7 @@ namespace ngraph ...@@ -334,7 +334,7 @@ namespace ngraph
} }
if (old_pops.kind(i) == mkldnn::primitive::kind::sum) if (old_pops.kind(i) == mkldnn::primitive::kind::sum)
{ {
new_pops.append_sum(2 * dyn_post_op_scales[0]); new_pops.append_sum(dyn_post_op_scales[0]);
} }
} }
conv_attr.set_post_ops(new_pops); conv_attr.set_post_ops(new_pops);
......
...@@ -1102,6 +1102,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma ...@@ -1102,6 +1102,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass); REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass); REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUQuantFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUHorizontalFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass); REGISTER_KNOBBED_PASS(CPUCollapseDims, true, runtime::cpu::pass);
#if defined(NGRAPH_HALIDE) #if defined(NGRAPH_HALIDE)
......
...@@ -186,21 +186,14 @@ namespace ngraph ...@@ -186,21 +186,14 @@ namespace ngraph
ops.append_sum(1.f); ops.append_sum(1.f);
} }
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>()) if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{ {
auto sum_scale_val = auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5); extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5);
ops.append_sum(sum_scale_val[0]); ops.append_sum(sum_scale_val[0]);
} }
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasSignedAdd>(node,
5);
ops.append_sum(2.0 * sum_scale_val[0]);
}
if (has_relu<OP>(node)) if (has_relu<OP>(node))
{ {
const float ops_scale = 1.f; const float ops_scale = 1.f;
...@@ -740,21 +733,14 @@ namespace ngraph ...@@ -740,21 +733,14 @@ namespace ngraph
ops.append_sum(1.f); ops.append_sum(1.f);
} }
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>()) if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasAdd>() ||
std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{ {
auto sum_scale_val = auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5); extract_scale_value<ngraph::op::QuantizedConvolutionBiasAdd>(node, 5);
ops.append_sum(sum_scale_val[0]); ops.append_sum(sum_scale_val[0]);
} }
if (std::is_same<OP, ngraph::op::QuantizedConvolutionBiasSignedAdd>())
{
auto sum_scale_val =
extract_scale_value<ngraph::op::QuantizedConvolutionBiasSignedAdd>(node,
5);
ops.append_sum(2.0 * sum_scale_val[0]);
}
if (has_relu<OP>(node)) if (has_relu<OP>(node))
{ {
const float ops_scale = 1.f; const float ops_scale = 1.f;
......
...@@ -99,11 +99,27 @@ namespace ngraph ...@@ -99,11 +99,27 @@ namespace ngraph
{ {
return false; return false;
} }
if (node->get_input_element_type(0) != element::f32) // Data
if (node->get_input_element_type(0) != element::f32 &&
node->get_input_element_type(0) != element::i8 &&
node->get_input_element_type(0) != element::u8)
{
return false;
}
// Weights
if (node->get_input_element_type(1) != element::f32 &&
node->get_input_element_type(1) != element::i8)
{
return false;
}
// Outputs
if (node->get_output_element_type(0) != element::f32 &&
node->get_output_element_type(0) != element::i8 &&
node->get_output_element_type(0) != element::u8 &&
node->get_output_element_type(0) != element::i32)
{ {
return false; return false;
} }
return true; return true;
} }
} }
......
This diff is collapsed.
...@@ -28,6 +28,7 @@ namespace ngraph ...@@ -28,6 +28,7 @@ namespace ngraph
namespace pass namespace pass
{ {
class CPUFusion; class CPUFusion;
class CPUQuantFusion;
} }
} }
} }
...@@ -101,3 +102,21 @@ private: ...@@ -101,3 +102,21 @@ private:
void construct_update_slice(); void construct_update_slice();
void construct_fuse_lstm_recurrent_state(); void construct_fuse_lstm_recurrent_state();
}; };
class CPU_BACKEND_API ngraph::runtime::cpu::pass::CPUQuantFusion : public ngraph::pass::GraphRewrite
{
public:
CPUQuantFusion()
: GraphRewrite()
{
construct_qconv_relu(true);
construct_qconv_relu(false);
construct_qconvb_add();
construct_dq_q();
}
private:
void construct_qconv_relu(bool with_bias);
void construct_dq_q();
void construct_qconvb_add();
};
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment