Unverified Commit 8a569f27 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/shape (#310)

* Add and use get_shape() and get_element_type() on Input/Output

* Fix Output

* Formatting.

* Format.

* Use reference

* Convolution.
parent 122db5ff
...@@ -32,11 +32,9 @@ ...@@ -32,11 +32,9 @@
using namespace ngraph; using namespace ngraph;
std::shared_ptr<Node> make_zero(const std::shared_ptr<const TensorViewType>& tensor_view_type) std::shared_ptr<Node> make_zero(const element::Type& element_type, const Shape& shape)
{ {
std::shared_ptr<Node> zero = std::shared_ptr<Node> zero = std::make_shared<op::Constant>(element_type, Shape{}, "0");
std::make_shared<op::Constant>(tensor_view_type->get_element_type(), Shape{}, "0");
const Shape& shape = tensor_view_type->get_shape();
if (shape.size() > 0) if (shape.size() > 0)
{ {
AxisSet axes; AxisSet axes;
...@@ -114,7 +112,8 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x) ...@@ -114,7 +112,8 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
auto adjoint_it = m_adjoint_map.find(x.get()); auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it) if (m_adjoint_map.end() == adjoint_it)
{ {
auto result = make_zero(x->get_outputs().at(0).get_tensor_view_type()); auto& output = x->get_outputs().at(0);
auto result = make_zero(output.get_element_type(), output.get_shape());
adjoint_it = m_adjoint_map.insert({x.get(), result}).first; adjoint_it = m_adjoint_map.insert({x.get(), result}).first;
} }
return adjoint_it->second; return adjoint_it->second;
...@@ -152,7 +151,8 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x, ...@@ -152,7 +151,8 @@ void autodiff::Adjoints::add_delta_to_slice(const std::shared_ptr<Node>& x,
auto adjoint_it = m_adjoint_map.find(x.get()); auto adjoint_it = m_adjoint_map.find(x.get());
if (m_adjoint_map.end() == adjoint_it) if (m_adjoint_map.end() == adjoint_it)
{ {
auto zeros = make_zero(x->get_outputs().at(0).get_tensor_view_type()); auto& output = x->get_outputs().at(0);
auto zeros = make_zero(output.get_element_type(), output.get_shape());
m_adjoint_map.insert({x.get(), m_adjoint_map.insert({x.get(),
std::make_shared<op::ReplaceSlice>( std::make_shared<op::ReplaceSlice>(
zeros, delta, lower_bounds, upper_bounds, strides)}); zeros, delta, lower_bounds, upper_bounds, strides)});
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
...@@ -65,3 +66,13 @@ std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const ...@@ -65,3 +66,13 @@ std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{ {
return m_output->get_tensor_view()->get_tensor_view_type(); return m_output->get_tensor_view()->get_tensor_view_type();
} }
const Shape& Input::get_shape() const
{
return get_tensor_view_type()->get_shape();
}
const element::Type& Input::get_element_type() const
{
return get_tensor_view_type()->get_element_type();
}
...@@ -61,6 +61,7 @@ namespace ngraph ...@@ -61,6 +61,7 @@ namespace ngraph
void replace_output(Output& output); void replace_output(Output& output);
protected:
/// @return the tensor view for the connected output /// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const; std::shared_ptr<const TensorView> get_tensor_view() const;
...@@ -70,6 +71,13 @@ namespace ngraph ...@@ -70,6 +71,13 @@ namespace ngraph
/// @return the tensor view type for the connected output /// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const; std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
public:
/// @return the shape of the connected output
const Shape& get_shape() const;
/// @return the element type of the connected output
const element::Type& get_element_type() const;
protected: protected:
Node* m_node; // The node we are an input for Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors size_t m_index; // Index into all input tensors
......
...@@ -52,3 +52,18 @@ Tensor& Output::get_tensor() ...@@ -52,3 +52,18 @@ Tensor& Output::get_tensor()
{ {
return m_tensor_view->get_tensor(); return m_tensor_view->get_tensor();
} }
std::shared_ptr<const TensorViewType> Output::get_tensor_view_type() const
{
return get_tensor_view()->get_tensor_view_type();
}
const Shape& Output::get_shape() const
{
return get_tensor_view_type()->get_shape();
}
const element::Type& Output::get_element_type() const
{
return get_tensor_view_type()->get_element_type();
}
...@@ -48,11 +48,16 @@ namespace ngraph ...@@ -48,11 +48,16 @@ namespace ngraph
const std::set<Input*>& get_inputs() const { return m_inputs; } const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const; const Tensor& get_tensor() const;
Tensor& get_tensor(); Tensor& get_tensor();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const protected:
{ /// @return the tensor view type for the output
return get_tensor_view()->get_tensor_view_type(); std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
}
public:
/// @return the shape of the output
const Shape& get_shape() const;
/// @return the element type of the output
const element::Type& get_element_type() const;
protected: protected:
Node* m_node; Node* m_node;
......
...@@ -62,6 +62,11 @@ void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) cons ...@@ -62,6 +62,11 @@ void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) cons
} }
} }
void Node::set_value_type_checked(const element::Type& element_type, const Shape& shape)
{
set_value_type_checked(make_shared<TensorViewType>(element_type, shape));
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type) void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{ {
if (nullptr == m_value_type) if (nullptr == m_value_type)
......
...@@ -87,6 +87,7 @@ namespace ngraph ...@@ -87,6 +87,7 @@ namespace ngraph
// This is used when the framework specifies a value type for the value, and we // This is used when the framework specifies a value type for the value, and we
// independently compute what we thing the value type should be from the arguments. // independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<const ValueType>& value_type); void set_value_type_checked(const std::shared_ptr<const ValueType>& value_type);
void set_value_type_checked(const element::Type& element_type, const Shape& shape);
bool is_parameter() const; bool is_parameter() const;
bool is_output() const; bool is_output() const;
......
...@@ -28,16 +28,15 @@ op::BinaryElementwise::BinaryElementwise( ...@@ -28,16 +28,15 @@ op::BinaryElementwise::BinaryElementwise(
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1}) : RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{ {
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape()) if (input_0.get_shape() != input_1.get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
const element::Type& result_element_type = element_type_function( const element::Type& result_element_type =
arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type()); element_type_function(input_0.get_element_type(), input_1.get_element_type());
set_value_type_checked( set_value_type_checked(make_shared<TensorViewType>(result_element_type, input_0.get_shape()));
make_shared<TensorViewType>(result_element_type, arg0_tensor_type->get_shape()));
} }
...@@ -25,18 +25,17 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg, ...@@ -25,18 +25,17 @@ op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type(); auto& input = m_inputs.at(0);
vector<size_t> target_shape = m_shape; vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i) for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{ {
target_shape.erase(target_shape.begin() + *i); target_shape.erase(target_shape.begin() + *i);
} }
if (Shape{target_shape} != arg_tensor_view_type->get_shape()) if (Shape{target_shape} != input.get_shape())
{ {
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
} }
set_value_type_checked( set_value_type_checked(make_shared<TensorViewType>(input.get_element_type(), m_shape));
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
} }
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -30,47 +30,47 @@ op::Concat::Concat(const Nodes& args, size_t concatenation_axis) ...@@ -30,47 +30,47 @@ op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
throw ngraph_error("At least one argument required"); throw ngraph_error("At least one argument required");
} }
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg0_shape = arg0_tensor_view_type->get_shape(); auto input_0_shape = input_0.get_shape();
if (m_concatenation_axis >= arg0_shape.size()) if (m_concatenation_axis >= input_0_shape.size())
{ {
throw ngraph_error("Concatenation axis is out of bounds"); throw ngraph_error("Concatenation axis is out of bounds");
} }
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis); size_t concatenation_axis_length = input_0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type(); auto& input_0_element_type = input_0.get_element_type();
for (auto i = 1; i < get_inputs().size(); i++) for (auto i = 1; i < get_inputs().size(); i++)
{ {
auto argi_tensor_view_type = get_inputs().at(i).get_tensor_view_type(); auto& input_i = get_inputs().at(i);
auto argi_shape = argi_tensor_view_type->get_shape(); auto input_i_shape = input_i.get_shape();
if (argi_shape.size() != arg0_shape.size()) if (input_i_shape.size() != input_0_shape.size())
{ {
throw ngraph_error("Arguments to concat do not have same rank"); throw ngraph_error("Arguments to concat do not have same rank");
} }
if (argi_tensor_view_type->get_element_type() != arg0_element_type) if (input_i.get_element_type() != input_0_element_type)
{ {
throw ngraph_error("Argument element types do not match"); throw ngraph_error("Argument element types do not match");
} }
for (auto j = 0; j < argi_shape.size(); j++) for (auto j = 0; j < input_i_shape.size(); j++)
{ {
if (j != m_concatenation_axis && arg0_shape.at(j) != argi_shape.at(j)) if (j != m_concatenation_axis && input_0_shape.at(j) != input_i_shape.at(j))
{ {
throw ngraph_error( throw ngraph_error(
"Arguments to concat do not have same dimension on a non-concatenation axis"); "Arguments to concat do not have same dimension on a non-concatenation axis");
} }
else if (j == m_concatenation_axis) else if (j == m_concatenation_axis)
{ {
concatenation_axis_length += argi_shape.at(j); concatenation_axis_length += input_i_shape.at(j);
} }
} }
} }
vector<size_t> concatenated_shape = arg0_shape; vector<size_t> concatenated_shape = input_0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length; concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
set_value_type_checked(make_shared<TensorViewType>(arg0_element_type, concatenated_shape)); set_value_type_checked(make_shared<TensorViewType>(input_0_element_type, concatenated_shape));
} }
void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Concat::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
......
...@@ -26,11 +26,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -26,11 +26,8 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
{ {
auto image_batch_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& image_batch_shape = get_inputs().at(0).get_shape();
auto& image_batch_shape = image_batch_tensor_view_type->get_shape(); auto& filters_shape = get_inputs().at(1).get_shape();
auto filters_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
auto& filters_shape = filters_tensor_view_type->get_shape();
// //
// Make sure image_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0. // Make sure image_batch: NCiDi for some Di of rank>0, N != 0, Ci != 0.
...@@ -157,8 +154,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch, ...@@ -157,8 +154,7 @@ op::Convolution::Convolution(const std::shared_ptr<Node>& image_batch,
result_shape[1] = m_output_channel_count; result_shape[1] = m_output_channel_count;
std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2); std::copy(m_output_image_shape.begin(), m_output_image_shape.end(), result_shape.begin() + 2);
set_value_type_checked(make_shared<TensorViewType>( set_value_type_checked(get_inputs().at(0).get_element_type(), result_shape);
image_batch_tensor_view_type->get_element_type(), result_shape));
} }
Strides default_strides(const std::shared_ptr<Node>& image_batch) Strides default_strides(const std::shared_ptr<Node>& image_batch)
......
...@@ -70,44 +70,44 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0, ...@@ -70,44 +70,44 @@ op::Dot::Dot(const std::shared_ptr<Node>& arg0,
: RequiresTensorViewArgs("Dot", {arg0, arg1}) : RequiresTensorViewArgs("Dot", {arg0, arg1})
, m_reduction_axes_count(reduction_axes_count) , m_reduction_axes_count(reduction_axes_count)
{ {
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type()) if (input_0.get_element_type() != input_1.get_element_type())
{ {
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
vector<size_t> arg0_shape = arg0_tensor_type->get_shape(); Shape input_0_shape = input_0.get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); Shape input_1_shape = input_1.get_shape();
if (reduction_axes_count > arg0_shape.size()) if (reduction_axes_count > input_0_shape.size())
{ {
throw ngraph_error("Dot has too many axes for arg0"); throw ngraph_error("Dot has too many axes for arg0");
} }
if (reduction_axes_count > arg1_shape.size()) if (reduction_axes_count > input_1_shape.size())
{ {
throw ngraph_error("Dot has too many axes for arg1"); throw ngraph_error("Dot has too many axes for arg1");
} }
for (size_t i = 0; i < reduction_axes_count; i++) for (size_t i = 0; i < reduction_axes_count; i++)
{ {
if (arg0_shape[arg0_shape.size() - reduction_axes_count + i] != arg1_shape[i]) if (input_0_shape[input_0_shape.size() - reduction_axes_count + i] != input_1_shape[i])
{ {
throw ngraph_error("Dot axes do not have same length"); throw ngraph_error("Dot axes do not have same length");
} }
} }
vector<size_t> result_shape(arg0_shape.size() + arg1_shape.size() - 2 * reduction_axes_count); Shape result_shape(input_0_shape.size() + input_1_shape.size() - 2 * reduction_axes_count);
std::copy(arg0_shape.begin(), arg0_shape.end() - reduction_axes_count, result_shape.begin()); std::copy(
std::copy(arg1_shape.begin() + reduction_axes_count, input_0_shape.begin(), input_0_shape.end() - reduction_axes_count, result_shape.begin());
arg1_shape.end(), std::copy(input_1_shape.begin() + reduction_axes_count,
result_shape.begin() + (arg0_shape.size() - reduction_axes_count)); input_1_shape.end(),
result_shape.begin() + (input_0_shape.size() - reduction_axes_count));
auto result_type = auto result_type = make_shared<TensorViewType>(input_0.get_element_type(), result_shape);
make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type); set_value_type_checked(result_type);
} }
......
...@@ -33,5 +33,6 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t ...@@ -33,5 +33,6 @@ op::GetOutputElement::GetOutputElement(const std::shared_ptr<Node>& arg, size_t
throw ngraph_error("Indexing tuple beyond its size"); throw ngraph_error("Indexing tuple beyond its size");
} }
set_value_type_checked(arg->get_outputs().at(n).get_tensor_view_type()); auto& output = arg->get_outputs().at(n);
set_value_type_checked(output.get_element_type(), output.get_shape());
} }
...@@ -23,21 +23,21 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t ...@@ -23,21 +23,21 @@ op::OneHot::OneHot(const std::shared_ptr<Node>& arg, const Shape& shape, size_t
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type(); auto& input = m_inputs.at(0);
auto& arg_element_type = arg_tensor_view_type->get_element_type(); auto& input_element_type = input.get_element_type();
if (one_hot_axis >= shape.size()) if (one_hot_axis >= shape.size())
{ {
throw ngraph_error("One-hot axis is out of bounds"); throw ngraph_error("One-hot axis is out of bounds");
} }
auto expected_arg_shape = shape; auto expected_input_shape = shape;
expected_arg_shape.erase(expected_arg_shape.begin() + one_hot_axis); expected_input_shape.erase(expected_input_shape.begin() + one_hot_axis);
if (arg_tensor_view_type->get_shape() != expected_arg_shape) if (input.get_shape() != expected_input_shape)
{ {
throw ngraph_error("One-hot argument shape is not compatible with desired output shape"); throw ngraph_error("One-hot argument shape is not compatible with desired output shape");
} }
set_value_type_checked(make_shared<TensorViewType>(arg_element_type, shape)); set_value_type_checked(make_shared<TensorViewType>(input_element_type, shape));
} }
...@@ -26,25 +26,24 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -26,25 +26,24 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
, m_reduction_function(reduction_function) , m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes) , m_reduction_axes(reduction_axes)
{ {
auto arg_reductee_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input_reductee = get_inputs().at(0);
auto arg_init_tensor_view_type = get_inputs().at(1).get_tensor_view_type(); auto& input_init = get_inputs().at(1);
if (arg_init_tensor_view_type->get_shape().size() != 0) if (input_init.get_shape().size() != 0)
{ {
throw ngraph_error("Argument for initial value is not a scalar"); throw ngraph_error("Argument for initial value is not a scalar");
} }
if (arg_init_tensor_view_type->get_element_type() != if (input_init.get_element_type() != input_reductee.get_element_type())
arg_reductee_tensor_view_type->get_element_type())
{ {
throw ngraph_error("Element types for reductee and initial values do not match"); throw ngraph_error("Element types for reductee and initial values do not match");
} }
auto arg_reductee_shape = arg_reductee_tensor_view_type->get_shape(); auto input_reductee_shape = input_reductee.get_shape();
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
if (axis >= arg_reductee_shape.size()) if (axis >= input_reductee_shape.size())
{ {
throw ngraph_error("Reduction axis is out of bounds"); throw ngraph_error("Reduction axis is out of bounds");
} }
...@@ -52,11 +51,11 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -52,11 +51,11 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < arg_reductee_shape.size(); i++) for (size_t i = 0; i < input_reductee_shape.size(); i++)
{ {
if (m_reduction_axes.count(i) == 0) if (m_reduction_axes.count(i) == 0)
{ {
result_shape.push_back(arg_reductee_shape.at(i)); result_shape.push_back(input_reductee_shape.at(i));
} }
} }
...@@ -87,6 +86,6 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee, ...@@ -87,6 +86,6 @@ op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
throw ngraph_error("Return type from reduction function does not match expected"); throw ngraph_error("Return type from reduction function does not match expected");
} }
set_value_type_checked(make_shared<TensorViewType>( set_value_type_checked(
arg_reductee_tensor_view_type->get_element_type(), result_shape)); make_shared<TensorViewType>(input_reductee.get_element_type(), result_shape));
} }
...@@ -46,37 +46,37 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0, ...@@ -46,37 +46,37 @@ op::ReplaceSlice::ReplaceSlice(const std::shared_ptr<Node>& arg0,
void op::ReplaceSlice::check_args() void op::ReplaceSlice::check_args()
{ {
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto& arg0_shape = arg0_tensor_view_type->get_shape(); auto& input_0_shape = input_0.get_shape();
auto& arg0_element_type = arg0_tensor_view_type->get_element_type(); auto& input_0_element_type = input_0.get_element_type();
auto arg1_tensor_view_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
auto& arg1_shape = arg1_tensor_view_type->get_shape(); auto& input_1_shape = input_1.get_shape();
auto& arg1_element_type = arg1_tensor_view_type->get_element_type(); auto& input_1_element_type = input_1.get_element_type();
if (arg0_shape.size() != arg1_shape.size()) if (input_0_shape.size() != input_1_shape.size())
{ {
throw ngraph_error("Replace-slice argument ranks do not match"); throw ngraph_error("Replace-slice argument ranks do not match");
} }
if (arg0_element_type != arg1_element_type) if (input_0_element_type != input_1_element_type)
{ {
throw ngraph_error("Element types for replace-slice arguments do not match"); throw ngraph_error("Element types for replace-slice arguments do not match");
} }
if (m_lower_bounds.size() != arg0_shape.size()) if (m_lower_bounds.size() != input_0_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes"); "Number of lower bounds provided for slice does not match number of input axes");
} }
if (m_upper_bounds.size() != arg0_shape.size()) if (m_upper_bounds.size() != input_0_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes"); "Number of upper bounds provided for slice does not match number of input axes");
} }
if (m_strides.size() != arg0_shape.size()) if (m_strides.size() != input_0_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of strides provided for slice does not match number of input axes"); "Number of strides provided for slice does not match number of input axes");
...@@ -84,9 +84,9 @@ void op::ReplaceSlice::check_args() ...@@ -84,9 +84,9 @@ void op::ReplaceSlice::check_args()
Shape slice_shape; Shape slice_shape;
for (size_t i = 0; i < arg0_shape.size(); i++) for (size_t i = 0; i < input_0_shape.size(); i++)
{ {
if (m_upper_bounds[i] > arg0_shape[i]) if (m_upper_bounds[i] > input_0_shape[i])
{ {
throw ngraph_error("Upper bound for slice is out of range"); throw ngraph_error("Upper bound for slice is out of range");
} }
...@@ -107,12 +107,12 @@ void op::ReplaceSlice::check_args() ...@@ -107,12 +107,12 @@ void op::ReplaceSlice::check_args()
slice_shape.push_back(slice_axis_size); slice_shape.push_back(slice_axis_size);
} }
if (arg1_shape != slice_shape) if (input_1_shape != slice_shape)
{ {
throw ngraph_error("Shape of replacement tensor does not match slice shape"); throw ngraph_error("Shape of replacement tensor does not match slice shape");
} }
set_value_type_checked(arg0_tensor_view_type); set_value_type_checked(input_0_element_type, input_0_shape);
} }
void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
...@@ -121,8 +121,8 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints, ...@@ -121,8 +121,8 @@ void op::ReplaceSlice::generate_adjoints(autodiff::Adjoints& adjoints,
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto& y_input = get_inputs().at(1); auto& y_input = get_inputs().at(1);
auto y = y_input.get_output().get_node(); auto y = y_input.get_output().get_node();
auto& y_element_type = y_input.get_tensor_view_type()->get_element_type(); auto& y_element_type = y_input.get_element_type();
auto y_shape = y_input.get_tensor_view_type()->get_shape(); auto y_shape = y_input.get_shape();
auto zeros_shaped_like_y = std::make_shared<op::Constant>(y_element_type, y_shape, "0"); auto zeros_shaped_like_y = std::make_shared<op::Constant>(y_element_type, y_shape, "0");
......
...@@ -27,16 +27,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -27,16 +27,16 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
, m_input_order(input_order) , m_input_order(input_order)
, m_output_shape(output_shape) , m_output_shape(output_shape)
{ {
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
auto arg_shape = arg_tensor_view_type->get_shape(); auto input_shape = input.get_shape();
auto arg_rank = arg_shape.size(); auto input_rank = input_shape.size();
if (m_input_order.size() != arg_rank) if (m_input_order.size() != input_rank)
{ {
throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes"); throw ngraph_error("Input axis order for reshape is not a permutation of argument's axes");
} }
for (size_t i = 0; i < arg_rank; i++) for (size_t i = 0; i < input_rank; i++)
{ {
auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i); auto it = std::find(std::begin(m_input_order), std::end(m_input_order), i);
if (std::end(m_input_order) == it) if (std::end(m_input_order) == it)
...@@ -46,10 +46,10 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -46,10 +46,10 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
} }
} }
size_t arg_shape_product = 1; size_t input_shape_product = 1;
for (auto i : arg_shape) for (auto i : input_shape)
{ {
arg_shape_product *= i; input_shape_product *= i;
} }
size_t output_shape_product = 1; size_t output_shape_product = 1;
...@@ -58,15 +58,14 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg, ...@@ -58,15 +58,14 @@ op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
output_shape_product *= i; output_shape_product *= i;
} }
if (arg_shape_product != output_shape_product) if (input_shape_product != output_shape_product)
{ {
throw ngraph_error( throw ngraph_error(
"Product of output shape dimensions does not match product of argument shape " "Product of output shape dimensions does not match product of argument shape "
"dimensions for reshape"); "dimensions for reshape");
} }
set_value_type_checked( set_value_type_checked(input.get_element_type(), m_output_shape);
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
} }
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints, void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -29,25 +29,24 @@ op::Select::Select(const std::shared_ptr<Node>& arg0, ...@@ -29,25 +29,24 @@ op::Select::Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg2) const std::shared_ptr<Node>& arg2)
: RequiresTensorViewArgs("Select", Nodes{arg0, arg1, arg2}) : RequiresTensorViewArgs("Select", Nodes{arg0, arg1, arg2})
{ {
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input_0 = get_inputs().at(0);
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type(); auto& input_1 = get_inputs().at(1);
auto arg2_tensor_type = get_inputs().at(2).get_tensor_view_type(); auto& input_2 = get_inputs().at(2);
if (arg0_tensor_type->get_element_type() != element::Bool::element_type()) if (input_0.get_element_type() != element::Bool::element_type())
{ {
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type"); throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
} }
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape() || if (input_0.get_shape() != input_1.get_shape() || input_0.get_shape() != input_2.get_shape())
arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view shape"); throw ngraph_error("Arguments must have the same shape");
} }
if (*arg1_tensor_type != *arg2_tensor_type) if (input_1.get_element_type() != input_2.get_element_type())
{ {
throw ngraph_error("Arguments 1 and 2 must have the same tensor view type"); throw ngraph_error("Arguments 1 and 2 must have the same element type");
} }
set_value_type_checked(arg1_tensor_type); set_value_type_checked(input_1.get_element_type(), input_1.get_shape());
} }
void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints, void ngraph::op::Select::generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -42,22 +42,22 @@ op::Slice::Slice(const std::shared_ptr<Node>& arg, ...@@ -42,22 +42,22 @@ op::Slice::Slice(const std::shared_ptr<Node>& arg,
void op::Slice::check_args() void op::Slice::check_args()
{ {
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
auto& arg_shape = arg_tensor_view_type->get_shape(); auto& input_shape = input.get_shape();
if (m_lower_bounds.size() != arg_shape.size()) if (m_lower_bounds.size() != input_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes"); "Number of lower bounds provided for slice does not match number of input axes");
} }
if (m_upper_bounds.size() != arg_shape.size()) if (m_upper_bounds.size() != input_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes"); "Number of upper bounds provided for slice does not match number of input axes");
} }
if (m_strides.size() != arg_shape.size()) if (m_strides.size() != input_shape.size())
{ {
throw ngraph_error( throw ngraph_error(
"Number of strides provided for slice does not match number of input axes"); "Number of strides provided for slice does not match number of input axes");
...@@ -65,9 +65,9 @@ void op::Slice::check_args() ...@@ -65,9 +65,9 @@ void op::Slice::check_args()
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++) for (size_t i = 0; i < input_shape.size(); i++)
{ {
if (m_upper_bounds[i] > arg_shape[i]) if (m_upper_bounds[i] > input_shape[i])
{ {
throw ngraph_error("Upper bound for slice is out of range"); throw ngraph_error("Upper bound for slice is out of range");
} }
...@@ -88,8 +88,7 @@ void op::Slice::check_args() ...@@ -88,8 +88,7 @@ void op::Slice::check_args()
result_shape.push_back(result_axis_size); result_shape.push_back(result_axis_size);
} }
set_value_type_checked( set_value_type_checked(input.get_element_type(), result_shape);
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
} }
void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Slice::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
......
...@@ -23,18 +23,18 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -23,18 +23,18 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: RequiresTensorViewArgs("Sum", {arg}) : RequiresTensorViewArgs("Sum", {arg})
, m_reduction_axes(reduction_axes) , m_reduction_axes(reduction_axes)
{ {
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
auto& arg_element_type = arg_tensor_view_type->get_element_type(); auto& input_element_type = input.get_element_type();
if (arg_element_type == element::Bool::element_type()) if (input_element_type == element::Bool::element_type())
{ {
throw ngraph_error("Argument for sum must have numeric element type"); throw ngraph_error("Argument for sum must have numeric element type");
} }
auto arg_shape = arg_tensor_view_type->get_shape(); auto input_shape = input.get_shape();
for (auto axis : m_reduction_axes) for (auto axis : m_reduction_axes)
{ {
if (axis >= arg_shape.size()) if (axis >= input_shape.size())
{ {
throw ngraph_error("Reduction axis for sum is out of bounds"); throw ngraph_error("Reduction axis for sum is out of bounds");
} }
...@@ -42,22 +42,21 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes) ...@@ -42,22 +42,21 @@ op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
Shape result_shape; Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++) for (size_t i = 0; i < input_shape.size(); i++)
{ {
if (m_reduction_axes.count(i) == 0) if (m_reduction_axes.count(i) == 0)
{ {
result_shape.push_back(arg_shape.at(i)); result_shape.push_back(input_shape.at(i));
} }
} }
set_value_type_checked( set_value_type_checked(input.get_element_type(), result_shape);
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
} }
void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta) void op::Sum::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{ {
auto x = get_inputs().at(0).get_output().get_node(); auto x = get_inputs().at(0).get_output().get_node();
auto& x_shape = get_inputs().at(0).get_tensor_view_type()->get_shape(); auto& x_shape = get_inputs().at(0).get_shape();
adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes)); adjoints.add_delta(x, make_shared<op::Broadcast>(delta, x_shape, m_reduction_axes));
} }
...@@ -25,10 +25,8 @@ op::UnaryElementwise::UnaryElementwise( ...@@ -25,10 +25,8 @@ op::UnaryElementwise::UnaryElementwise(
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg}) : RequiresTensorViewArgs(node_type, Nodes{arg})
{ {
auto arg_tensor_type = get_inputs().at(0).get_tensor_view_type(); auto& input = get_inputs().at(0);
const element::Type& result_element_type = const element::Type& result_element_type = element_type_function(input.get_element_type());
element_type_function(arg_tensor_type->get_element_type());
set_value_type_checked( set_value_type_checked(result_element_type, input.get_shape());
make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
} }
...@@ -78,9 +78,8 @@ void runtime::interpreter::INT_CallFrame::call( ...@@ -78,9 +78,8 @@ void runtime::interpreter::INT_CallFrame::call(
if (!contains_key(tensor_map, name)) if (!contains_key(tensor_map, name))
{ {
// The output tensor is not in the tensor map so create a new tensor // The output tensor is not in the tensor map so create a new tensor
const Shape& shape = output.get_tensor_view_type()->get_shape(); const Shape& shape = output.get_shape();
const element::Type& element_type = const element::Type& element_type = output.get_element_type();
output.get_tensor_view_type()->get_element_type();
string tensor_name = output.get_tensor().get_name(); string tensor_name = output.get_tensor().get_name();
itv = make_shared<runtime::interpreter::INT_TensorView>( itv = make_shared<runtime::interpreter::INT_TensorView>(
element_type, shape, tensor_name); element_type, shape, tensor_name);
......
...@@ -258,9 +258,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -258,9 +258,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \ #define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \ const element::Type& et = n->get_inputs().at(0).get_element_type(); \
n->get_inputs().at(0).get_tensor_view_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \ DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric unop has unhandled element type", \ "Internal error: numeric unop has unhandled element type", \
M_REGISTER_NUMERIC_UNOP, \ M_REGISTER_NUMERIC_UNOP, \
...@@ -270,9 +268,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -270,9 +268,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
#define REGISTER_LOGICAL_UNOP(op_class, instr_class) \ #define REGISTER_LOGICAL_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \ const element::Type& et = n->get_inputs().at(0).get_element_type(); \
n->get_inputs().at(0).get_tensor_view_type())) \
->get_element_type(); \
if (element::Bool::element_type() == et) \ if (element::Bool::element_type() == et) \
{ \ { \
ef->get_instructions()->push_back(make_shared<instr_class>(in[0], out[0])); \ ef->get_instructions()->push_back(make_shared<instr_class>(in[0], out[0])); \
...@@ -288,9 +284,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -288,9 +284,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
#define REGISTER_NUMERIC_BINOP(op_class, instr_class) \ #define REGISTER_NUMERIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \ const element::Type& et = n->get_inputs().at(0).get_element_type(); \
n->get_inputs().at(0).get_tensor_view_type())) \
->get_element_type(); \
DO_ON_NUMERIC_TYPE(et, \ DO_ON_NUMERIC_TYPE(et, \
"Internal error: numeric binop has unhandled element type", \ "Internal error: numeric binop has unhandled element type", \
M_REGISTER_NUMERIC_BINOP, \ M_REGISTER_NUMERIC_BINOP, \
...@@ -302,9 +296,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -302,9 +296,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
#define REGISTER_POLYMORPHIC_BINOP(op_class, instr_class) \ #define REGISTER_POLYMORPHIC_BINOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \ const element::Type& et = n->get_inputs().at(0).get_element_type(); \
n->get_inputs().at(0).get_tensor_view_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \ DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic binop has unhandled element type", \ "Internal error: polymorphic binop has unhandled element type", \
M_REGISTER_POLYMORPHIC_BINOP, \ M_REGISTER_POLYMORPHIC_BINOP, \
...@@ -317,9 +309,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -317,9 +309,7 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
#define REGISTER_POLYMORPHIC_TERNOP(op_class, instr_class) \ #define REGISTER_POLYMORPHIC_TERNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \ REGISTER_TO_OP_MAP(op_class) \
{ \ { \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \ const element::Type& et = n->get_inputs().at(1).get_element_type(); \
n->get_inputs().at(1).get_tensor_view_type())) \
->get_element_type(); \
DO_ON_ELEMENT_TYPE(et, \ DO_ON_ELEMENT_TYPE(et, \
"Internal error: polymorphic ternop has unhandled element type", \ "Internal error: polymorphic ternop has unhandled element type", \
M_REGISTER_POLYMORPHIC_TERNOP, \ M_REGISTER_POLYMORPHIC_TERNOP, \
...@@ -434,23 +424,17 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -434,23 +424,17 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto broadcast = static_cast<const op::Broadcast*>(n); auto broadcast = static_cast<const op::Broadcast*>(n);
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto& input_shape = n->get_inputs().at(0).get_shape();
n->get_inputs().at(0).get_tensor_view_type()); auto& result = n->get_outputs().at(0);
assert(nullptr != arg_tensor_type); auto result_shape = result.get_shape();
auto arg_shape = arg_tensor_type->get_shape(); auto& result_element_type = result.get_element_type();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type", "Broadcast has unhandled element type",
instruction::BroadcastInstruction, instruction::BroadcastInstruction,
in[0], in[0],
out[0], out[0],
arg_shape, input_shape,
result_shape, result_shape,
broadcast->get_broadcast_axes()); broadcast->get_broadcast_axes());
}; };
...@@ -459,37 +443,30 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -459,37 +443,30 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto concat = static_cast<const op::Concat*>(n); auto concat = static_cast<const op::Concat*>(n);
std::vector<Shape> arg_shapes; std::vector<Shape> input_shapes;
for (auto& arg : n->get_inputs()) for (auto& input : n->get_inputs())
{ {
auto arg_tensor_type = input_shapes.push_back(input.get_shape());
dynamic_pointer_cast<const TensorViewType>(arg.get_tensor_view_type());
assert(nullptr != arg_tensor_type);
arg_shapes.push_back(arg_tensor_type->get_shape());
} }
auto result_tensor_type = auto& result = n->get_outputs().at(0);
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); auto result_shape = result.get_shape();
assert(nullptr != result_tensor_type); auto& result_element_type = result.get_element_type();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Concat has unhandled element type", "Concat has unhandled element type",
instruction::ConcatInstruction, instruction::ConcatInstruction,
in, in,
out[0], out[0],
arg_shapes, input_shapes,
result_shape, result_shape,
concat->get_concatenation_axis()); concat->get_concatenation_axis());
}; };
REGISTER_TO_OP_MAP(op::Convert) REGISTER_TO_OP_MAP(op::Convert)
{ {
auto arg_tensor_type = n->get_inputs().at(0).get_tensor_view_type(); auto& arg_element_type = n->get_inputs().at(0).get_element_type();
auto& arg_element_type = arg_tensor_type->get_element_type();
auto result_tensor_type = auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type); assert(nullptr != result_tensor_type);
...@@ -546,17 +523,12 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -546,17 +523,12 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto convolution = static_cast<const op::Convolution*>(n); auto convolution = static_cast<const op::Convolution*>(n);
auto arg0_tensor_type = n->get_inputs().at(0).get_tensor_view_type(); auto input_0_shape = n->get_inputs().at(0).get_shape();
auto arg0_shape = arg0_tensor_type->get_shape(); auto input_1_shape = n->get_inputs().at(1).get_shape();
auto arg1_tensor_type = n->get_inputs().at(1).get_tensor_view_type(); auto& result = n->get_outputs().at(0);
auto arg1_shape = arg1_tensor_type->get_shape(); auto& result_shape = result.get_shape();
auto& result_element_type = result.get_element_type();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Convolution has unhandled element type", "Convolution has unhandled element type",
...@@ -564,8 +536,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -564,8 +536,8 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
in[0], in[0],
in[1], in[1],
out[0], out[0],
arg0_shape, input_0_shape,
arg1_shape, input_1_shape,
result_shape, result_shape,
convolution->get_window_movement_strides(), convolution->get_window_movement_strides(),
convolution->get_window_dilation_strides()); convolution->get_window_dilation_strides());
...@@ -577,34 +549,25 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -577,34 +549,25 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
assert(n->get_inputs().size() == 2); assert(n->get_inputs().size() == 2);
auto arg0_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto& input_0 = n->get_inputs().at(0);
n->get_inputs().at(0).get_tensor_view_type()); auto& input_1 = n->get_inputs().at(1);
assert(nullptr != arg0_tensor_type);
auto arg1_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto input_0_shape = input_0.get_shape();
n->get_inputs().at(1).get_tensor_view_type()); auto input_1_shape = input_1.get_shape();
assert(nullptr != arg1_tensor_type); auto& input_0_element_type = input_0.get_element_type();
auto arg0_shape = arg0_tensor_type->get_shape();
auto arg1_shape = arg1_tensor_type->get_shape();
auto& arg0_element_type = arg0_tensor_type->get_element_type();
auto reduction_axes_count = dot->get_reduction_axes_count(); auto reduction_axes_count = dot->get_reduction_axes_count();
auto result_tensor_type = auto result_shape = n->get_outputs().at(0).get_shape();
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type, PUSH_POLYMORPHIC_INSTRUCTION(input_0_element_type,
"Dot has unhandled element type", "Dot has unhandled element type",
instruction::DotInstruction, instruction::DotInstruction,
in[0], in[0],
in[1], in[1],
out[0], out[0],
arg0_shape, input_0_shape,
arg1_shape, input_1_shape,
result_shape, result_shape,
reduction_axes_count); reduction_axes_count);
}; };
...@@ -682,16 +645,10 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -682,16 +645,10 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
function_map.insert({reduction_function, external}); function_map.insert({reduction_function, external});
} }
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto input_shape = n->get_inputs().at(0).get_shape();
n->get_inputs().at(0).get_tensor_view_type()); auto& result = n->get_outputs().at(0);
assert(nullptr != arg_tensor_type); auto result_shape = result.get_shape();
auto arg_shape = arg_tensor_type->get_shape(); auto& result_element_type = result.get_element_type();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
#define M(ET) \ #define M(ET) \
{ \ { \
...@@ -714,7 +671,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -714,7 +671,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
in[0], \ in[0], \
in[1], \ in[1], \
out[0], \ out[0], \
arg_shape, \ input_shape, \
result_shape, \ result_shape, \
reduce->get_reduction_axes(), \ reduce->get_reduction_axes(), \
reduce_handler); \ reduce_handler); \
...@@ -731,23 +688,18 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -731,23 +688,18 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto sum = static_cast<const op::Sum*>(n); auto sum = static_cast<const op::Sum*>(n);
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto input_shape = n->get_inputs().at(0).get_shape();
n->get_inputs().at(0).get_tensor_view_type());
assert(nullptr != arg_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_tensor_type = auto& result = n->get_outputs().at(0);
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); auto result_shape = result.get_shape();
assert(nullptr != result_tensor_type); auto& result_element_type = result.get_element_type();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Sum has unhandled element type", "Sum has unhandled element type",
instruction::SumInstruction, instruction::SumInstruction,
in[0], in[0],
out[0], out[0],
arg_shape, input_shape,
result_shape, result_shape,
sum->get_reduction_axes()); sum->get_reduction_axes());
}; };
...@@ -756,23 +708,18 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -756,23 +708,18 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto reshape = static_cast<const op::Reshape*>(n); auto reshape = static_cast<const op::Reshape*>(n);
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto input_shape = n->get_inputs().at(0).get_shape();
n->get_inputs().at(0).get_tensor_view_type());
assert(nullptr != arg_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_tensor_type = auto& result = n->get_outputs().at(0);
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); auto result_shape = result.get_shape();
assert(nullptr != result_tensor_type); auto& result_element_type = result.get_element_type();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Reshape has unhandled element type", "Reshape has unhandled element type",
instruction::ReshapeInstruction, instruction::ReshapeInstruction,
in[0], in[0],
out[0], out[0],
arg_shape, input_shape,
reshape->get_input_order(), reshape->get_input_order(),
result_shape); result_shape);
}; };
...@@ -781,28 +728,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -781,28 +728,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto slice = static_cast<const op::Slice*>(n); auto slice = static_cast<const op::Slice*>(n);
auto arg_type = slice->get_inputs().at(0).get_tensor_view_type(); auto& input = slice->get_inputs().at(0);
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type); auto input_shape = input.get_shape();
assert(nullptr != arg_tensor_view_type); auto& input_element_type = input.get_element_type();
auto arg_shape = arg_tensor_view_type->get_shape();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
auto result_type = slice->get_value_type(); auto result_shape = slice->get_outputs().at(0).get_shape();
auto result_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(result_type);
assert(nullptr != result_tensor_view_type);
auto result_shape = result_tensor_view_type->get_shape();
auto& lower_bounds = slice->get_lower_bounds(); auto& lower_bounds = slice->get_lower_bounds();
auto& upper_bounds = slice->get_upper_bounds(); auto& upper_bounds = slice->get_upper_bounds();
auto& strides = slice->get_strides(); auto& strides = slice->get_strides();
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type, PUSH_POLYMORPHIC_INSTRUCTION(input_element_type,
"Slice has unhandled element type", "Slice has unhandled element type",
runtime::ngvm::instruction::SliceInstruction, runtime::ngvm::instruction::SliceInstruction,
in[0], in[0],
out[0], out[0],
arg_shape, input_shape,
lower_bounds, lower_bounds,
upper_bounds, upper_bounds,
strides, strides,
...@@ -813,33 +755,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -813,33 +755,23 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto replace_slice = static_cast<const op::ReplaceSlice*>(n); auto replace_slice = static_cast<const op::ReplaceSlice*>(n);
auto arg0_type = replace_slice->get_inputs().at(0).get_tensor_view_type(); auto& input_0_element_type = replace_slice->get_inputs().at(0).get_element_type();
auto arg0_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg0_type); auto input_1_shape = replace_slice->get_inputs().at(1).get_shape();
assert(nullptr != arg0_tensor_view_type);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
auto arg1_type = replace_slice->get_inputs().at(1).get_tensor_view_type();
auto arg1_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg1_type);
assert(nullptr != arg1_tensor_view_type);
auto arg1_shape = arg1_tensor_view_type->get_shape();
auto result_type = replace_slice->get_value_type(); auto result_shape = replace_slice->get_outputs().at(0).get_shape();
auto result_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(result_type);
assert(nullptr != result_tensor_view_type);
auto result_shape = result_tensor_view_type->get_shape();
auto& lower_bounds = replace_slice->get_lower_bounds(); auto& lower_bounds = replace_slice->get_lower_bounds();
auto& upper_bounds = replace_slice->get_upper_bounds(); auto& upper_bounds = replace_slice->get_upper_bounds();
auto& strides = replace_slice->get_strides(); auto& strides = replace_slice->get_strides();
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type, PUSH_POLYMORPHIC_INSTRUCTION(input_0_element_type,
"Replace-slice has unhandled element type", "Replace-slice has unhandled element type",
runtime::ngvm::instruction::ReplaceSliceInstruction, runtime::ngvm::instruction::ReplaceSliceInstruction,
in[0], in[0],
in[1], in[1],
out[0], out[0],
arg1_shape, input_1_shape,
lower_bounds, lower_bounds,
upper_bounds, upper_bounds,
strides, strides,
...@@ -850,23 +782,18 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -850,23 +782,18 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
auto one_hot = static_cast<const op::OneHot*>(n); auto one_hot = static_cast<const op::OneHot*>(n);
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>( auto input_shape = n->get_inputs().at(0).get_shape();
n->get_inputs().at(0).get_tensor_view_type());
assert(nullptr != arg_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_tensor_type = auto& result = n->get_outputs().at(0);
dynamic_pointer_cast<const TensorViewType>(n->get_value_type()); auto result_shape = result.get_shape();
assert(nullptr != result_tensor_type); auto& result_element_type = result.get_element_type();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type, PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"One-hot has unhandled element type", "One-hot has unhandled element type",
instruction::OneHotInstruction, instruction::OneHotInstruction,
in[0], in[0],
out[0], out[0],
arg_shape, input_shape,
result_shape, result_shape,
one_hot->get_one_hot_axis()); one_hot->get_one_hot_axis());
}; };
......
...@@ -559,7 +559,7 @@ TEST(type_prop, select_shape_mismatch_a) ...@@ -559,7 +559,7 @@ TEST(type_prop, select_shape_mismatch_a)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape")); EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
} }
catch (...) catch (...)
{ {
...@@ -583,7 +583,7 @@ TEST(type_prop, select_shape_mismatch_b) ...@@ -583,7 +583,7 @@ TEST(type_prop, select_shape_mismatch_b)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape")); EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
} }
catch (...) catch (...)
{ {
...@@ -607,7 +607,7 @@ TEST(type_prop, select_shape_mismatch_c) ...@@ -607,7 +607,7 @@ TEST(type_prop, select_shape_mismatch_c)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape")); EXPECT_EQ(error.what(), std::string("Arguments must have the same shape"));
} }
catch (...) catch (...)
{ {
...@@ -657,8 +657,7 @@ TEST(type_prop, select_elem_mismatch_bc) ...@@ -657,8 +657,7 @@ TEST(type_prop, select_elem_mismatch_bc)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), EXPECT_EQ(error.what(), std::string("Arguments 1 and 2 must have the same element type"));
std::string("Arguments 1 and 2 must have the same tensor view type"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment