Commit 2a35b6bf authored by Scott Cyphers's avatar Scott Cyphers

A few binary elementwise ops

parent 8fd94713
......@@ -15,6 +15,7 @@ set (SRC
tree.cpp
util.cpp
log.cpp
ops/binary_elementwise_builtin.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/convert.cpp
......@@ -24,6 +25,7 @@ set (SRC
ops/op.cpp
ops/parameter.cpp
ops/tuple.cpp
ops/unary_elementwise_builtin.cpp
types/element_type.cpp
types/type.cpp
ngraph/node.cpp
......
......@@ -24,12 +24,18 @@
#include "function.hpp"
#include "node.hpp"
#include "op.hpp"
#include "ops/add.hpp"
#include "ops/broadcast.hpp"
#include "ops/ceiling.hpp"
#include "ops/concatenate.hpp"
#include "ops/constant.hpp"
#include "ops/convert.hpp"
#include "ops/divide.hpp"
#include "ops/dot.hpp"
#include "ops/floor.hpp"
#include "ops/multiply.hpp"
#include "ops/parameter.hpp"
#include "ops/subtract.hpp"
#include "ops/tuple.hpp"
#include "shape.hpp"
#include "type.hpp"
......@@ -61,10 +61,6 @@ namespace ngraph
{
public:
virtual std::string description() const override { return "Builtin"; }
/// Name of the builtin op, for debugging and logging.
// TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {}
protected:
Builtin(const std::vector<std::shared_ptr<Node>>& args)
......@@ -73,58 +69,60 @@ namespace ngraph
}
};
class Abs : public Builtin
/// Index ops create a new way to index the same tensor elements
class IndexBuiltin : public Builtin
{
public:
Abs(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
protected:
IndexBuiltin(const std::shared_ptr<Node>& arg)
: Builtin(Nodes{arg})
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
};
class Add : public Builtin
/// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
protected:
UnaryElementwiseBuiltin(const std::shared_ptr<Node> arg)
: Builtin(Nodes{arg})
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
//virtual void propagate_types() override;
public:
virtual void propagate_types() override;
};
class Ceiling : public Builtin
/// Op(X, Y)[I] = op(X[I], Y[I])
class BinaryElementwiseBuiltin : public Builtin
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
protected:
BinaryElementwiseBuiltin(std::shared_ptr<Node> arg0, std::shared_ptr<Node> arg1)
: Builtin(Nodes{arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
//virtual void propagate_types() override;
public:
virtual void propagate_types() override;
};
class Divide : public Builtin
class Abs : public UnaryElementwiseBuiltin
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
Abs(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin({arg0})
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
};
class Equal : public Builtin
class Equal : public BinaryElementwiseBuiltin
{
public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -132,11 +130,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Exp : public Builtin
class Exp : public UnaryElementwiseBuiltin
{
public:
Exp(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
: UnaryElementwiseBuiltin(arg0)
{
}
......@@ -144,23 +142,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Floor : public Builtin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
//virtual void propagate_types() override;
};
class Greater : public Builtin
class Greater : public BinaryElementwiseBuiltin
{
public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -168,11 +154,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Less : public Builtin
class Less : public BinaryElementwiseBuiltin
{
public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -180,11 +166,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Log : public Builtin
class Log : public UnaryElementwiseBuiltin
{
public:
Log(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
: UnaryElementwiseBuiltin(arg0)
{
}
......@@ -192,11 +178,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Maximum : public Builtin
class Maximum : public BinaryElementwiseBuiltin
{
public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -204,11 +190,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Minimum : public Builtin
class Minimum : public BinaryElementwiseBuiltin
{
public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -216,23 +202,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Multiply : public Builtin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
//virtual void propagate_types() override;
};
class Negative : public Builtin
class Negative : public UnaryElementwiseBuiltin
{
public:
Negative(const std::shared_ptr<Node>& arg0)
: Builtin({arg0})
: UnaryElementwiseBuiltin(arg0)
{
}
......@@ -240,11 +214,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Power : public Builtin
class Power : public BinaryElementwiseBuiltin
{
public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -252,11 +226,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Remainder : public Builtin
class Remainder : public BinaryElementwiseBuiltin
{
public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
......@@ -264,11 +238,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class Reshape : public Builtin
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: Builtin({arg0})
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
......@@ -278,17 +252,5 @@ namespace ngraph
protected:
Shape m_shape;
};
class Subtract : public Builtin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
//virtual void propagate_types() override;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Add : public BinaryElementwiseBuiltin
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
};
}
}
......@@ -18,7 +18,7 @@ namespace ngraph
{
namespace op
{
class Broadcast : public Builtin
class Broadcast : public IndexBuiltin
{
public:
///
......@@ -30,7 +30,7 @@ namespace ngraph
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Builtin({arg})
: IndexBuiltin(arg)
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Ceiling : public BinaryElementwiseBuiltin
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Ceiling"; }
};
}
}
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Convert : public Builtin
class Convert : public UnaryElementwiseBuiltin
{
public:
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: Builtin({arg})
: UnaryElementwiseBuiltin({arg})
, m_element_type(element_type)
{
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Divide : public BinaryElementwiseBuiltin
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Divide"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Floor : public BinaryElementwiseBuiltin
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Floor"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Multiply : public BinaryElementwiseBuiltin
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Multiply"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Subtract : public BinaryElementwiseBuiltin
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Subtract"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (*arg0_tensor_type != *arg1_tensor_type)
{
throw ngraph_error("Arguments must have the same tensor view type");
}
set_value_type_checked(arg0_tensor_type);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
set_value_type_checked(arg_tensor_type);
}
......@@ -20,6 +20,11 @@
using namespace std;
using namespace ngraph;
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node);
void test_binary_bad_arguments_views(const shared_ptr<Node>& node);
void test_binary_good_arguments(const shared_ptr<Node>& node);
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y));
TEST(type_prop, broadcast_deduce)
{
// Deduce type
......@@ -84,3 +89,103 @@ TEST(type_prop, broadcast_bad_arguments)
}
}
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
FAIL() << "Tuple argument to add not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must be tensor views"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
FAIL() << "Tuple argument to add not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_good_arguments(const shared_ptr<Node>& node)
{
node->propagate_types();
EXPECT_EQ(*node->get_value_type(), *node->get_arguments()[0]->get_value_type());
}
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tp0_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tp1_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_views(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
}
TEST(type_prop, add_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Add>(x, y);
});
}
TEST(type_prop, ceiling_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Ceiling>(x, y);
});
}
TEST(type_prop, divide_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Divide>(x, y);
});
}
TEST(type_prop, floor_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Floor>(x, y);
});
}
TEST(type_prop, multiply_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Multiply>(x, y);
});
}
TEST(type_prop, subtract_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Subtract>(x, y);
});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment