Commit 513f8de6 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

CTCGreedyDecoder layer op (#2965)

* Added CTCGreedyDecoder layer op

* Added comment on seq_len validation checks
parent cf5e3623
...@@ -168,6 +168,8 @@ set (SRC ...@@ -168,6 +168,8 @@ set (SRC
op/experimental/quantized_dot_bias.hpp op/experimental/quantized_dot_bias.hpp
op/experimental/transpose.cpp op/experimental/transpose.cpp
op/experimental/transpose.hpp op/experimental/transpose.hpp
op/experimental/layers/ctc_greedy_decoder.cpp
op/experimental/layers/ctc_greedy_decoder.hpp
op/experimental/layers/detection_output.cpp op/experimental/layers/detection_output.cpp
op/experimental/layers/detection_output.hpp op/experimental/layers/detection_output.hpp
op/experimental/layers/interpolate.cpp op/experimental/layers/interpolate.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ctc_greedy_decoder.hpp"
using namespace std;
using namespace ngraph;
op::CTCGreedyDecoder::CTCGreedyDecoder(const shared_ptr<Node>& input,
const std::shared_ptr<Node>& seq_len,
const bool ctc_merge_repeated)
: Op("CTCGreedyDecoder", check_single_output_args({input, seq_len}))
, m_ctc_merge_repeated(ctc_merge_repeated)
{
constructor_validate_and_infer_types();
}
void op::CTCGreedyDecoder::validate_and_infer_types()
{
auto input_et = get_input_element_type(0);
if (get_input_partial_shape(0).is_static())
{
Shape input_shape = get_input_partial_shape(0).to_shape();
NODE_VALIDATION_CHECK(this,
input_shape.size() >= 3,
"CTCGreedyDecoder expects 3 or more dimensions for input. Got ",
input_shape.size());
// TODO: Add more validation checks for seq_len
set_output_type(0, input_et, Shape{input_shape[1], input_shape[0], 1, 1});
}
else
{
set_output_type(0, input_et, PartialShape::dynamic());
}
}
shared_ptr<Node> op::CTCGreedyDecoder::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CTCGreedyDecoder>(new_args.at(0), new_args.at(1), m_ctc_merge_repeated);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class CTCGreedyDecoder : public Op
{
public:
/// \brief Constructs a CTCGreedyDecoder operation
///
/// \param input Logits on which greedy decoding is performed
/// \param seq_len Sequence lengths
/// \param ctc_merge_repeated Whether to merge repeated labels
CTCGreedyDecoder(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& seq_len,
const bool ctc_merge_repeated);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool get_ctc_merge_repeated() const { return m_ctc_merge_repeated; }
private:
bool m_ctc_merge_repeated;
};
}
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/op/experimental/layers/ctc_greedy_decoder.hpp"
#include "ngraph/op/experimental/layers/detection_output.hpp" #include "ngraph/op/experimental/layers/detection_output.hpp"
#include "ngraph/op/experimental/layers/interpolate.hpp" #include "ngraph/op/experimental/layers/interpolate.hpp"
#include "ngraph/op/experimental/layers/prior_box.hpp" #include "ngraph/op/experimental/layers/prior_box.hpp"
...@@ -31,6 +32,14 @@ ...@@ -31,6 +32,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
TEST(type_prop_layers, ctc_greedy_decoder)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{88, 2, 48, 1});
auto seq_len = make_shared<op::Parameter>(element::f32, Shape{88, 2});
auto op = make_shared<op::CTCGreedyDecoder>(input, seq_len, false);
ASSERT_EQ(op->get_shape(), (Shape{2, 88, 1, 1}));
}
TEST(type_prop_layers, detection_output) TEST(type_prop_layers, detection_output)
{ {
auto box_logits = make_shared<op::Parameter>(element::f32, Shape{4, 1, 5, 5}); auto box_logits = make_shared<op::Parameter>(element::f32, Shape{4, 1, 5, 5});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment