Commit c19a5319 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Elementwise ops with non-uniform element types for args/results (#105)

* Implemented binops/unops with different return type from args

* Implemented generic bases class for comparison and arith ops, changed current comparison ops to subclass it

* Added a TraitedType for booleans

* Moved arith and comparison ops to separate header files from `op.h`
parent 36cc0317
......@@ -34,7 +34,9 @@ set (SRC
ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp
ngraph/visualize.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/constant.cpp
......@@ -44,6 +46,7 @@ set (SRC
ops/op.cpp
ops/parameter.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
tree.cpp
types/element_type.cpp
......
......@@ -104,6 +104,9 @@ namespace ngraph
}
};
NGRAPH_DEFINE_TRAITED_TYPE_NAME(bool)
using Bool = TraitedType<bool>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>;
......
......@@ -31,6 +31,7 @@
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp"
......@@ -39,9 +40,19 @@
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/floor.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp"
......
......@@ -79,6 +79,21 @@ namespace ngraph
}
};
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
virtual std::string get_op_class_name() const override { return "Reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
};
/// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin
......@@ -88,11 +103,24 @@ namespace ngraph
: Builtin(Nodes{arg})
{
}
virtual const element::Type& propagate_element_types(
const element::Type& arg_element_type) const = 0;
public:
virtual void propagate_types() override;
};
class UnaryElementwiseArithmetic : public UnaryElementwiseBuiltin
{
protected:
UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: UnaryElementwiseBuiltin({arg})
{
}
virtual const element::Type& propagate_element_types(
const element::Type& arg_element_type) const final override;
};
/// Op(X, Y)[I] = op(X[I], Y[I])
class BinaryElementwiseBuiltin : public Builtin
{
......@@ -102,156 +130,45 @@ namespace ngraph
: Builtin(Nodes{arg0, arg1})
{
}
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const = 0;
public:
virtual void propagate_types() override;
};
class Abs : public UnaryElementwiseBuiltin
{
public:
Abs(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin({arg0})
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
};
class Equal : public BinaryElementwiseBuiltin
class BinaryElementwiseComparison : public BinaryElementwiseBuiltin
{
public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
BinaryElementwiseComparison(
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Equal"; }
virtual std::string get_op_class_name() const override { return "BinaryElementwiseComparison"; }
//virtual void propagate_types() override;
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const override;
};
class Exp : public UnaryElementwiseBuiltin
class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin
{
public:
Exp(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin(arg0)
{
}
virtual std::string get_op_class_name() const override { return "Exp"; }
//virtual void propagate_types() override;
};
class Greater : public BinaryElementwiseBuiltin
{
public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Greater"; }
//virtual void propagate_types() override;
};
class Less : public BinaryElementwiseBuiltin
{
public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
BinaryElementwiseArithmetic(
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Less"; }
//virtual void propagate_types() override;
};
class Log : public UnaryElementwiseBuiltin
{
public:
Log(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin(arg0)
{
}
virtual std::string get_op_class_name() const override { return "Log"; }
virtual std::string get_op_class_name() const override { return "BinaryElementwiseArithmetic"; }
//virtual void propagate_types() override;
};
class Maximum : public BinaryElementwiseBuiltin
{
public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Max"; }
//virtual void propagate_types() override;
};
class Minimum : public BinaryElementwiseBuiltin
{
public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Min"; }
//virtual void propagate_types() override;
};
class Negative : public UnaryElementwiseBuiltin
{
public:
Negative(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin(arg0)
{
}
virtual std::string get_op_class_name() const override { return "Negative"; }
//virtual void propagate_types() override;
};
class Power : public BinaryElementwiseBuiltin
{
public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Power"; }
//virtual void propagate_types() override;
};
class Remainder : public BinaryElementwiseBuiltin
{
public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Remainder"; }
//virtual void propagate_types() override;
};
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
virtual std::string get_op_class_name() const override { return "Reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type)
const final override;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Abs : public UnaryElementwiseArithmetic
{
public:
Abs(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
};
}
}
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Add : public BinaryElementwiseBuiltin
class Add : public BinaryElementwiseArithmetic
{
public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Add"; }
......
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Ceiling : public BinaryElementwiseBuiltin
class Ceiling : public UnaryElementwiseArithmetic
{
public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
Ceiling(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
......
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Divide : public BinaryElementwiseBuiltin
class Divide : public BinaryElementwiseArithmetic
{
public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Equal : public BinaryElementwiseComparison
{
public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Equal"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Exp : public UnaryElementwiseArithmetic
{
public:
Exp(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Exp"; }
};
}
}
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Floor : public BinaryElementwiseBuiltin
class Floor : public UnaryElementwiseArithmetic
{
public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
Floor(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Greater : public BinaryElementwiseComparison
{
public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Greater"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Less : public BinaryElementwiseComparison
{
public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Less"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Log : public UnaryElementwiseArithmetic
{
public:
Log(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Log"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Maximum : public BinaryElementwiseArithmetic
{
public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Maximum"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Minimum : public BinaryElementwiseArithmetic
{
public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Minimum"; }
};
}
}
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Multiply : public BinaryElementwiseBuiltin
class Multiply : public BinaryElementwiseArithmetic
{
public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Negative : public UnaryElementwiseArithmetic
{
public:
Negative(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Negative"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Power : public BinaryElementwiseArithmetic
{
public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Power"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Remainder : public BinaryElementwiseArithmetic
{
public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Remainder"; }
};
}
}
......@@ -18,11 +18,11 @@ namespace ngraph
{
namespace op
{
class Subtract : public BinaryElementwiseBuiltin
class Subtract : public BinaryElementwiseArithmetic
{
public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
}
......@@ -18,6 +18,7 @@
#include "log.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
......@@ -35,10 +36,16 @@ void BinaryElementwiseBuiltin::propagate_types()
{
throw ngraph_error("Arguments must be tensor views");
}
if (*arg0_tensor_type != *arg1_tensor_type)
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape())
{
throw ngraph_error("Arguments must have the same tensor view type");
throw ngraph_error("Arguments must have the same tensor view shape");
}
set_value_type_checked(arg0_tensor_type);
const element::Type& result_element_type =
propagate_element_types(arg0_tensor_type->get_element_type(),
arg1_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type,
arg0_tensor_type->get_shape()));
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
return element::Bool::element_type();
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& UnaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg_element_type) const
{
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
return arg_element_type;
}
......@@ -17,6 +17,7 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
......@@ -33,5 +34,9 @@ void UnaryElementwiseBuiltin::propagate_types()
throw ngraph_error("Argument must be tensor view");
}
set_value_type_checked(arg_tensor_type);
const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type,
arg_tensor_type->get_shape()));
}
......@@ -252,7 +252,7 @@ void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
}
}
void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
void test_binary_bad_arguments_view_shapes(const shared_ptr<Node>& node)
{
try
{
......@@ -262,7 +262,25 @@ void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view type"));
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_bad_arguments_view_element_types(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view element type"));
}
catch (...)
{
......@@ -285,13 +303,16 @@ void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_views(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_bad_arguments_view_shapes(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_bad_arguments_view_element_types(f(tv0_2_4_param_0, tv0_2_4_param_2));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
}
......@@ -302,13 +323,6 @@ TEST(type_prop, add_bad_arguments)
});
}
TEST(type_prop, ceiling_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Ceiling>(x, y);
});
}
TEST(type_prop, divide_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
......@@ -316,13 +330,6 @@ TEST(type_prop, divide_bad_arguments)
});
}
TEST(type_prop, floor_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Floor>(x, y);
});
}
TEST(type_prop, multiply_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
......@@ -336,3 +343,59 @@ TEST(type_prop, subtract_bad_arguments)
return make_shared<op::Subtract>(x, y);
});
}
TEST(type_prop, comparison_good)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto eq = make_shared<op::Equal>(tv0_2_4_param_0,tv0_2_4_param_1);
auto expected_type = TensorViewType(element::Bool::element_type(), Shape{2, 4});
eq->propagate_types();
EXPECT_EQ(*eq->get_value_type(),expected_type);
}
TEST(type_prop, binary_arithmetic_bad_argument_element_types)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Add>(tv0_2_4_param_0,tv0_2_4_param_1);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Operands for arithmetic operators must have numeric element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, unary_arithmetic_bad_argument_element_types)
{
auto tv0_2_4_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Negative>(tv0_2_4_param);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Operands for arithmetic operators must have numeric element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment