Unverified Commit 8a569f27 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/shape (#310)

* Add and use get_shape() and get_element_type() on Input/Output

* Fix Output

* Formatting.

* Format.

* Use reference

* Convolution.
parent 122db5ff
...@@ -32,11 +32,9 @@ ...@@ -32,11 +32,9 @@
using namespace ngraph; using namespace ngraph;
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type) std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape)
{ {
std::shared_ptr<Node> zero = std::shared_ptr<Node> zero = std::make_shared<op::Constant>(element_type, Shape{}, "0");
std::make_shared<op::Constant>(tensor_view_type->get_element_type(), Shape{}, "0");
const Shape& shape = tensor_view_type->get_shape();
if (shape.size() > 0) if (shape.size() > 0)
{ {
AxisSet axes; AxisSet axes;
...@@ -114,7 +112,8 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x) ...@@ -114,7 +112,8 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto adjoint_it = m_adjoint_map.find(x.get()); auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it) if (m_adjoint_map.end() == adjoint_it)
{ {
auto result = make_zero(x->get_outputs().at(0).get_tensor_view_type()); auto& output = x->get_outputs().at(0);
auto result = make_zero(output.get_element_type(), output.get_shape());
adjoint_it = m_adjoint_map.insert({x.get(), result}).first; adjoint_it = m_adjoint_map.insert({x.get(), result}).first;
} }
return adjoint_it->second; return adjoint_it->second;
...@@ -152,7 +151,8 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x, ...@@ -152,7 +151,8 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x,
auto adjoint_it = m_adjoint_map.find(x.get()); auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it) if (m_adjoint_map.end() == adjoint_it)
{ {
auto zeros = make_zero(x->get_outputs().at(0).get_tensor_view_type()); auto& output = x->get_outputs().at(0);
auto zeros = make_zero(output.get_element_type(), output.get_shape());
m_adjoint_map.insert({x.get(), m_adjoint_map.insert({x.get(),
std::make_shared<op::ReplaceSlice>( std::make_shared<op::ReplaceSlice>(
zeros, delta, lower_bounds, upper_bounds, strides)}); zeros, delta, lower_bounds, upper_bounds, strides)});
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
...@@ -65,3 +66,13 @@ std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const ...@@ -65,3 +66,13 @@ std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{ {
return m_output->get_tensor_view()->get_tensor_view_type(); return m_output->get_tensor_view()->get_tensor_view_type();
} }
const Shape& Input::get_shape() const
{
return get_tensor_view_type()->get_shape();
}
const element::Type& Input::get_element_type() const
{
return get_tensor_view_type()->get_element_type();
}
...@@ -61,6 +61,7 @@ namespace ngraph ...@@ -61,6 +61,7 @@ namespace ngraph
void replace_output(Output& output); void replace_output(Output& output);
protected:
/// @return the tensor view for the connected output /// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const; std::shared_ptr<const TensorView> get_tensor_view() const;
...@@ -70,6 +71,13 @@ namespace ngraph ...@@ -70,6 +71,13 @@ namespace ngraph
/// @return the tensor view type for the connected output /// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const; std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
public:
/// @return the shape of the connected output
const Shape& get_shape() const;
/// @return the element type of the connected output
const element::Type& get_element_type() const;
protected: protected:
Node* m_node; // The node we are an input for Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors size_t m_index; // Index into all input tensors
......
...@@ -52,3 +52,18 @@ Tensor& Output::get_tensor() ...@@ -52,3 +52,18 @@ Tensor& Output::get_tensor()
{ {
return m_tensor_view->get_tensor(); return m_tensor_view->get_tensor();
} }
std::shared_ptr<const TensorViewType> Output::get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
const Shape& Output::get_shape() const
{
return get_tensor_view_type()->get_shape();
}
const element::Type& Output::get_element_type() const
{
return get_tensor_view_type()->get_element_type();
}
...@@ -48,11 +48,16 @@ namespace ngraph ...@@ -48,11 +48,16 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
Tensor& get_tensor(); Tensor& get_tensor();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const protected:
{ /// @return the tensor view type for the output
return get_tensor_view()->get_tensor_view_type(); std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
}
public:
/// @return the shape of the output
const Shape& get_shape() const;
/// @return the element type of the output
const element::Type& get_element_type() const;
protected: protected:
Node* m_node; Node* m_node;
......
...@@ -62,6 +62,11 @@ void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) cons ...@@ -62,6 +62,11 @@ void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) cons
} }
} }
void Node::set_value_type_checked(const element::Type& element_type, const Shape& shape)
{
set_value_type_checked(make_shared<TensorViewType>(element_type, shape));
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{ {
if (nullptr == m_value_type) if (nullptr == m_value_type)
......
...@@ -87,6 +87,7 @@ namespace ngraph ...@@ -87,6 +87,7 @@ namespace ngraph
// This is used when the framework specifies a value type for the value, and we // This is used when the framework specifies a value type for the value, and we
// independently compute what we thing the value type should be from the arguments. // independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<const ValueType>& value_type); void set_value_type_checked(const std::shared_ptr<const ValueType>& value_type);
void set_value_type_checked(const element::Type& element_type, const Shape& shape);
bool is_parameter() const; bool is_parameter() const;
bool is_output() const; bool is_output() const;
......
...@@ -28,16 +28,15 @@ op::BinaryElementwise::BinaryElementwise( ...@@ -28,16 +28,15 @@ op::BinaryElementwise::BinaryElementwise(
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1}) : RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{ {
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape()) if (input_0.get_shape() != input_1.get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
const element::Type& result_element_type = element_type_function( const element::Type& result_element_type =
arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type()); element_type_function(input_0.get_element_type(), input_1.get_element_type());
set_value_type_checked( set_value_type_checked(make_shared<TensorViewType>(result_element_type, input_0.get_shape()));
make_shared<TensorViewType>(result_element_type, arg0_tensor_type->get_shape()));
} }
...@@ -25,18 +25,17 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg, ...@@ -25,18 +25,17 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type(); auto& input = m_inputs.at(0);
vector<size_t> target_shape = m_shape; vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{ {
target_shape.erase(target_shape.begin() + *i); target_shape.erase(target_shape.begin() + *i);
} }
if (Shape{target_shape} != arg_tensor_view_type->get_shape()) if (Shape{target_shape} != input.get_shape())
{ {
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
} }
set_value_type_checked( set_value_type_checked(make_shared<TensorViewType>(input.get_element_type(), m_shape));
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
} }
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -30,47 +30,47 @@ op::Concat::Concat(const Nodes& args, size_t concatenation_axis) ...@@ -30,47 +30,47 @@ op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
throw ngraph_error("At least one argument required"); throw ngraph_error("At least one argument required");
} }
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg0_shape = arg0_tensor_view_type->get_shape(); auto input_0_shape = input_0.get_shape();
if (m_concatenation_axis >= arg0_shape.size()) if (m_concatenation_axis >= input_0_shape.size())
{ {
throw ngraph_error("Concatenation axis is out of bounds"); throw ngraph_error("Concatenation axis is out of bounds");
} }
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis); size_t concatenation_axis_length = input_0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type(); auto& input_0_element_type = input_0.get_element_type();
for (auto i = 1; i < get_inputs().size(); i++) for (auto i = 1; i < get_inputs().size(); i++)
{ {
auto argi_tensor_view_type = get_inputs().at(i).get_tensor_view_type(); auto& input_i = get_inputs().at(i);
auto argi_shape = argi_tensor_view_type->get_shape(); auto input_i_shape = input_i.get_shape();
if (argi_shape.size() != arg0_shape.size()) if (input_i_shape.size() != input_0_shape.size())
{ {
throw ngraph_error("Arguments to concat do not have same rank"); throw ngraph_error("Arguments to concat do not have same rank");
} }
if (argi_tensor_view_type->get_element_type() != arg0_element_type) if (input_i.get_element_type() != input_0_element_type)
{ {
throw ngraph_error("Argument element types do not match"); throw ngraph_error("Argument element types do not match");
} }
for (auto j = 0; j < argi_shape.size(); j++) for (auto j = 0; j < input_i_shape.size(); j++)
{ {
if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j)) if (j != m_concatenation_axis && input_0_shape.at(j) != input_i_shape.at(j))
{ {
throw ngraph_error( throw ngraph_error(
"Arguments to concat do not have same dimension on a non-concatenation axis"); "Arguments to concat do not have same dimension on a non-concatenation axis");
} }
else if (j == m_concatenation_axis) else if (j == m_concatenation_axis)
{ {
concatenation_axis_length += argi_shape.at(j); concatenation_axis_length += input_i_shape.at(j);
} }
} }
} }
vector<size_t> concatenated_shape = arg0_shape; vector<size_t> concatenated_shape = input_0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length; concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
set_value_type_checked(make_shared<TensorViewType>(arg0_element_type, concatenated_shape)); set_value_type_checked(make_shared<TensorViewType>(input_0_element_type, concatenated_shape));
} }
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
......
...@@ -26,11 +26,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -26,11 +26,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
{ {
auto image_batch_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& image_batch_shape = get_inputs().at(0).get_shape();
auto& image_batch_shape = image_batch_tensor_view_type->get_shape(); auto& filters_shape = get_inputs().at(1).get_shape();
auto filters_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
auto& filters_shape = filters_tensor_view_type->get_shape();
// //
// Make sure image_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0. // Make sure image_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0.
...@@ -157,8 +154,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -157,8 +154,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
result_shape[1] = m_output_channel_count; result_shape[1] = m_output_channel_count;
std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2); std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2);
set_value_type_checked(make_shared<TensorViewType>( set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
image_batch_tensor_view_type->get_element_type(), result_shape));
} }
Strides default_strides(const std::shared_ptr<Node>& image_batch) Strides default_strides(const std::shared_ptr<Node>& image_batch)
......
...@@ -70,44 +70,44 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0, ...@@ -70,44 +70,44 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0,
: RequiresTensorViewArgs("Dot", {arg0, arg1}) : RequiresTensorViewArgs("Dot", {arg0, arg1})
, m_reduction_axes_count(reduction_axes_count) , m_reduction_axes_count(reduction_axes_count)
{ {
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type()) if (input_0.get_element_type() != input_1.get_element_type())
{ {
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
vector<size_t> arg0_shape = arg0_tensor_type->get_shape(); Shape input_0_shape = input_0.get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); Shape input_1_shape = input_1.get_shape();
if (reduction_axes_count > arg0_shape.size()) if (reduction_axes_count > input_0_shape.size())
{ {
throw ngraph_error("Dot has too many axes for arg0"); throw ngraph_error("Dot has too many axes for arg0");
} }
if (reduction_axes_count > arg1_shape.size()) if (reduction_axes_count > input_1_shape.size())
{ {
throw ngraph_error("Dot has too many axes for arg1"); throw ngraph_error("Dot has too many axes for arg1");
} }
for (size_t i = 0; i < reduction_axes_count; i++) for (size_t i = 0; i < reduction_axes_count; i++)
{ {
if (arg0_shape[arg0_shape.size() - reduction_axes_count + i] != arg1_shape[i]) if (input_0_shape[input_0_shape.size() - reduction_axes_count + i] != input_1_shape[i])
{ {
throw ngraph_error("Dot axes do not have same length"); throw ngraph_error("Dot axes do not have same length");
} }
} }
vector<size_t> result_shape(arg0_shape.size() + arg1_shape.size() - 2 * reduction_axes_count); Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * reduction_axes_count);
std::copy(arg0_shape.begin(), arg0_shape.end() - reduction_axes_count, result_shape.begin()); std::copy(
std::copy(arg1_shape.begin() + reduction_axes_count, input_0_shape.begin(), input_0_shape.end() - reduction_axes_count, result_shape.begin());
arg1_shape.end(), std::copy(input_1_shape.begin() + reduction_axes_count,
result_shape.begin() + (arg0_shape.size() - reduction_axes_count)); input_1_shape.end(),
result_shape.begin() + (input_0_shape.size() - reduction_axes_count));
auto result_type = auto result_type = make_shared<TensorViewType>(input_0.get_element_type(), result_shape);
make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type); set_value_type_checked(result_type);
} }
......
...@@ -33,5 +33,6 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t ...@@ -33,5 +33,6 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
throw ngraph_error("Indexing tuple beyond its size"); throw ngraph_error("Indexing tuple beyond its size");
} }
set_value_type_checked(arg->get_outputs().at(n).get_tensor_view_type()); auto& output = arg->get_outputs().at(n);
set_value_type_checked(output.get_element_type(), output.get_shape());
} }
...@@ -23,21 +23,21 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t ...@@ -23,21 +23,21 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type(); auto& input = m_inputs.at(0);
auto& arg_element_type = arg_tensor_view_type->get_element_type(); auto& input_element_type = input.get_element_type();
if (one_hot_axis >= shape.size()) if (one_hot_axis >= shape.size())
{ {
throw ngraph_error("One-hot axis is out of bounds"); throw ngraph_error("One-hot axis is out of bounds");
} }
auto expected_arg_shape = shape; auto expected_input_shape = shape;
expected_arg_shape.erase(expected_arg_shape.begin() + one_hot_axis); expected_input_shape.erase(expected_input_shape.begin() + one_hot_axis);
if (arg_tensor_view_type->get_shape() != expected_arg_shape) if (input.get_shape() != expected_input_shape)
{ {
throw ngraph_error("One-hot argument shape is not compatible with desired output shape"); throw ngraph_error("One-hot argument shape is not compatible with desired output shape");
} }
set_value_type_checked(make_shared<TensorViewType>(arg_element_type, shape)); set_value_type_checked(make_shared<TensorViewType>(input_element_type, shape));
} }
...@@ -26,25 +26,24 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -26,25 +26,24 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
, m_reduction_function(reduction_function) , m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes) , m_reduction_axes(reduction_axes)
{ {
auto arg_reductee_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input_reductee = get_inputs().at(0);
auto arg_init_tensor_view_type = get_inputs().at(1).get_tensor_view_type(); auto& input_init = get_inputs().at(1);
if (arg_init_tensor_view_type->get_shape().size() != 0) if (input_init.get_shape().size() != 0)
{ {
throw ngraph_error("Argument for initial value is not a scalar"); throw ngraph_error("Argument for initial value is not a scalar");
} }
if (arg_init_tensor_view_type->get_element_type() != if (input_init.get_element_type() != input_reductee.get_element_type())
arg_reductee_tensor_view_type->get_element_type())
{ {
throw ngraph_error("Element types for reductee and initial values do not match"); throw ngraph_error("Element types for reductee and initial values do not match");
} }
auto arg_reductee_shape = arg_reductee_tensor_view_type->get_shape(); auto input_reductee_shape = input_reductee.get_shape();
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
if (axis >= arg_reductee_shape.size()) if (axis >= input_reductee_shape.size())
{ {
throw ngraph_error("Reduction axis is out of bounds"); throw ngraph_error("Reduction axis is out of bounds");
} }
...@@ -52,11 +51,11 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -52,11 +51,11 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < arg_reductee_shape.size(); i++) for (size_t i = 0; i < input_reductee_shape.size(); i++)
{ {
if (m_reduction_axes.count(i) == 0) if (m_reduction_axes.count(i) == 0)
{ {
result_shape.push_back(arg_reductee_shape.at(i)); result_shape.push_back(input_reductee_shape.at(i));
} }
} }
...@@ -87,6 +86,6 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -87,6 +86,6 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
throw ngraph_error("Return type from reduction function does not match expected"); throw ngraph_error("Return type from reduction function does not match expected");
} }
set_value_type_checked(make_shared<TensorViewType>( set_value_type_checked(
arg_reductee_tensor_view_type->get_element_type(), result_shape)); make_shared<TensorViewType>(input_reductee.get_element_type(), result_shape));
} }
...@@ -46,37 +46,37 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0, ...@@ -46,37 +46,37 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0,
void op::ReplaceSlice::check_args() void op::ReplaceSlice::check_args()
{ {
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto& arg0_shape = arg0_tensor_view_type->get_shape(); auto& input_0_shape = input_0.get_shape();
auto& arg0_element_type = arg0_tensor_view_type->get_element_type(); auto& input_0_element_type = input_0.get_element_type();
auto arg1_tensor_view_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
auto& arg1_shape = arg1_tensor_view_type->get_shape(); auto& input_1_shape = input_1.get_shape();
auto& arg1_element_type = arg1_tensor_view_type->get_element_type(); auto& input_1_element_type = input_1.get_element_type();
if (arg0_shape.size() != arg1_shape.size()) if (input_0_shape.size() != input_1_shape.size())
{ {
throw ngraph_error("Replace-slice argument ranks do not match"); throw ngraph_error("Replace-slice argument ranks do not match");
} }
if (arg0_element_type != arg1_element_type) if (input_0_element_type != input_1_element_type)
{ {
throw ngraph_error("Element types for replace-slice arguments do not match"); throw ngraph_error("Element types for replace-slice arguments do not match");
} }
if (m_lower_bounds.size() != arg0_shape.size()) if (m_lower_bounds.size() != input_0_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes"); "Number of lower bounds provided for slice does not match number of input axes");
} }
if (m_upper_bounds.size() != arg0_shape.size()) if (m_upper_bounds.size() != input_0_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes"); "Number of upper bounds provided for slice does not match number of input axes");
} }
if (m_strides.size() != arg0_shape.size()) if (m_strides.size() != input_0_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of strides provided for slice does not match number of input axes"); "Number of strides provided for slice does not match number of input axes");
...@@ -84,9 +84,9 @@ void op::ReplaceSlice::check_args() ...@@ -84,9 +84,9 @@ void op::ReplaceSlice::check_args()
Shape slice_shape; Shape slice_shape;
for (size_t i = 0; i < arg0_shape.size(); i++) for (size_t i = 0; i < input_0_shape.size(); i++)
{ {
if (m_upper_bounds[i] > arg0_shape[i]) if (m_upper_bounds[i] > input_0_shape[i])
{ {
throw ngraph_error("Upper bound for slice is out of range"); throw ngraph_error("Upper bound for slice is out of range");
} }
...@@ -107,12 +107,12 @@ void op::ReplaceSlice::check_args() ...@@ -107,12 +107,12 @@ void op::ReplaceSlice::check_args()
slice_shape.push_back(slice_axis_size); slice_shape.push_back(slice_axis_size);
} }
if (arg1_shape != slice_shape) if (input_1_shape != slice_shape)
{ {
throw ngraph_error("Shape of replacement tensor does not match slice shape"); throw ngraph_error("Shape of replacement tensor does not match slice shape");
} }
set_value_type_checked(arg0_tensor_view_type); set_value_type_checked(input_0_element_type, input_0_shape);
} }
void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -121,8 +121,8 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -121,8 +121,8 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto& y_input = get_inputs().at(1); auto& y_input = get_inputs().at(1);
auto y = y_input.get_output().get_node(); auto y = y_input.get_output().get_node();
auto& y_element_type = y_input.get_tensor_view_type()->get_element_type(); auto& y_element_type = y_input.get_element_type();
auto y_shape = y_input.get_tensor_view_type()->get_shape(); auto y_shape = y_input.get_shape();
auto zeros_shaped_like_y = std::make_shared<op::Constant>(y_element_type, y_shape, "0"); auto zeros_shaped_like_y = std::make_shared<op::Constant>(y_element_type, y_shape, "0");
......
...@@ -27,16 +27,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -27,16 +27,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
, m_input_order(input_order) , m_input_order(input_order)
, m_output_shape(output_shape) , m_output_shape(output_shape)
{ {
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
auto arg_shape = arg_tensor_view_type->get_shape(); auto input_shape = input.get_shape();
auto arg_rank = arg_shape.size(); auto input_rank = input_shape.size();
if (m_input_order.size() != arg_rank) if (m_input_order.size() != input_rank)
{ {
throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes"); throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes");
} }
for (size_t i = 0; i < arg_rank; i++) for (size_t i = 0; i < input_rank; i++)
{ {
auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i); auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i);
if (std::end(m_input_order) == it) if (std::end(m_input_order) == it)
...@@ -46,10 +46,10 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -46,10 +46,10 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
} }
} }
size_t arg_shape_product = 1; size_t input_shape_product = 1;
for (auto i : arg_shape) for (auto i : input_shape)
{ {
arg_shape_product *= i; input_shape_product *= i;
} }
size_t output_shape_product = 1; size_t output_shape_product = 1;
...@@ -58,15 +58,14 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -58,15 +58,14 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
output_shape_product *= i; output_shape_product *= i;
} }
if (arg_shape_product != output_shape_product) if (input_shape_product != output_shape_product)
{ {
throw ngraph_error( throw ngraph_error(
"Product of output shape dimensions does not match product of argument shape " "Product of output shape dimensions does not match product of argument shape "
"dimensions for reshape"); "dimensions for reshape");
} }
set_value_type_checked( set_value_type_checked(input.get_element_type(), m_output_shape);
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
} }
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -29,25 +29,24 @@ op::Select::Select(const std::shared_ptr<Node>& arg0, ...@@ -29,25 +29,24 @@ op::Select::Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg2) const std::shared_ptr<Node>& arg2)
: RequiresTensorViewArgs("Select", Nodes{arg0, arg1, arg2}) : RequiresTensorViewArgs("Select", Nodes{arg0, arg1, arg2})
{ {
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
auto arg2_tensor_type = get_inputs().at(2).get_tensor_view_type(); auto& input_2 = get_inputs().at(2);
if (arg0_tensor_type->get_element_type() != element::Bool::element_type()) if (input_0.get_element_type() != element::Bool::element_type())
{ {
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type"); throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
} }
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() || if (input_0.get_shape() != input_1.get_shape() || input_0.get_shape() != input_2.get_shape())
arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same shape");
} }
if (*arg1_tensor_type != *arg2_tensor_type) if (input_1.get_element_type() != input_2.get_element_type())
{ {
throw ngraph_error("Arguments 1 and 2 must have the same tensor view type"); throw ngraph_error("Arguments 1 and 2 must have the same element type");
} }
set_value_type_checked(arg1_tensor_type); set_value_type_checked(input_1.get_element_type(), input_1.get_shape());
} }
void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints, void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -42,22 +42,22 @@ op::Slice::Slice(const std::shared_ptr<Node>& arg, ...@@ -42,22 +42,22 @@ op::Slice::Slice(const std::shared_ptr<Node>& arg,
void op::Slice::check_args() void op::Slice::check_args()
{ {
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
auto& arg_shape = arg_tensor_view_type->get_shape(); auto& input_shape = input.get_shape();
if (m_lower_bounds.size() != arg_shape.size()) if (m_lower_bounds.size() != input_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes"); "Number of lower bounds provided for slice does not match number of input axes");
} }
if (m_upper_bounds.size() != arg_shape.size()) if (m_upper_bounds.size() != input_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes"); "Number of upper bounds provided for slice does not match number of input axes");
} }
if (m_strides.size() != arg_shape.size()) if (m_strides.size() != input_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of strides provided for slice does not match number of input axes"); "Number of strides provided for slice does not match number of input axes");
...@@ -65,9 +65,9 @@ void op::Slice::check_args() ...@@ -65,9 +65,9 @@ void op::Slice::check_args()
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++) for (size_t i = 0; i < input_shape.size(); i++)
{ {
if (m_upper_bounds[i] > arg_shape[i]) if (m_upper_bounds[i] > input_shape[i])
{ {
throw ngraph_error("Upper bound for slice is out of range"); throw ngraph_error("Upper bound for slice is out of range");
} }
...@@ -88,8 +88,7 @@ void op::Slice::check_args() ...@@ -88,8 +88,7 @@ void op::Slice::check_args()
result_shape.push_back(result_axis_size); result_shape.push_back(result_axis_size);
} }
set_value_type_checked( set_value_type_checked(input.get_element_type(), result_shape);
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
} }
void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
......
...@@ -23,18 +23,18 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -23,18 +23,18 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: RequiresTensorViewArgs("Sum", {arg}) : RequiresTensorViewArgs("Sum", {arg})
, m_reduction_axes(reduction_axes) , m_reduction_axes(reduction_axes)
{ {
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
auto& arg_element_type = arg_tensor_view_type->get_element_type(); auto& input_element_type = input.get_element_type();
if (arg_element_type == element::Bool::element_type()) if (input_element_type == element::Bool::element_type())
{ {
throw ngraph_error("Argument for sum must have numeric element type"); throw ngraph_error("Argument for sum must have numeric element type");
} }
auto arg_shape = arg_tensor_view_type->get_shape(); auto input_shape = input.get_shape();
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
if (axis >= arg_shape.size()) if (axis >= input_shape.size())
{ {
throw ngraph_error("Reduction axis for sum is out of bounds"); throw ngraph_error("Reduction axis for sum is out of bounds");
} }
...@@ -42,22 +42,21 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -42,22 +42,21 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++) for (size_t i = 0; i < input_shape.size(); i++)
{ {
if (m_reduction_axes.count(i) == 0) if (m_reduction_axes.count(i) == 0)
{ {
result_shape.push_back(arg_shape.at(i)); result_shape.push_back(input_shape.at(i));
} }
} }
set_value_type_checked( set_value_type_checked(input.get_element_type(), result_shape);
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
} }
void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{ {
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto& x_shape = get_inputs().at(0).get_tensor_view_type()->get_shape(); auto& x_shape = get_inputs().at(0).get_shape();
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes)); adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes));
} }
...@@ -25,10 +25,8 @@ op::UnaryElementwise::UnaryElementwise( ...@@ -25,10 +25,8 @@ op::UnaryElementwise::UnaryElementwise(
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg}) : RequiresTensorViewArgs(node_type, Nodes{arg})
{ {
auto arg_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
const element::Type& result_element_type = const element::Type& result_element_type = element_type_function(input.get_element_type());
element_type_function(arg_tensor_type->get_element_type());
set_value_type_checked( set_value_type_checked(result_element_type, input.get_shape());
make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
} }
...@@ -78,9 +78,8 @@ void runtime::interpreter::INT_CallFrame::call( ...@@ -78,9 +78,8 @@ void runtime::interpreter::INT_CallFrame::call(
if (!contains_key(tensor_map, name)) if (!contains_key(tensor_map, name))
{ {
// The output tensor is not in the tensor map so create a new tensor // The output tensor is not in the tensor map so create a new tensor
const Shape& shape = output.get_tensor_view_type()->get_shape(); const Shape& shape = output.get_shape();
const element::Type& element_type = const element::Type& element_type = output.get_element_type();
output.get_tensor_view_type()->get_element_type();
string tensor_name = output.get_tensor().get_name(); string tensor_name = output.get_tensor().get_name();
itv = make_shared<runtime::interpreter::INT_TensorView>( itv = make_shared<runtime::interpreter::INT_TensorView>(
element_type, shape, tensor_name); element_type, shape, tensor_name);
......
...@@ -559,7 +559,7 @@ TEST(type_prop, select_shape_mismatch_a) ...@@ -559,7 +559,7 @@ TEST(type_prop, select_shape_mismatch_a)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape")); EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
} }
catch (...) catch (...)
{ {
...@@ -583,7 +583,7 @@ TEST(type_prop, select_shape_mismatch_b) ...@@ -583,7 +583,7 @@ TEST(type_prop, select_shape_mismatch_b)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape")); EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
} }
catch (...) catch (...)
{ {
...@@ -607,7 +607,7 @@ TEST(type_prop, select_shape_mismatch_c) ...@@ -607,7 +607,7 @@ TEST(type_prop, select_shape_mismatch_c)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape")); EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
} }
catch (...) catch (...)
{ {
...@@ -657,8 +657,7 @@ TEST(type_prop, select_elem_mismatch_bc) ...@@ -657,8 +657,7 @@ TEST(type_prop, select_elem_mismatch_bc)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), EXPECT_EQ(error.what(), std::string("Arguments 1 and 2 must have the same element type"));
std::string("Arguments 1 and 2 must have the same tensor view type"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment