Commit b9a599a1 authored by Adam Procter's avatar Adam Procter

wip

parent beb8c442
......@@ -142,6 +142,8 @@ set (SRC
op/experimental/dyn_broadcast.hpp
op/experimental/dyn_pad.cpp
op/experimental/dyn_pad.hpp
op/experimental/dyn_replace_slice.cpp
op/experimental/dyn_replace_slice.hpp
op/experimental/dyn_reshape.cpp
op/experimental/dyn_reshape.hpp
op/experimental/dyn_slice.cpp
......
......@@ -89,6 +89,7 @@
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
......
......@@ -17,6 +17,7 @@
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory>
......@@ -42,142 +43,6 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types();
}
Shape op::DynSlice::compute_output_shape() const
{
auto input_shape = get_input_partial_shape(0).to_shape();
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(1));
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(3));
if (lower_bounds && upper_bounds && strides)
{
auto lb = lower_bounds->get_vector<int64_t>();
auto ub = upper_bounds->get_vector<int64_t>();
auto str = strides->get_vector<int64_t>();
int max_dims = input_shape.size() + m_new_axis.size();
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(
this,
lb.size() == ub.size(),
"Lower bounds and Upper bounds needs to have same number of values");
}
if (lb.size() && str.size())
{
NODE_VALIDATION_CHECK(this,
lb.size() == str.size(),
"Lower bounds and strides needs to have same number of values");
}
if (ub.size() && str.size())
{
NODE_VALIDATION_CHECK(this,
ub.size() == str.size(),
"Upper bounds and strides needs to have same number of values");
}
int bounds_size =
lb.size() ? lb.size() : (ub.size() ? ub.size() : (str.size() ? str.size() : 0));
NODE_VALIDATION_CHECK(
this, m_ellipsis_mask.size() <= 1, "Ellipsis mask cannot specify more than one axis");
int ellipsis_pos1 = m_ellipsis_mask.size() ? *m_ellipsis_mask.begin() : max_dims;
int ellipsis_pos2 = max_dims;
bounds_size -= ellipsis_pos1;
if (bounds_size > 0 && (max_dims - bounds_size) > ellipsis_pos1)
{
ellipsis_pos2 = max_dims - bounds_size;
}
std::vector<int> begin_dms(max_dims, 0);
std::vector<int> end_dms(max_dims, -1);
std::vector<int> stride_dms(max_dims, 1);
int i, j, k, bj, ej, sj;
Shape out_dims;
for (i = 0, j = 0, k = 0, bj = 0, ej = 0, sj = 0; i < max_dims; i++)
{
if (i >= ellipsis_pos1 && i < ellipsis_pos2)
{
if (m_new_axis.find(i) == m_new_axis.end())
{
end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : input_shape[j++] + end_dms[i];
}
else
{
end_dms[i] = begin_dms[i];
}
out_dims.push_back(
static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
static_cast<float>(abs(stride_dms[i])))));
k = ellipsis_pos1;
continue;
}
stride_dms[i] = (str.size() > sj && str[sj] != 0) ? str[sj++] : 1;
// Use lower_bounds if mask is not set
if (m_lower_bounds_mask.find(j) == m_lower_bounds_mask.end())
{
begin_dms[i] = lb.size() > bj ? lb[bj] : (stride_dms[i] > 0 ? 0 : -1);
}
else
{
begin_dms[i] = stride_dms[i] > 0 ? 0 : -1;
}
bj++;
begin_dms[i] = begin_dms[i] >= 0 ? begin_dms[i] : input_shape[j] + begin_dms[i];
// Clipping 'begin'
begin_dms[i] =
(begin_dms[i] < 0) ? 0 : (begin_dms[i] >= input_shape[j] ? input_shape[j] - 1
: begin_dms[i]);
// Use upper_bounds if mask is not set
if (m_upper_bounds_mask.find(j) == m_upper_bounds_mask.end())
{
int end_dms_tmp =
ub.size() > ej ? (stride_dms[i] > 0 ? ub[ej] - 1 : ub[ej] + 1) : end_dms[i];
end_dms[i] = ub.size() > ej ? end_dms_tmp : (stride_dms[i] > 0 ? -1 : 0);
}
else
{
end_dms[i] = stride_dms[i] > 0 ? -1 : 0;
}
ej++;
end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : input_shape[j] + end_dms[i];
// Clipping 'end'
end_dms[i] = (end_dms[i] < 0) ? 0 : (end_dms[i] >= input_shape[j] ? input_shape[j] - 1
: end_dms[i]);
if (m_new_axis.find(i) == m_new_axis.end())
{
j++;
}
else
{
end_dms[i] = 0;
}
if (m_shrink_axis.find(k) != m_shrink_axis.end())
{
end_dms[i] = begin_dms[i];
}
else
{
out_dims.push_back(
static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
static_cast<float>(abs(stride_dms[i])))));
}
k++;
}
return out_dims;
}
return Shape{};
}
void op::DynSlice::validate_and_infer_types()
{
auto lower_bounds_et = get_input_element_type(1);
......@@ -219,17 +84,24 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
if (get_input_partial_shape(0).is_static())
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(1));
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(3));
if (lower_bounds && upper_bounds && strides)
{
auto shape = compute_output_shape();
if (shape != Shape{})
{
set_output_type(0, get_input_element_type(0), shape);
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
set_output_type(0,
get_input_element_type(0),
infer_slice_shape(this,
get_input_partial_shape(0),
lower_bounds->get_vector<int64_t>(),
upper_bounds->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
m_lower_bounds_mask,
m_upper_bounds_mask,
m_new_axis,
m_shrink_axis,
m_ellipsis_mask));
}
else
{
......
......@@ -614,3 +614,165 @@ void ngraph::infer_auto_padding(const Shape& image_shape,
padding_above.push_back(pad_type == op::PadType::SAME_UPPER ? padding_rhs : padding_lhs);
}
}
PartialShape ngraph::infer_slice_shape(const Node* node,
const PartialShape& input_shape,
const std::vector<int64_t>& lb,
const std::vector<int64_t>& ub,
const std::vector<int64_t>& str,
const AxisSet& lb_mask,
const AxisSet& ub_mask,
const AxisSet& new_axis,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
{
// TODO(amprocte): double-check that these checks are needed.
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(node,
lb.size() == ub.size(),
"Lower bounds and Upper bounds needs to have same number of values");
}
if (lb.size() && str.size())
{
NODE_VALIDATION_CHECK(node,
lb.size() == str.size(),
"Lower bounds and strides needs to have same number of values");
}
if (ub.size() && str.size())
{
NODE_VALIDATION_CHECK(node,
ub.size() == str.size(),
"Upper bounds and strides needs to have same number of values");
}
if (input_shape.rank().is_dynamic())
{
return PartialShape::dynamic();
}
int max_dims = size_t(input_shape.rank()) + new_axis.size();
int bounds_size =
lb.size() ? lb.size() : (ub.size() ? ub.size() : (str.size() ? str.size() : 0));
int ellipsis_pos1 = ellipsis_mask.size() ? *ellipsis_mask.begin() : max_dims;
int ellipsis_pos2 = max_dims;
bounds_size -= ellipsis_pos1;
if (bounds_size > 0 && (max_dims - bounds_size) > ellipsis_pos1)
{
ellipsis_pos2 = max_dims - bounds_size;
}
std::vector<Dimension> begin_dms(max_dims, 0);
std::vector<Dimension> end_dms(max_dims, -1);
std::vector<Dimension> stride_dms(max_dims, 1);
int i, j, k, bj, ej, sj;
std::vector<Dimension> out_dims;
for (i = 0, j = 0, k = 0, bj = 0, ej = 0, sj = 0; i < max_dims; i++)
{
if (i >= ellipsis_pos1 && i < ellipsis_pos2)
{
if (new_axis.find(i) == new_axis.end())
{
end_dms[i] = end_dms[i].is_static() && int64_t(end_dms[i]) >= 0
? end_dms[i]
: input_shape[j++] + end_dms[i];
}
else
{
end_dms[i] = begin_dms[i];
}
out_dims.push_back(
(end_dms[i].is_dynamic() || begin_dms[i].is_dynamic() || stride_dms[i].is_dynamic())
? Dimension::dynamic()
: static_cast<int64_t>(ceil(
static_cast<float>(abs(int64_t(end_dms[i]) - int64_t(begin_dms[i])) + 1) /
static_cast<float>(abs(int64_t(stride_dms[i]))))));
k = ellipsis_pos1;
continue;
}
stride_dms[i] = (str.size() > sj && str[sj] != 0) ? str[sj++] : 1;
// Use lower_bounds if mask is not set
if (lb_mask.find(j) == lb_mask.end())
{
begin_dms[i] = lb.size() > bj ? lb[bj] : (stride_dms[i].is_dynamic()
? Dimension::dynamic()
: (int64_t(stride_dms[i]) > 0 ? 0 : -1));
}
else
{
begin_dms[i] = stride_dms[i].is_dynamic() ? Dimension::dynamic()
: (int64_t(stride_dms[i]) > 0 ? 0 : -1);
}
bj++;
begin_dms[i] = (begin_dms[i].is_static() && int64_t(begin_dms[i]) >= 0)
? begin_dms[i]
: input_shape[j] + begin_dms[i];
// Clipping 'begin'
begin_dms[i] = (begin_dms[i].is_static() && int64_t(begin_dms[i]) < 0)
? 0
: (begin_dms[i].is_static() && input_shape[j].is_static() &&
int64_t(begin_dms[i]) >= int64_t(input_shape[j])
? input_shape[j] - 1
: begin_dms[i]);
// Use upper_bounds if mask is not set
if (ub_mask.find(j) == ub_mask.end())
{
Dimension end_dms_tmp =
ub.size() > ej
? (stride_dms[i].is_static() && int64_t(stride_dms[i]) > 0 ? ub[ej] - 1
: ub[ej] + 1)
: end_dms[i];
end_dms[i] = ub.size() > ej
? end_dms_tmp
: (stride_dms[i].is_static() && int64_t(stride_dms[i]) > 0 ? -1 : 0);
}
else
{
end_dms[i] = stride_dms[i].is_static() && int64_t(stride_dms[i]) > 0 ? -1 : 0;
}
ej++;
end_dms[i] = end_dms[i].is_static() && int64_t(end_dms[i]) >= 0
? end_dms[i]
: input_shape[j] + end_dms[i];
// Clipping 'end'
end_dms[i] = (end_dms[i].is_static() && int64_t(end_dms[i]) < 0)
? 0
: (end_dms[i].is_static() && input_shape[j].is_static() &&
int64_t(end_dms[i]) >= int64_t(input_shape[j])
? input_shape[j] - 1
: end_dms[i]);
if (new_axis.find(i) == new_axis.end())
{
j++;
}
else
{
end_dms[i] = 0;
}
if (shrink_axis.find(k) != shrink_axis.end())
{
end_dms[i] = begin_dms[i];
}
else
{
out_dims.push_back(
end_dms[i].is_dynamic() || begin_dms[i].is_dynamic() || stride_dms[i].is_dynamic()
? Dimension::dynamic()
: static_cast<int64_t>(ceil(
static_cast<float>(abs(int64_t(end_dms[i]) - int64_t(begin_dms[i])) + 1) /
static_cast<float>(abs(int64_t(stride_dms[i]))))));
}
k++;
}
return out_dims;
}
......@@ -92,4 +92,15 @@ namespace ngraph
const op::PadType pad_type,
CoordinateDiff& padding_above,
CoordinateDiff& padding_below);
PartialShape infer_slice_shape(const Node* node,
const PartialShape& input_shape,
const std::vector<int64_t>& lb,
const std::vector<int64_t>& ub,
const std::vector<int64_t>& str,
const AxisSet& lb_mask,
const AxisSet& ub_mask,
const AxisSet& new_axis,
const AxisSet& shrink_mask,
const AxisSet& ellipsis_mask);
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment