Commit 8c50b179 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

New checks for Function's constructor (#292)

* Remove unnecessary dependency on codegen in interpreter

* Check for incorrect return type and undeclared parameters in Function's constructor

* Address review comments

* Add scarier error message when the result node has null return type (should never happen)
* Add new constructor for Function that doesn't require the return type, and unit test for same
parent 025a1b92
......@@ -29,12 +29,42 @@ Function::Function(const std::shared_ptr<Node>& result,
: m_result(result)
, m_parameters(parameters)
, m_name(name)
, m_result_type(result_type)
, m_ordered_ops_valid(false)
, m_temporary_pool_size(0)
, m_instance_id(m_next_instance_id.fetch_add(1))
{
traverse_nodes(this, [&](shared_ptr<Node> node) { m_ops.push_back(node); });
if (nullptr == result->get_value_type())
{
throw ngraph_error("Internal nGraph error: result->get_value_type() == nullptr");
}
if (nullptr != result_type && (*result_type != *(result->get_value_type())))
{
throw ngraph_error("Function result node's value type does not match declared return type");
}
traverse_nodes(this, [&](shared_ptr<Node> node) {
m_ops.push_back(node);
std::shared_ptr<op::Parameter> p = std::dynamic_pointer_cast<op::Parameter>(node);
if (nullptr != p)
{
auto it = std::find_if(parameters.begin(),
parameters.end(),
[p](std::shared_ptr<op::Parameter> q) { return (p == q); });
if (it == parameters.end())
{
throw ngraph_error("Function references undeclared parameter");
}
}
});
}
Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name)
: Function(result, nullptr, parameters, name)
{
}
void Function::set_ordered_ops(const std::list<shared_ptr<Node>>& ordered_ops)
......
......@@ -39,13 +39,20 @@ namespace ngraph
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name = "");
Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<op::Parameter>>& parameters,
const std::string& name = "");
std::shared_ptr<Node> get_result() { return m_result; }
std::shared_ptr<const Node> get_result() const { return m_result; }
const std::vector<std::shared_ptr<op::Parameter>>& get_parameters() const
{
return m_parameters;
}
std::shared_ptr<const ValueType> get_result_type() const { return m_result_type; }
std::shared_ptr<const ValueType> get_result_type() const
{
return m_result->get_value_type();
}
std::string get_name() const;
void set_name(const std::string& name);
std::list<std::shared_ptr<Node>>& get_ops();
......
......@@ -16,7 +16,6 @@
#include <memory>
#include "ngraph/codegen/execution_engine.hpp"
#include "ngraph/runtime/manager.hpp"
namespace ngraph
......@@ -32,9 +31,6 @@ namespace ngraph
/// @brief Transformer for the interpreted backend
class INT_Manager : public Manager
{
protected:
ngraph::codegen::ExecutionEngine exec_state;
public:
virtual std::shared_ptr<Backend> allocate_backend() override;
......
......@@ -537,7 +537,7 @@ TEST(${BACKEND_NAME}, dot_matrix_3x2_2x0)
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_b = Shape{2, 0};
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape_b);
auto shape_r = Shape{0, 0};
auto shape_r = Shape{3, 0};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(make_shared<op::Dot>(A, B), rt, op::Parameters{A, B});
......
......@@ -33,8 +33,7 @@ TEST(build_graph, build_simple)
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto result_type =
make_shared<TensorViewType>(element::Float32::element_type(), Shape{10, 32, 7});
auto result_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{32, 3});
auto cluster_0 =
make_shared<Function>(dot, result_type, op::Parameters{arg0, arg1, arg2, arg3});
......@@ -119,3 +118,91 @@ TEST(build_graph, tensor)
TEST(build_graph, arg_inverse)
{
}
// Check functions with undeclared parameters
TEST(build_graph, function_undeclared_parameters)
{
// Function with 4 parameters
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto result_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{32, 3});
try
{
auto f = make_shared<Function>(dot, result_type, op::Parameters{arg0, arg1, arg3});
// Should have thrown, so fail if it didn't
FAIL() << "Undeclared parameter not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Function references undeclared parameter"));
}
catch (...)
{
FAIL() << "Function construction failed for unexpected reason";
}
}
// Check functions with incorrect declared return types
TEST(build_graph, function_incorrect_return_type)
{
// Function with 4 parameters
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto incorrect_result_type =
make_shared<TensorViewType>(element::Int32::element_type(), Shape{32, 3});
try
{
auto f = make_shared<Function>(
dot, incorrect_result_type, op::Parameters{arg0, arg1, arg2, arg3});
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect result type not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string("Function result node's value type does not match declared return type"));
}
catch (...)
{
FAIL() << "Function construction failed for unexpected reason";
}
}
// Check functions with no declared return type
TEST(build_graph, function_no_declared_return_type)
{
// Function with 4 parameters
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto f = make_shared<Function>(dot, op::Parameters{arg0, arg1, arg2, arg3});
auto f_rt = f->get_result_type();
ASSERT_EQ(*f_rt, TensorViewType(element::Float32::element_type(), Shape{32, 3}));
}
......@@ -37,8 +37,8 @@ namespace ng = ngraph;
TEST(liveness, constant)
{
auto shape = Shape{1};
auto c = make_shared<op::Constant>(element::i32, Shape{}, "5");
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto c = make_shared<op::Constant>(element::i32, shape, "5");
auto rt = make_shared<TensorViewType>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Negative>(c), rt, op::Parameters{});
pass::Manager pass_manager;
......
......@@ -229,8 +229,8 @@ TEST(memory_layout, constant)
pass_manager.register_pass<pass::DumpSorted>(dump_file);
auto shape = Shape{1};
auto c = make_shared<op::Constant>(element::i32, Shape{}, "5");
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape);
auto c = make_shared<op::Constant>(element::i32, shape, "5");
auto rt = make_shared<TensorViewType>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::Negative>(c), rt, op::Parameters{});
pass_manager.run_passes(f);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment