Commit e8e762dd authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #87 from NervanaSystems/cyphers/binary

A few binary elementwise ops
parents df3bbefd f6637801
...@@ -15,6 +15,7 @@ set (SRC ...@@ -15,6 +15,7 @@ set (SRC
tree.cpp tree.cpp
util.cpp util.cpp
log.cpp log.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/convert.cpp ops/convert.cpp
...@@ -24,6 +25,7 @@ set (SRC ...@@ -24,6 +25,7 @@ set (SRC
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/tuple.cpp ops/tuple.cpp
ops/unary_elementwise_builtin.cpp
types/element_type.cpp types/element_type.cpp
types/type.cpp types/type.cpp
ngraph/node.cpp ngraph/node.cpp
......
...@@ -24,12 +24,18 @@ ...@@ -24,12 +24,18 @@
#include "function.hpp" #include "function.hpp"
#include "node.hpp" #include "node.hpp"
#include "op.hpp" #include "op.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp" #include "ops/broadcast.hpp"
#include "ops/ceiling.hpp"
#include "ops/concatenate.hpp" #include "ops/concatenate.hpp"
#include "ops/constant.hpp" #include "ops/constant.hpp"
#include "ops/convert.hpp" #include "ops/convert.hpp"
#include "ops/divide.hpp"
#include "ops/dot.hpp" #include "ops/dot.hpp"
#include "ops/floor.hpp"
#include "ops/multiply.hpp"
#include "ops/parameter.hpp" #include "ops/parameter.hpp"
#include "ops/subtract.hpp"
#include "ops/tuple.hpp" #include "ops/tuple.hpp"
#include "shape.hpp" #include "shape.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -61,10 +61,6 @@ namespace ngraph ...@@ -61,10 +61,6 @@ namespace ngraph
{ {
public: public:
virtual std::string description() const override { return "Builtin"; } virtual std::string description() const override { return "Builtin"; }
/// Name of the builtin op, for debugging and logging.
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected: protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args) Builtin(const std::vector<std::shared_ptr<Node>>& args)
...@@ -73,58 +69,60 @@ namespace ngraph ...@@ -73,58 +69,60 @@ namespace ngraph
} }
}; };
class Abs : public Builtin /// Index ops create a new way to index the same tensor elements
class IndexBuiltin : public Builtin
{ {
public: protected:
Abs(const std::shared_ptr<Node>& arg0) IndexBuiltin(const std::shared_ptr<Node>& arg)
: Builtin({arg0}) : Builtin(Nodes{arg})
{ {
} }
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
}; };
class Add : public Builtin /// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin
{ {
public: protected:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) UnaryElementwiseBuiltin(const std::shared_ptr<Node>& arg)
: Builtin({arg0, arg1}) : Builtin(Nodes{arg})
{ {
} }
virtual std::string get_op_class_name() const override { return "Add"; }
//virtual void propagate_types() override; public:
virtual void propagate_types() override;
}; };
class Ceiling : public Builtin /// Op(X, Y)[I] = op(X[I], Y[I])
class BinaryElementwiseBuiltin : public Builtin
{ {
public: protected:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) BinaryElementwiseBuiltin(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : Builtin(Nodes{arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "Ceiling"; } public:
//virtual void propagate_types() override; virtual void propagate_types() override;
}; };
class Divide : public Builtin class Abs : public UnaryElementwiseBuiltin
{ {
public: public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Abs(const std::shared_ptr<Node>& arg0)
: Builtin({arg0, arg1}) : UnaryElementwiseBuiltin({arg0})
{ {
} }
virtual std::string get_op_class_name() const override { return "Divide"; } virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Equal : public Builtin class Equal : public BinaryElementwiseBuiltin
{ {
public: public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -132,11 +130,11 @@ namespace ngraph ...@@ -132,11 +130,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Exp : public Builtin class Exp : public UnaryElementwiseBuiltin
{ {
public: public:
Exp(const std::shared_ptr<Node>& arg0) Exp(const std::shared_ptr<Node>& arg0)
: Builtin({arg0}) : UnaryElementwiseBuiltin(arg0)
{ {
} }
...@@ -144,23 +142,11 @@ namespace ngraph ...@@ -144,23 +142,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Floor : public Builtin class Greater : public BinaryElementwiseBuiltin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
//virtual void propagate_types() override;
};
class Greater : public Builtin
{ {
public: public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -168,11 +154,11 @@ namespace ngraph ...@@ -168,11 +154,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Less : public Builtin class Less : public BinaryElementwiseBuiltin
{ {
public: public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -180,11 +166,11 @@ namespace ngraph ...@@ -180,11 +166,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Log : public Builtin class Log : public UnaryElementwiseBuiltin
{ {
public: public:
Log(const std::shared_ptr<Node>& arg0) Log(const std::shared_ptr<Node>& arg0)
: Builtin({arg0}) : UnaryElementwiseBuiltin(arg0)
{ {
} }
...@@ -192,11 +178,11 @@ namespace ngraph ...@@ -192,11 +178,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Maximum : public Builtin class Maximum : public BinaryElementwiseBuiltin
{ {
public: public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -204,11 +190,11 @@ namespace ngraph ...@@ -204,11 +190,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Minimum : public Builtin class Minimum : public BinaryElementwiseBuiltin
{ {
public: public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -216,23 +202,11 @@ namespace ngraph ...@@ -216,23 +202,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Multiply : public Builtin class Negative : public UnaryElementwiseBuiltin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
//virtual void propagate_types() override;
};
class Negative : public Builtin
{ {
public: public:
Negative(const std::shared_ptr<Node>& arg0) Negative(const std::shared_ptr<Node>& arg0)
: Builtin({arg0}) : UnaryElementwiseBuiltin(arg0)
{ {
} }
...@@ -240,11 +214,11 @@ namespace ngraph ...@@ -240,11 +214,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Power : public Builtin class Power : public BinaryElementwiseBuiltin
{ {
public: public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -252,11 +226,11 @@ namespace ngraph ...@@ -252,11 +226,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Remainder : public Builtin class Remainder : public BinaryElementwiseBuiltin
{ {
public: public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1}) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
...@@ -264,11 +238,11 @@ namespace ngraph ...@@ -264,11 +238,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class Reshape : public Builtin class Reshape : public IndexBuiltin
{ {
public: public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape) Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: Builtin({arg0}) : IndexBuiltin(arg0)
, m_shape(shape) , m_shape(shape)
{ {
} }
...@@ -278,17 +252,5 @@ namespace ngraph ...@@ -278,17 +252,5 @@ namespace ngraph
protected: protected:
Shape m_shape; Shape m_shape;
}; };
class Subtract : public Builtin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
//virtual void propagate_types() override;
};
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Add : public BinaryElementwiseBuiltin
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
};
}
}
...@@ -18,7 +18,7 @@ namespace ngraph ...@@ -18,7 +18,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Broadcast : public Builtin class Broadcast : public IndexBuiltin
{ {
public: public:
/// ///
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
Broadcast(const std::shared_ptr<Node>& arg, Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: Builtin({arg}) : IndexBuiltin(arg)
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Ceiling : public BinaryElementwiseBuiltin
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
};
}
}
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Convert : public Builtin class Convert : public UnaryElementwiseBuiltin
{ {
public: public:
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type) Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: Builtin({arg}) : UnaryElementwiseBuiltin({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Divide : public BinaryElementwiseBuiltin
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Floor : public BinaryElementwiseBuiltin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Multiply : public BinaryElementwiseBuiltin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Subtract : public BinaryElementwiseBuiltin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (*arg0_tensor_type != *arg1_tensor_type)
{
throw ngraph_error("Arguments must have the same tensor view type");
}
set_value_type_checked(arg0_tensor_type);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
set_value_type_checked(arg_tensor_type);
}
...@@ -20,6 +20,11 @@ ...@@ -20,6 +20,11 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node);
void test_binary_bad_arguments_views(const shared_ptr<Node>& node);
void test_binary_good_arguments(const shared_ptr<Node>& node);
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y));
TEST(type_prop, broadcast_deduce) TEST(type_prop, broadcast_deduce)
{ {
// Deduce type // Deduce type
...@@ -52,6 +57,7 @@ TEST(type_prop, broadcast_deduce_incorrect) ...@@ -52,6 +57,7 @@ TEST(type_prop, broadcast_deduce_incorrect)
try try
{ {
bc->propagate_types(); bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type"; FAIL() << "Deduced type should disagree with specified type";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
...@@ -72,6 +78,7 @@ TEST(type_prop, broadcast_bad_arguments) ...@@ -72,6 +78,7 @@ TEST(type_prop, broadcast_bad_arguments)
try try
{ {
bc->propagate_types(); bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument to broadcast not detected."; FAIL() << "Tuple argument to broadcast not detected.";
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
...@@ -84,3 +91,105 @@ TEST(type_prop, broadcast_bad_arguments) ...@@ -84,3 +91,105 @@ TEST(type_prop, broadcast_bad_arguments)
} }
} }
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must be tensor views"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_good_arguments(const shared_ptr<Node>& node)
{
node->propagate_types();
EXPECT_EQ(*node->get_value_type(), *node->get_arguments()[0]->get_value_type());
}
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tp0_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tp1_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_views(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
}
TEST(type_prop, add_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Add>(x, y);
});
}
TEST(type_prop, ceiling_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Ceiling>(x, y);
});
}
TEST(type_prop, divide_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Divide>(x, y);
});
}
TEST(type_prop, floor_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Floor>(x, y);
});
}
TEST(type_prop, multiply_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Multiply>(x, y);
});
}
TEST(type_prop, subtract_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Subtract>(x, y);
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment