Commit 9cc2ce34 authored by Ilya Churaev's avatar Ilya Churaev Committed by Scott Cyphers

Port PR 4156 and 4184 to master (#4219)

* Added FP32 to FP16 conversion transformation

* Update src/ngraph/pass/convert_fp32_to_fp16.cpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Update src/ngraph/pass/convert_fp32_to_fp16.hpp
Co-Authored-By: 's avatarRobert Kimball <robert.kimball@intel.com>

* Added NGRAPH_API for fp32 to fp16 conversion
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent e52a32e4
......@@ -496,6 +496,8 @@ set (SRC
pass/constant_folding.cpp
pass/constant_folding.hpp
pass/constant_to_broadcast.cpp
pass/convert_fp32_to_fp16.hpp
pass/convert_fp32_to_fp16.cpp
pass/core_fusion.cpp
pass/core_fusion.hpp
pass/cse.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/convert_fp32_to_fp16.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
void pass::ConvertFP32ToFP16::convert_constants_precision()
{
auto constant =
std::make_shared<ngraph::op::Constant>(element::f32, Shape{1}, std::vector<float>{0});
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto constant = std::dynamic_pointer_cast<ngraph::op::Constant>(m.get_match_root());
if (!constant)
{
return false;
}
if (constant->get_element_type() == element::f32)
{
auto data = constant->get_vector<float>();
std::vector<ngraph::float16> new_data(data.size());
for (size_t i = 0; i < data.size(); ++i)
{
new_data[i] = ngraph::float16(data[i]);
}
auto new_const = std::make_shared<ngraph::op::Constant>(
element::f16, constant->get_shape(), new_data);
new_const->set_friendly_name(constant->get_friendly_name());
ngraph::replace_node(constant, new_const);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertFP32ToFP16");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
void pass::ConvertFP32ToFP16::convert_parameters_precision()
{
auto constant = std::make_shared<ngraph::op::Parameter>(element::f32, Shape{1});
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto parameter = std::dynamic_pointer_cast<ngraph::op::Parameter>(m.get_match_root());
if (parameter && parameter->get_element_type() == element::f32)
{
parameter->set_element_type(element::f16);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(constant, "ConvertFP32ToFP16");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph
{
namespace pass
{
class ConvertFP32ToFP16;
} // namespace pass
} // namespace ngraph
class NGRAPH_API ngraph::pass::ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite
{
public:
ConvertFP32ToFP16()
: GraphRewrite()
{
convert_constants_precision();
convert_parameters_precision();
}
private:
void convert_constants_precision();
void convert_parameters_precision();
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment