Unverified Commit 8a569f27 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/shape (#310)

* Add and use get_shape() and get_element_type() on Input/Output

* Fix Output

* Formatting.

* Format.

* Use reference

* Convolution.
parent 122db5ff
......@@ -32,11 +32,9 @@
using namespace ngraph;
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type)
std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape)
{
std::shared_ptr<Node> zero =
std::make_shared<op::Constant>(tensor_view_type->get_element_type(), Shape{}, "0");
const Shape& shape = tensor_view_type->get_shape();
std::shared_ptr<Node> zero = std::make_shared<op::Constant>(element_type, Shape{}, "0");
if (shape.size() > 0)
{
AxisSet axes;
......@@ -114,7 +112,8 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto result = make_zero(x->get_outputs().at(0).get_tensor_view_type());
auto& output = x->get_outputs().at(0);
auto result = make_zero(output.get_element_type(), output.get_shape());
adjoint_it = m_adjoint_map.insert({x.get(), result}).first;
}
return adjoint_it->second;
......@@ -152,7 +151,8 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x,
auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it)
{
auto zeros = make_zero(x->get_outputs().at(0).get_tensor_view_type());
auto& output = x->get_outputs().at(0);
auto zeros = make_zero(output.get_element_type(), output.get_shape());
m_adjoint_map.insert({x.get(),
std::make_shared<op::ReplaceSlice>(
zeros, delta, lower_bounds, upper_bounds, strides)});
......
......@@ -15,6 +15,7 @@
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph;
using namespace descriptor;
......@@ -65,3 +66,13 @@ std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{
return m_output->get_tensor_view()->get_tensor_view_type();
}
const Shape& Input::get_shape() const
{
return get_tensor_view_type()->get_shape();
}
const element::Type& Input::get_element_type() const
{
return get_tensor_view_type()->get_element_type();
}
......@@ -61,6 +61,7 @@ namespace ngraph
void replace_output(Output& output);
protected:
/// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const;
......@@ -70,6 +71,13 @@ namespace ngraph
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
public:
/// @return the shape of the connected output
const Shape& get_shape() const;
/// @return the element type of the connected output
const element::Type& get_element_type() const;
protected:
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
......
......@@ -52,3 +52,18 @@ Tensor& Output::get_tensor()
{
return m_tensor_view->get_tensor();
}
std::shared_ptr<const TensorViewType> Output::get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
const Shape& Output::get_shape() const
{
return get_tensor_view_type()->get_shape();
}
const element::Type& Output::get_element_type() const
{
return get_tensor_view_type()->get_element_type();
}
......@@ -48,11 +48,16 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
protected:
/// @return the tensor view type for the output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
public:
/// @return the shape of the output
const Shape& get_shape() const;
/// @return the element type of the output
const element::Type& get_element_type() const;
protected:
Node* m_node;
......
......@@ -62,6 +62,11 @@ void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) cons
}
}
void Node::set_value_type_checked(const element::Type& element_type, const Shape& shape)
{
set_value_type_checked(make_shared<TensorViewType>(element_type, shape));
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{
if (nullptr == m_value_type)
......
......@@ -87,6 +87,7 @@ namespace ngraph
// This is used when the framework specifies a value type for the value, and we
// independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<const ValueType>& value_type);
void set_value_type_checked(const element::Type& element_type, const Shape& shape);
bool is_parameter() const;
bool is_output() const;
......
......@@ -28,16 +28,15 @@ op::BinaryElementwise::BinaryElementwise(
const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape())
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
if (input_0.get_shape() != input_1.get_shape())
{
throw ngraph_error("Arguments must have the same tensor view shape");
}
const element::Type& result_element_type = element_type_function(
arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type());
const element::Type& result_element_type =
element_type_function(input_0.get_element_type(), input_1.get_element_type());
set_value_type_checked(
make_shared<TensorViewType>(result_element_type, arg0_tensor_type->get_shape()));
set_value_type_checked(make_shared<TensorViewType>(result_element_type, input_0.get_shape()));
}
......@@ -25,18 +25,17 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type();
auto& input = m_inputs.at(0);
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->get_shape())
if (Shape{target_shape} != input.get_shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
set_value_type_checked(make_shared<TensorViewType>(input.get_element_type(), m_shape));
}
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -30,47 +30,47 @@ op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
throw ngraph_error("At least one argument required");
}
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg0_shape = arg0_tensor_view_type->get_shape();
if (m_concatenation_axis >= arg0_shape.size())
auto& input_0 = get_inputs().at(0);
auto input_0_shape = input_0.get_shape();
if (m_concatenation_axis >= input_0_shape.size())
{
throw ngraph_error("Concatenation axis is out of bounds");
}
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
size_t concatenation_axis_length = input_0_shape.at(m_concatenation_axis);
auto& input_0_element_type = input_0.get_element_type();
for (auto i = 1; i < get_inputs().size(); i++)
{
auto argi_tensor_view_type = get_inputs().at(i).get_tensor_view_type();
auto argi_shape = argi_tensor_view_type->get_shape();
if (argi_shape.size() != arg0_shape.size())
auto& input_i = get_inputs().at(i);
auto input_i_shape = input_i.get_shape();
if (input_i_shape.size() != input_0_shape.size())
{
throw ngraph_error("Arguments to concat do not have same rank");
}
if (argi_tensor_view_type->get_element_type() != arg0_element_type)
if (input_i.get_element_type() != input_0_element_type)
{
throw ngraph_error("Argument element types do not match");
}
for (auto j = 0; j < argi_shape.size(); j++)
for (auto j = 0; j < input_i_shape.size(); j++)
{
if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j))
if (j != m_concatenation_axis && input_0_shape.at(j) != input_i_shape.at(j))
{
throw ngraph_error(
"Arguments to concat do not have same dimension on a non-concatenation axis");
}
else if (j == m_concatenation_axis)
{
concatenation_axis_length += argi_shape.at(j);
concatenation_axis_length += input_i_shape.at(j);
}
}
}
vector<size_t> concatenated_shape = arg0_shape;
vector<size_t> concatenated_shape = input_0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
set_value_type_checked(make_shared<TensorViewType>(arg0_element_type, concatenated_shape));
set_value_type_checked(make_shared<TensorViewType>(input_0_element_type, concatenated_shape));
}
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
......
......@@ -26,11 +26,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
{
auto image_batch_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& image_batch_shape = image_batch_tensor_view_type->get_shape();
auto filters_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
auto& filters_shape = filters_tensor_view_type->get_shape();
auto& image_batch_shape = get_inputs().at(0).get_shape();
auto& filters_shape = get_inputs().at(1).get_shape();
//
// Make sure image_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0.
......@@ -157,8 +154,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
result_shape[1] = m_output_channel_count;
std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2);
set_value_type_checked(make_shared<TensorViewType>(
image_batch_tensor_view_type->get_element_type(), result_shape));
set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
}
Strides default_strides(const std::shared_ptr<Node>& image_batch)
......
......@@ -70,44 +70,44 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0,
: RequiresTensorViewArgs("Dot", {arg0, arg1})
, m_reduction_axes_count(reduction_axes_count)
{
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
if (input_0.get_element_type() != input_1.get_element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
Shape input_0_shape = input_0.get_shape();
Shape input_1_shape = input_1.get_shape();
if (reduction_axes_count > arg0_shape.size())
if (reduction_axes_count > input_0_shape.size())
{
throw ngraph_error("Dot has too many axes for arg0");
}
if (reduction_axes_count > arg1_shape.size())
if (reduction_axes_count > input_1_shape.size())
{
throw ngraph_error("Dot has too many axes for arg1");
}
for (size_t i = 0; i < reduction_axes_count; i++)
{
if (arg0_shape[arg0_shape.size() - reduction_axes_count + i] != arg1_shape[i])
if (input_0_shape[input_0_shape.size() - reduction_axes_count + i] != input_1_shape[i])
{
throw ngraph_error("Dot axes do not have same length");
}
}
vector<size_t> result_shape(arg0_shape.size() + arg1_shape.size() - 2 * reduction_axes_count);
Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * reduction_axes_count);
std::copy(arg0_shape.begin(), arg0_shape.end() - reduction_axes_count, result_shape.begin());
std::copy(arg1_shape.begin() + reduction_axes_count,
arg1_shape.end(),
result_shape.begin() + (arg0_shape.size() - reduction_axes_count));
std::copy(
input_0_shape.begin(), input_0_shape.end() - reduction_axes_count, result_shape.begin());
std::copy(input_1_shape.begin() + reduction_axes_count,
input_1_shape.end(),
result_shape.begin() + (input_0_shape.size() - reduction_axes_count));
auto result_type =
make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
auto result_type = make_shared<TensorViewType>(input_0.get_element_type(), result_shape);
set_value_type_checked(result_type);
}
......
......@@ -33,5 +33,6 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
throw ngraph_error("Indexing tuple beyond its size");
}
set_value_type_checked(arg->get_outputs().at(n).get_tensor_view_type());
auto& output = arg->get_outputs().at(n);
set_value_type_checked(output.get_element_type(), output.get_shape());
}
......@@ -23,21 +23,21 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t
, m_shape(shape)
, m_one_hot_axis(one_hot_axis)
{
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
auto& input = m_inputs.at(0);
auto& input_element_type = input.get_element_type();
if (one_hot_axis >= shape.size())
{
throw ngraph_error("One-hot axis is out of bounds");
}
auto expected_arg_shape = shape;
expected_arg_shape.erase(expected_arg_shape.begin() + one_hot_axis);
auto expected_input_shape = shape;
expected_input_shape.erase(expected_input_shape.begin() + one_hot_axis);
if (arg_tensor_view_type->get_shape() != expected_arg_shape)
if (input.get_shape() != expected_input_shape)
{
throw ngraph_error("One-hot argument shape is not compatible with desired output shape");
}
set_value_type_checked(make_shared<TensorViewType>(arg_element_type, shape));
set_value_type_checked(make_shared<TensorViewType>(input_element_type, shape));
}
......@@ -26,25 +26,24 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
{
auto arg_reductee_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& input_reductee = get_inputs().at(0);
auto arg_init_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
if (arg_init_tensor_view_type->get_shape().size() != 0)
auto& input_init = get_inputs().at(1);
if (input_init.get_shape().size() != 0)
{
throw ngraph_error("Argument for initial value is not a scalar");
}
if (arg_init_tensor_view_type->get_element_type() !=
arg_reductee_tensor_view_type->get_element_type())
if (input_init.get_element_type() != input_reductee.get_element_type())
{
throw ngraph_error("Element types for reductee and initial values do not match");
}
auto arg_reductee_shape = arg_reductee_tensor_view_type->get_shape();
auto input_reductee_shape = input_reductee.get_shape();
for (auto axis : m_reduction_axes)
{
if (axis >= arg_reductee_shape.size())
if (axis >= input_reductee_shape.size())
{
throw ngraph_error("Reduction axis is out of bounds");
}
......@@ -52,11 +51,11 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
Shape result_shape;
for (size_t i = 0; i < arg_reductee_shape.size(); i++)
for (size_t i = 0; i < input_reductee_shape.size(); i++)
{
if (m_reduction_axes.count(i) == 0)
{
result_shape.push_back(arg_reductee_shape.at(i));
result_shape.push_back(input_reductee_shape.at(i));
}
}
......@@ -87,6 +86,6 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
throw ngraph_error("Return type from reduction function does not match expected");
}
set_value_type_checked(make_shared<TensorViewType>(
arg_reductee_tensor_view_type->get_element_type(), result_shape));
set_value_type_checked(
make_shared<TensorViewType>(input_reductee.get_element_type(), result_shape));
}
......@@ -46,37 +46,37 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0,
void op::ReplaceSlice::check_args()
{
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg0_shape = arg0_tensor_view_type->get_shape();
auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
auto& input_0 = get_inputs().at(0);
auto& input_0_shape = input_0.get_shape();
auto& input_0_element_type = input_0.get_element_type();
auto arg1_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
auto& arg1_shape = arg1_tensor_view_type->get_shape();
auto& arg1_element_type = arg1_tensor_view_type->get_element_type();
auto& input_1 = get_inputs().at(1);
auto& input_1_shape = input_1.get_shape();
auto& input_1_element_type = input_1.get_element_type();
if (arg0_shape.size() != arg1_shape.size())
if (input_0_shape.size() != input_1_shape.size())
{
throw ngraph_error("Replace-slice argument ranks do not match");
}
if (arg0_element_type != arg1_element_type)
if (input_0_element_type != input_1_element_type)
{
throw ngraph_error("Element types for replace-slice arguments do not match");
}
if (m_lower_bounds.size() != arg0_shape.size())
if (m_lower_bounds.size() != input_0_shape.size())
{
throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes");
}
if (m_upper_bounds.size() != arg0_shape.size())
if (m_upper_bounds.size() != input_0_shape.size())
{
throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes");
}
if (m_strides.size() != arg0_shape.size())
if (m_strides.size() != input_0_shape.size())
{
throw ngraph_error(
"Number of strides provided for slice does not match number of input axes");
......@@ -84,9 +84,9 @@ void op::ReplaceSlice::check_args()
Shape slice_shape;
for (size_t i = 0; i < arg0_shape.size(); i++)
for (size_t i = 0; i < input_0_shape.size(); i++)
{
if (m_upper_bounds[i] > arg0_shape[i])
if (m_upper_bounds[i] > input_0_shape[i])
{
throw ngraph_error("Upper bound for slice is out of range");
}
......@@ -107,12 +107,12 @@ void op::ReplaceSlice::check_args()
slice_shape.push_back(slice_axis_size);
}
if (arg1_shape != slice_shape)
if (input_1_shape != slice_shape)
{
throw ngraph_error("Shape of replacement tensor does not match slice shape");
}
set_value_type_checked(arg0_tensor_view_type);
set_value_type_checked(input_0_element_type, input_0_shape);
}
void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
......@@ -121,8 +121,8 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
auto x = get_inputs().at(0).get_output().get_node();
auto& y_input = get_inputs().at(1);
auto y = y_input.get_output().get_node();
auto& y_element_type = y_input.get_tensor_view_type()->get_element_type();
auto y_shape = y_input.get_tensor_view_type()->get_shape();
auto& y_element_type = y_input.get_element_type();
auto y_shape = y_input.get_shape();
auto zeros_shaped_like_y = std::make_shared<op::Constant>(y_element_type, y_shape, "0");
......
......@@ -27,16 +27,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
, m_input_order(input_order)
, m_output_shape(output_shape)
{
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
auto& input = get_inputs().at(0);
auto input_shape = input.get_shape();
auto input_rank = input_shape.size();
if (m_input_order.size() != arg_rank)
if (m_input_order.size() != input_rank)
{
throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes");
}
for (size_t i = 0; i < arg_rank; i++)
for (size_t i = 0; i < input_rank; i++)
{
auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i);
if (std::end(m_input_order) == it)
......@@ -46,10 +46,10 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
}
}
size_t arg_shape_product = 1;
for (auto i : arg_shape)
size_t input_shape_product = 1;
for (auto i : input_shape)
{
arg_shape_product *= i;
input_shape_product *= i;
}
size_t output_shape_product = 1;
......@@ -58,15 +58,14 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
output_shape_product *= i;
}
if (arg_shape_product != output_shape_product)
if (input_shape_product != output_shape_product)
{
throw ngraph_error(
"Product of output shape dimensions does not match product of argument shape "
"dimensions for reshape");
}
set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
set_value_type_checked(input.get_element_type(), m_output_shape);
}
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -29,25 +29,24 @@ op::Select::Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg2)
: RequiresTensorViewArgs("Select", Nodes{arg0, arg1, arg2})
{
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
auto arg2_tensor_type = get_inputs().at(2).get_tensor_view_type();
auto& input_0 = get_inputs().at(0);
auto& input_1 = get_inputs().at(1);
auto& input_2 = get_inputs().at(2);
if (arg0_tensor_type->get_element_type() != element::Bool::element_type())
if (input_0.get_element_type() != element::Bool::element_type())
{
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
}
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() ||
arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape())
if (input_0.get_shape() != input_1.get_shape() || input_0.get_shape() != input_2.get_shape())
{
throw ngraph_error("Arguments must have the same tensor view shape");
throw ngraph_error("Arguments must have the same shape");
}
if (*arg1_tensor_type != *arg2_tensor_type)
if (input_1.get_element_type() != input_2.get_element_type())
{
throw ngraph_error("Arguments 1 and 2 must have the same tensor view type");
throw ngraph_error("Arguments 1 and 2 must have the same element type");
}
set_value_type_checked(arg1_tensor_type);
set_value_type_checked(input_1.get_element_type(), input_1.get_shape());
}
void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints,
......
......@@ -42,22 +42,22 @@ op::Slice::Slice(const std::shared_ptr<Node>& arg,
void op::Slice::check_args()
{
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_shape = arg_tensor_view_type->get_shape();
auto& input = get_inputs().at(0);
auto& input_shape = input.get_shape();
if (m_lower_bounds.size() != arg_shape.size())
if (m_lower_bounds.size() != input_shape.size())
{
throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes");
}
if (m_upper_bounds.size() != arg_shape.size())
if (m_upper_bounds.size() != input_shape.size())
{
throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes");
}
if (m_strides.size() != arg_shape.size())
if (m_strides.size() != input_shape.size())
{
throw ngraph_error(
"Number of strides provided for slice does not match number of input axes");
......@@ -65,9 +65,9 @@ void op::Slice::check_args()
Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++)
for (size_t i = 0; i < input_shape.size(); i++)
{
if (m_upper_bounds[i] > arg_shape[i])
if (m_upper_bounds[i] > input_shape[i])
{
throw ngraph_error("Upper bound for slice is out of range");
}
......@@ -88,8 +88,7 @@ void op::Slice::check_args()
result_shape.push_back(result_axis_size);
}
set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
set_value_type_checked(input.get_element_type(), result_shape);
}
void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
......
......@@ -23,18 +23,18 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: RequiresTensorViewArgs("Sum", {arg})
, m_reduction_axes(reduction_axes)
{
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
if (arg_element_type == element::Bool::element_type())
auto& input = get_inputs().at(0);
auto& input_element_type = input.get_element_type();
if (input_element_type == element::Bool::element_type())
{
throw ngraph_error("Argument for sum must have numeric element type");
}
auto arg_shape = arg_tensor_view_type->get_shape();
auto input_shape = input.get_shape();
for (auto axis : m_reduction_axes)
{
if (axis >= arg_shape.size())
if (axis >= input_shape.size())
{
throw ngraph_error("Reduction axis for sum is out of bounds");
}
......@@ -42,22 +42,21 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++)
for (size_t i = 0; i < input_shape.size(); i++)
{
if (m_reduction_axes.count(i) == 0)
{
result_shape.push_back(arg_shape.at(i));
result_shape.push_back(input_shape.at(i));
}
}
set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
set_value_type_checked(input.get_element_type(), result_shape);
}
void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
auto x = get_inputs().at(0).get_output().get_node();
auto& x_shape = get_inputs().at(0).get_tensor_view_type()->get_shape();
auto& x_shape = get_inputs().at(0).get_shape();
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes));
}
......@@ -25,10 +25,8 @@ op::UnaryElementwise::UnaryElementwise(
const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg})
{
auto arg_tensor_type = get_inputs().at(0).get_tensor_view_type();
const element::Type& result_element_type =
element_type_function(arg_tensor_type->get_element_type());
auto& input = get_inputs().at(0);
const element::Type& result_element_type = element_type_function(input.get_element_type());
set_value_type_checked(
make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
set_value_type_checked(result_element_type, input.get_shape());
}
......@@ -78,9 +78,8 @@ void runtime::interpreter::INT_CallFrame::call(
if (!contains_key(tensor_map, name))
{
// The output tensor is not in the tensor map so create a new tensor
const Shape& shape = output.get_tensor_view_type()->get_shape();
const element::Type& element_type =
output.get_tensor_view_type()->get_element_type();
const Shape& shape = output.get_shape();
const element::Type& element_type = output.get_element_type();
string tensor_name = output.get_tensor().get_name();
itv = make_shared<runtime::interpreter::INT_TensorView>(
element_type, shape, tensor_name);
......
......@@ -559,7 +559,7 @@ TEST(type_prop, select_shape_mismatch_a)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
}
catch (...)
{
......@@ -583,7 +583,7 @@ TEST(type_prop, select_shape_mismatch_b)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
}
catch (...)
{
......@@ -607,7 +607,7 @@ TEST(type_prop, select_shape_mismatch_c)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
}
catch (...)
{
......@@ -657,8 +657,7 @@ TEST(type_prop, select_elem_mismatch_bc)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("Arguments 1 and 2 must have the same tensor view type"));
EXPECT_EQ(error.what(), std::string("Arguments 1 and 2 must have the same element type"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment