Commit c19a5319 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Elementwise ops with non-uniform element types for args/results (#105)

* Implemented binops/unops with different return type from args

* Implemented generic bases class for comparison and arith ops, changed current comparison ops to subclass it

* Added a TraitedType for booleans

* Moved arith and comparison ops to separate header files from `op.h`
parent 36cc0317
...@@ -34,7 +34,9 @@ set (SRC ...@@ -34,7 +34,9 @@ set (SRC
ngraph/runtime/eigen/tensor_view.cpp ngraph/runtime/eigen/tensor_view.cpp
ngraph/shape.cpp ngraph/shape.cpp
ngraph/visualize.cpp ngraph/visualize.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/constant.cpp ops/constant.cpp
...@@ -44,6 +46,7 @@ set (SRC ...@@ -44,6 +46,7 @@ set (SRC
ops/op.cpp ops/op.cpp
ops/parameter.cpp ops/parameter.cpp
ops/tuple.cpp ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp ops/unary_elementwise_builtin.cpp
tree.cpp tree.cpp
types/element_type.cpp types/element_type.cpp
......
...@@ -104,6 +104,9 @@ namespace ngraph ...@@ -104,6 +104,9 @@ namespace ngraph
} }
}; };
NGRAPH_DEFINE_TRAITED_TYPE_NAME(bool)
using Bool = TraitedType<bool>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float) NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>; using Float32 = TraitedType<float>;
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op.hpp" #include "ngraph/op.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp" #include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp" #include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/ceiling.hpp" #include "ngraph/ops/ceiling.hpp"
...@@ -39,9 +40,19 @@ ...@@ -39,9 +40,19 @@
#include "ngraph/ops/convert.hpp" #include "ngraph/ops/convert.hpp"
#include "ngraph/ops/divide.hpp" #include "ngraph/ops/divide.hpp"
#include "ngraph/ops/dot.hpp" #include "ngraph/ops/dot.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/exp.hpp"
#include "ngraph/ops/floor.hpp" #include "ngraph/ops/floor.hpp"
#include "ngraph/ops/greater.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp" #include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/eigen/tensor_view.hpp" #include "ngraph/runtime/eigen/tensor_view.hpp"
......
...@@ -79,6 +79,21 @@ namespace ngraph ...@@ -79,6 +79,21 @@ namespace ngraph
} }
}; };
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
virtual std::string get_op_class_name() const override { return "Reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
};
/// Operations where the same element function is applied to each element /// Operations where the same element function is applied to each element
/// Op(X)[I] = op(X[I]) /// Op(X)[I] = op(X[I])
class UnaryElementwiseBuiltin : public Builtin class UnaryElementwiseBuiltin : public Builtin
...@@ -88,11 +103,24 @@ namespace ngraph ...@@ -88,11 +103,24 @@ namespace ngraph
: Builtin(Nodes{arg}) : Builtin(Nodes{arg})
{ {
} }
virtual const element::Type& propagate_element_types(
const element::Type& arg_element_type) const = 0;
public: public:
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
class UnaryElementwiseArithmetic : public UnaryElementwiseBuiltin
{
protected:
UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: UnaryElementwiseBuiltin({arg})
{
}
virtual const element::Type& propagate_element_types(
const element::Type& arg_element_type) const final override;
};
/// Op(X, Y)[I] = op(X[I], Y[I]) /// Op(X, Y)[I] = op(X[I], Y[I])
class BinaryElementwiseBuiltin : public Builtin class BinaryElementwiseBuiltin : public Builtin
{ {
...@@ -102,156 +130,45 @@ namespace ngraph ...@@ -102,156 +130,45 @@ namespace ngraph
: Builtin(Nodes{arg0, arg1}) : Builtin(Nodes{arg0, arg1})
{ {
} }
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const = 0;
public: public:
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
class Abs : public UnaryElementwiseBuiltin class BinaryElementwiseComparison : public BinaryElementwiseBuiltin
{
public:
Abs(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin({arg0})
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override;
};
class Equal : public BinaryElementwiseBuiltin
{ {
public: public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) BinaryElementwiseComparison(
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Equal"; } virtual std::string get_op_class_name() const override { return "BinaryElementwiseComparison"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const override;
}; };
class Exp : public UnaryElementwiseBuiltin class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin
{ {
public: public:
Exp(const std::shared_ptr<Node>& arg0) BinaryElementwiseArithmetic(
: UnaryElementwiseBuiltin(arg0) const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{
}
virtual std::string get_op_class_name() const override { return "Exp"; }
//virtual void propagate_types() override;
};
class Greater : public BinaryElementwiseBuiltin
{
public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Greater"; }
//virtual void propagate_types() override;
};
class Less : public BinaryElementwiseBuiltin
{
public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseBuiltin(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Less"; } virtual std::string get_op_class_name() const override { return "BinaryElementwiseArithmetic"; }
//virtual void propagate_types() override;
};
class Log : public UnaryElementwiseBuiltin
{
public:
Log(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin(arg0)
{
}
virtual std::string get_op_class_name() const override { return "Log"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
class Maximum : public BinaryElementwiseBuiltin const element::Type& arg1_element_type)
{ const final override;
public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Max"; }
//virtual void propagate_types() override;
};
class Minimum : public BinaryElementwiseBuiltin
{
public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Min"; }
//virtual void propagate_types() override;
};
class Negative : public UnaryElementwiseBuiltin
{
public:
Negative(const std::shared_ptr<Node>& arg0)
: UnaryElementwiseBuiltin(arg0)
{
}
virtual std::string get_op_class_name() const override { return "Negative"; }
//virtual void propagate_types() override;
};
class Power : public BinaryElementwiseBuiltin
{
public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Power"; }
//virtual void propagate_types() override;
};
class Remainder : public BinaryElementwiseBuiltin
{
public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Remainder"; }
//virtual void propagate_types() override;
};
class Reshape : public IndexBuiltin
{
public:
Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: IndexBuiltin(arg0)
, m_shape(shape)
{
}
virtual std::string get_op_class_name() const override { return "Reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
}; };
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Abs : public UnaryElementwiseArithmetic
{
public:
Abs(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Abs"; }
};
}
}
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Add : public BinaryElementwiseBuiltin class Add : public BinaryElementwiseArithmetic
{ {
public: public:
Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
virtual std::string get_op_class_name() const override { return "Add"; } virtual std::string get_op_class_name() const override { return "Add"; }
......
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Ceiling : public BinaryElementwiseBuiltin class Ceiling : public UnaryElementwiseArithmetic
{ {
public: public:
Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Ceiling(const std::shared_ptr<Node>& arg)
: BinaryElementwiseBuiltin(arg0, arg1) : UnaryElementwiseArithmetic(arg)
{ {
} }
......
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Divide : public BinaryElementwiseBuiltin class Divide : public BinaryElementwiseArithmetic
{ {
public: public:
Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Equal : public BinaryElementwiseComparison
{
public:
Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Equal"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Exp : public UnaryElementwiseArithmetic
{
public:
Exp(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Exp"; }
};
}
}
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Floor : public BinaryElementwiseBuiltin class Floor : public UnaryElementwiseArithmetic
{ {
public: public:
Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Floor(const std::shared_ptr<Node>& arg)
: BinaryElementwiseBuiltin(arg0, arg1) : UnaryElementwiseArithmetic(arg)
{ {
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Greater : public BinaryElementwiseComparison
{
public:
Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Greater"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Less : public BinaryElementwiseComparison
{
public:
Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Less"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Log : public UnaryElementwiseArithmetic
{
public:
Log(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Log"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Maximum : public BinaryElementwiseArithmetic
{
public:
Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Maximum"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Minimum : public BinaryElementwiseArithmetic
{
public:
Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Minimum"; }
};
}
}
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Multiply : public BinaryElementwiseBuiltin class Multiply : public BinaryElementwiseArithmetic
{ {
public: public:
Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Negative : public UnaryElementwiseArithmetic
{
public:
Negative(const std::shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
}
virtual std::string get_op_class_name() const override { return "Negative"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Power : public BinaryElementwiseArithmetic
{
public:
Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Power"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Remainder : public BinaryElementwiseArithmetic
{
public:
Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseArithmetic(arg0, arg1)
{
}
virtual std::string get_op_class_name() const override { return "Remainder"; }
};
}
}
...@@ -18,11 +18,11 @@ namespace ngraph ...@@ -18,11 +18,11 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Subtract : public BinaryElementwiseBuiltin class Subtract : public BinaryElementwiseArithmetic
{ {
public: public:
Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1) : BinaryElementwiseArithmetic(arg0, arg1)
{ {
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "log.hpp" #include "log.hpp"
using namespace std; using namespace std;
using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types() void BinaryElementwiseBuiltin::propagate_types()
...@@ -35,10 +36,16 @@ void BinaryElementwiseBuiltin::propagate_types() ...@@ -35,10 +36,16 @@ void BinaryElementwiseBuiltin::propagate_types()
{ {
throw ngraph_error("Arguments must be tensor views"); throw ngraph_error("Arguments must be tensor views");
} }
if (*arg0_tensor_type != *arg1_tensor_type) if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape())
{ {
throw ngraph_error("Arguments must have the same tensor view type"); throw ngraph_error("Arguments must have the same tensor view shape");
} }
set_value_type_checked(arg0_tensor_type); const element::Type& result_element_type =
propagate_element_types(arg0_tensor_type->get_element_type(),
arg1_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type,
arg0_tensor_type->get_shape()));
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
return element::Bool::element_type();
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& UnaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg_element_type) const
{
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
return arg_element_type;
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types() void UnaryElementwiseBuiltin::propagate_types()
...@@ -33,5 +34,9 @@ void UnaryElementwiseBuiltin::propagate_types() ...@@ -33,5 +34,9 @@ void UnaryElementwiseBuiltin::propagate_types()
throw ngraph_error("Argument must be tensor view"); throw ngraph_error("Argument must be tensor view");
} }
set_value_type_checked(arg_tensor_type); const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type());
set_value_type_checked(make_shared<TensorViewType>(result_element_type,
arg_tensor_type->get_shape()));
} }
...@@ -252,7 +252,7 @@ void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node) ...@@ -252,7 +252,7 @@ void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
} }
} }
void test_binary_bad_arguments_views(const shared_ptr<Node>& node) void test_binary_bad_arguments_view_shapes(const shared_ptr<Node>& node)
{ {
try try
{ {
...@@ -262,7 +262,25 @@ void test_binary_bad_arguments_views(const shared_ptr<Node>& node) ...@@ -262,7 +262,25 @@ void test_binary_bad_arguments_views(const shared_ptr<Node>& node)
} }
catch (const ngraph_error& error) catch (const ngraph_error& error)
{ {
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view type")); EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_bad_arguments_view_element_types(const shared_ptr<Node>& node)
{
try
{
node->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view element type"));
} }
catch (...) catch (...)
{ {
...@@ -285,13 +303,16 @@ void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr ...@@ -285,13 +303,16 @@ void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4})); make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>( auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4})); make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>( auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2})); make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param)); test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0)); test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param)); test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_views(f(tv0_2_4_param_0, tv0_4_2_param)); test_binary_bad_arguments_view_shapes(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_bad_arguments_view_element_types(f(tv0_2_4_param_0, tv0_2_4_param_2));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1)); test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
} }
...@@ -302,13 +323,6 @@ TEST(type_prop, add_bad_arguments) ...@@ -302,13 +323,6 @@ TEST(type_prop, add_bad_arguments)
}); });
} }
TEST(type_prop, ceiling_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Ceiling>(x, y);
});
}
TEST(type_prop, divide_bad_arguments) TEST(type_prop, divide_bad_arguments)
{ {
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> { test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
...@@ -316,13 +330,6 @@ TEST(type_prop, divide_bad_arguments) ...@@ -316,13 +330,6 @@ TEST(type_prop, divide_bad_arguments)
}); });
} }
TEST(type_prop, floor_bad_arguments)
{
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
return make_shared<op::Floor>(x, y);
});
}
TEST(type_prop, multiply_bad_arguments) TEST(type_prop, multiply_bad_arguments)
{ {
test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> { test_binary([](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
...@@ -336,3 +343,59 @@ TEST(type_prop, subtract_bad_arguments) ...@@ -336,3 +343,59 @@ TEST(type_prop, subtract_bad_arguments)
return make_shared<op::Subtract>(x, y); return make_shared<op::Subtract>(x, y);
}); });
} }
TEST(type_prop, comparison_good)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto eq = make_shared<op::Equal>(tv0_2_4_param_0,tv0_2_4_param_1);
auto expected_type = TensorViewType(element::Bool::element_type(), Shape{2, 4});
eq->propagate_types();
EXPECT_EQ(*eq->get_value_type(),expected_type);
}
TEST(type_prop, binary_arithmetic_bad_argument_element_types)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Add>(tv0_2_4_param_0,tv0_2_4_param_1);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Operands for arithmetic operators must have numeric element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, unary_arithmetic_bad_argument_element_types)
{
auto tv0_2_4_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Negative>(tv0_2_4_param);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Operands for arithmetic operators must have numeric element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment