Commit 8ba7b324 authored by Scott Cyphers's avatar Scott Cyphers

Finish broadcast type propagation, add tests.

parent 961b4e0a
......@@ -94,6 +94,7 @@ namespace ngraph
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
protected:
Nodes m_arguments;
......
......@@ -19,6 +19,10 @@ using namespace ngraph::op;
void Broadcast::propagate_types()
{
if (m_arguments.size() != 1){
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
......
......@@ -24,6 +24,7 @@ set (SRC
element_type.cpp
uuid.cpp
topological_sort.cpp
type_prop.cpp
op.cpp
)
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(type_prop, broadcast_deduce)
{
// Deduce type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_correct)
{
// Check deduced type against correctly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 4}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
TEST(type_prop, broadcast_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 4, 3}, AxisSet{1});
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 4}));
try
{
bc->propagate_types();
FAIL() << "Deduced type should disagree with specified type";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Broadcast arg, shape, and axes are incompatible"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, broadcast_bad_arguments)
{
// Check for bad arguments
auto param = make_shared<op::Parameter>(make_shared<TupleType>());
auto bc = make_shared<op::Broadcast>(param, Shape{2, 4, 3}, AxisSet{1});
try
{
bc->propagate_types();
FAIL() << "Tuple argument to broadcast not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Argument to broadcast is not a tensor view"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment