Unverified Commit f6fe6aca authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Restore constant folding for DynReshape until users are converted to … (#4164)

* Restore constant folding for DynReshape until users are converted to v1 Reshape

* Disbale test when no serialization
Co-authored-by: 's avatarbaojun <32073718+baojun-nervana@users.noreply.github.com>
parent ef553de3
......@@ -17,6 +17,7 @@
#include <numeric>
#include "constant_folding.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -24,9 +25,9 @@
using namespace std;
using namespace ngraph;
template <class T>
template <typename T, typename R>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
shared_ptr<op::v1::Reshape> dyn_reshape)
R dyn_reshape)
{
const Shape& out_shape = dyn_reshape->get_shape();
runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(T));
......@@ -44,18 +45,102 @@ shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> cons
return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, data_ptr);
}
template <typename R>
std::shared_ptr<Node> do_fold(R dyn_reshape_match, shared_ptr<op::Constant> constant_data_match)
{
std::shared_ptr<Node> replacement;
auto type = dyn_reshape_match->get_element_type();
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::bf16:
replacement = fold_constant_dyn_reshape<bfloat16>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f16:
replacement = fold_constant_dyn_reshape<float16>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f32:
replacement = fold_constant_dyn_reshape<float>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f64:
replacement = fold_constant_dyn_reshape<double>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i8:
replacement = fold_constant_dyn_reshape<int8_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i16:
replacement = fold_constant_dyn_reshape<int16_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i32:
replacement = fold_constant_dyn_reshape<int32_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i64:
replacement = fold_constant_dyn_reshape<int64_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u8:
replacement = fold_constant_dyn_reshape<uint8_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u16:
replacement = fold_constant_dyn_reshape<uint16_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u32:
replacement = fold_constant_dyn_reshape<uint32_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u64:
replacement = fold_constant_dyn_reshape<uint64_t>(constant_data_match, dyn_reshape_match);
break;
}
return replacement;
}
void pass::ConstantFolding::construct_constant_dyn_reshape()
{
auto constant_data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_reshape =
auto reshape_v1 =
make_shared<op::v1::Reshape>(constant_data_label, constant_shape_label, false);
auto dyn_reshape =
make_shared<op::v0::DynReshape>(constant_data_label, constant_shape_label, false);
// Note: No need to capture or consider constant_shape_label, because
// shape propagation will have transferred the info to dyn_reshape's
// output.
auto constant_reshape_v1_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_v1_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_data_match =
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto match_root = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(match_root));
shared_ptr<Node> replacement;
replacement =
do_fold(static_pointer_cast<op::v1::Reshape>(match_root), constant_data_match);
replace_node(m.get_match_root(), replacement);
return true;
};
auto reshape_v1_matcher =
make_shared<pattern::Matcher>(reshape_v1, "ConstantFolding.ConstantReshapev1");
this->add_matcher(
reshape_v1_matcher, constant_reshape_v1_callback, PassProperty::CHANGE_DYNAMIC_STATE);
auto constant_dyn_reshape_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_dyn_reshape_callback against node = "
<< m.get_match_root()->get_name();
......@@ -64,75 +149,11 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
auto constant_data_match =
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto dyn_reshape_match = static_pointer_cast<op::v1::Reshape>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(dyn_reshape_match));
std::shared_ptr<Node> replacement;
auto type = dyn_reshape_match->get_element_type();
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::u1:
NGRAPH_CHECK(false, "Encountered 'u1' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::bf16:
replacement =
fold_constant_dyn_reshape<bfloat16>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f16:
replacement =
fold_constant_dyn_reshape<float16>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f32:
replacement = fold_constant_dyn_reshape<float>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f64:
replacement = fold_constant_dyn_reshape<double>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i8:
replacement = fold_constant_dyn_reshape<int8_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i16:
replacement =
fold_constant_dyn_reshape<int16_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i32:
replacement =
fold_constant_dyn_reshape<int32_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i64:
replacement =
fold_constant_dyn_reshape<int64_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u8:
replacement =
fold_constant_dyn_reshape<uint8_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u16:
replacement =
fold_constant_dyn_reshape<uint16_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u32:
replacement =
fold_constant_dyn_reshape<uint32_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u64:
replacement =
fold_constant_dyn_reshape<uint64_t>(constant_data_match, dyn_reshape_match);
break;
}
auto match_root = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(match_root));
shared_ptr<Node> replacement;
replacement =
do_fold(static_pointer_cast<op::v0::DynReshape>(match_root), constant_data_match);
replace_node(m.get_match_root(), replacement);
return true;
};
......
......@@ -14,10 +14,13 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/dyn_elimination.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/dyn_elimination.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
......@@ -266,3 +269,25 @@ TEST(dyn_elimination, range_f64)
ASSERT_TRUE(test::all_close_f(
vals, vector<double>{-0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75}));
}
#ifndef NGRAPH_JSON_DISABLE
TEST(dyn_elimination, paddlepaddle_transpose)
{
string model = "paddlepaddle/transpose.json";
const string json_path = file_util::path_join(SERIALIZED_ZOO, model);
const string json_string = file_util::read_file_to_string(json_path);
shared_ptr<Function> f = ngraph::deserialize(json_string);
vector<element::Type> arg_element_types = {element::f64, element::f64};
vector<PartialShape> arg_shapes = {{3, 4}, {4, 3}};
std::vector<void*> arg_value_base_pointers = {nullptr, nullptr};
auto clone = specialize_function(f, arg_element_types, arg_shapes, arg_value_base_pointers);
pass::Manager passes;
passes.register_pass<pass::ConstantFolding>();
passes.register_pass<pass::DynElimination>();
passes.register_pass<pass::Opset0Downgrade>(); // Converts dynamic v1 variants to v0 ops
passes.set_per_pass_validation(false);
passes.run_passes(clone);
}
#endif
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment