Commit 1138d1b9 authored by Adam Procter's avatar Adam Procter Committed by GitHub

MNIST elementwise ops (#137)

* Add elementwise ops to runtime

* equals operator implementation at graph and runtime level

* "Select" operator (elementwise ?:) at graph and runtime level

* A few file renames (`notequal` -> `not_equal`, `lessthan` -> `less_than`)

* Macro wrapping boilerplate for op registration in `external_function.cpp`
  (this covers everything except `op::Parameter`).
parent 8021a00e
......@@ -29,6 +29,7 @@ set (SRC
ops/dot.cpp
ops/op.cpp
ops/parameter.cpp
ops/select.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
......@@ -47,7 +48,6 @@ set (SRC
runtime/call_frame.cpp
runtime/external_function.cpp
shape.cpp
shape.cpp
types/element_type.cpp
types/type.cpp
util.cpp
......
......@@ -48,10 +48,12 @@
#include "ngraph/ops/minimum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/power.hpp"
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/external_function.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class NotEqual : public BinaryElementwiseComparison
{
public:
NotEqual(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BinaryElementwiseComparison(arg0, arg1)
{
}
virtual std::string description() const override { return "NotEqual"; }
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
#include "ngraph/log.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void Select::propagate_types()
{
if (m_arguments.size() != 3)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
auto arg2_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(2)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type || nullptr == arg2_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (arg0_tensor_type->get_element_type() != element::Bool::element_type())
{
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
}
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape()
|| arg0_tensor_type->get_shape() != arg2_tensor_type->get_shape())
{
throw ngraph_error("Arguments must have the same tensor view shape");
}
if (*arg1_tensor_type != *arg2_tensor_type)
{
throw ngraph_error("Arguments 1 and 2 must have the same tensor view type");
}
set_value_type_checked(arg1_tensor_type);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace op
{
class Select : public Builtin
{
public:
Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: Builtin(Nodes{arg0, arg1, arg2})
{
}
virtual std::string description() const override { return "Select"; }
virtual void propagate_types() override;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename T>
void abs(T arg, T out)
{
set_map(&*out, Eigen::abs(get_map(&*arg)));
}
template <typename ET>
class AbsInstruction : public Instruction
{
public:
AbsInstruction(size_t arg, size_t out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::abs(
call_frame.get_parameterized_tensor<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg;
size_t m_out;
};
}
}
}
......@@ -26,15 +26,9 @@ namespace ngraph
namespace eigen
{
template <typename T>
void add(T* arg0, T* arg1, T* out)
void add(T arg0, T arg1, T out)
{
set_map(out, get_map(arg0) + get_map(arg1));
}
template <typename T>
void add(std::shared_ptr<T>& arg0, std::shared_ptr<T>& arg1, std::shared_ptr<T>& out)
{
add(&*arg0, &*arg1, &*out);
set_map(&*out, get_map(&*arg0) + get_map(&*arg1));
}
template <typename ET>
......@@ -50,7 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
add(call_frame.get_parameterized_tensor<ET>(m_arg0),
runtime::eigen::add(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename T>
void divide(T arg0, T arg1, T out)
{
set_map(&*out, get_map(&*arg0) / get_map(&*arg1));
}
template <typename ET>
class DivideInstruction : public Instruction
{
public:
DivideInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::divide(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename TI,typename TO>
void equal(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map(&*arg0) == get_map(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char);
}
template <typename ET>
class EqualInstruction : public Instruction
{
public:
EqualInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::equal(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename TI,typename TO>
void less_than(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map(&*arg0) < get_map(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char);
}
template <typename ET>
class LessThanInstruction : public Instruction
{
public:
LessThanInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::less_than(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename T>
void log(T arg, T out)
{
set_map(&*out, Eigen::log(get_map(&*arg)));
}
template <typename ET>
class LogInstruction : public Instruction
{
public:
LogInstruction(size_t arg, size_t out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::log(
call_frame.get_parameterized_tensor<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename T>
void maximum(T arg0, T arg1, T out)
{
set_map(out, get_map(&*arg0).max(get_map(&*arg1)));
}
template <typename ET>
class MaximumInstruction : public Instruction
{
public:
MaximumInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::maximum(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
......@@ -25,17 +25,9 @@ namespace ngraph
namespace eigen
{
template <typename T>
void multiply(T* arg0, T* arg1, T* out)
void multiply(T arg0, T arg1, T out)
{
set_map(out, get_map(arg0) * get_map(arg1));
}
template <typename T>
void multiply(std::shared_ptr<T>& arg0,
std::shared_ptr<T>& arg1,
std::shared_ptr<T>& out)
{
multiply(&*arg0, &*arg1, &*out);
set_map(&*out, get_map(&*arg0) * get_map(&*arg1));
}
template <typename ET>
......@@ -51,7 +43,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
multiply(call_frame.get_parameterized_tensor<ET>(m_arg0),
runtime::eigen::multiply(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename T>
void negate(T arg, T out)
{
set_map(&*out, -(get_map(&*arg)));
}
template <typename ET>
class NegateInstruction : public Instruction
{
public:
NegateInstruction(size_t arg, size_t out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::negate(
call_frame.get_parameterized_tensor<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename TI,typename TO>
void not_equal(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map(&*arg0) != get_map(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char);
}
template <typename ET>
class NotEqualInstruction : public Instruction
{
public:
NotEqualInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::not_equal(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename TA,typename TB>
void select(TA arg0, TB arg1, TB arg2, TB out)
{
set_map(&*out, get_map(&*arg0).select(get_map(&*arg1),get_map(&*arg2)));
}
template <typename ET>
class SelectInstruction : public Instruction
{
public:
SelectInstruction(size_t arg0, size_t arg1, size_t arg2, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_arg2(arg2)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::select(
call_frame.get_parameterized_tensor<element::Bool>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_arg2),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_arg2;
size_t m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename T>
void subtract(T arg0, T arg1, T out)
{
set_map(&*out, get_map(&*arg0) - get_map(&*arg1));
}
template <typename ET>
class SubtractInstruction : public Instruction
{
public:
SubtractInstruction(size_t arg0, size_t arg1, size_t out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::subtract(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
}
protected:
size_t m_arg0;
size_t m_arg1;
size_t m_out;
};
}
}
}
......@@ -22,14 +22,34 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/abs.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/divide.hpp"
#include "ngraph/ops/equal.hpp"
#include "ngraph/ops/less.hpp"
#include "ngraph/ops/log.hpp"
#include "ngraph/ops/maximum.hpp"
#include "ngraph/ops/multiply.hpp"
#include "ngraph/ops/negative.hpp"
#include "ngraph/ops/not_equal.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/eigen/abs.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/equal.hpp"
#include "ngraph/runtime/eigen/less_than.hpp"
#include "ngraph/runtime/eigen/log.hpp"
#include "ngraph/runtime/eigen/maximum.hpp"
#include "ngraph/runtime/eigen/multiply.hpp"
#include "ngraph/runtime/eigen/negate.hpp"
#include "ngraph/runtime/eigen/not_equal.hpp"
#include "ngraph/runtime/eigen/return.hpp"
#include "ngraph/runtime/eigen/select.hpp"
#include "ngraph/runtime/eigen/subtract.hpp"
#include "ngraph/runtime/utils.hpp"
using namespace std;
......@@ -44,6 +64,22 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
{
}
#define REGISTER_INSTRUCTION(op_class,instr_class,...) \
op_map[type_index(typeid(op_class))] = [](Node * n, \
ExternalFunction* ef, \
const std::vector<size_t>& in, \
const std::vector<size_t>& out) { \
ef->get_instructions()->push_back( \
make_shared<instr_class>(__VA_ARGS__)); \
}
#define REGISTER_UNOP(op_class,instr_class) \
REGISTER_INSTRUCTION(op_class,instr_class,in[0],out[0])
#define REGISTER_BINOP(op_class,instr_class) \
REGISTER_INSTRUCTION(op_class,instr_class,in[0],in[1],out[0])
#define REGISTER_TERNOP(op_class,instr_class) \
REGISTER_INSTRUCTION(op_class,instr_class,in[0],in[1],in[2],out[0])
// Define code generators for handled ops.
std::unordered_map<std::type_index,
std::function<void(ngraph::Node*,
......@@ -61,24 +97,20 @@ std::unordered_map<std::type_index,
op_map;
if (!initialized)
{
op_map[type_index(typeid(op::Add))] = [](Node* n,
ExternalFunction* ef,
const std::vector<size_t>& in,
const std::vector<size_t>& out) {
ef->get_instructions()->push_back(
make_shared<runtime::eigen::AddInstruction<element::Float32>>(
in[0], in[1], out[0]));
};
op_map[type_index(typeid(op::Multiply))] = [](Node* n,
ExternalFunction* ef,
const std::vector<size_t>& in,
const std::vector<size_t>& out) {
ef->get_instructions()->push_back(
make_shared<runtime::eigen::MultiplyInstruction<element::Float32>>(
in[0], in[1], out[0]));
};
REGISTER_UNOP (op::Abs, runtime::eigen::AbsInstruction<element::Float32>);
REGISTER_BINOP (op::Add, runtime::eigen::AddInstruction<element::Float32>);
REGISTER_BINOP (op::Divide, runtime::eigen::DivideInstruction<element::Float32>);
REGISTER_BINOP (op::Equal, runtime::eigen::EqualInstruction<element::Float32>);
REGISTER_BINOP (op::Less, runtime::eigen::LessThanInstruction<element::Float32>);
REGISTER_UNOP (op::Log, runtime::eigen::LogInstruction<element::Float32>);
REGISTER_BINOP (op::Maximum, runtime::eigen::MaximumInstruction<element::Float32>);
REGISTER_BINOP (op::Multiply,runtime::eigen::MultiplyInstruction<element::Float32>);
REGISTER_UNOP (op::Negative,runtime::eigen::NegateInstruction<element::Float32>);
REGISTER_BINOP (op::NotEqual,runtime::eigen::NotEqualInstruction<element::Float32>);
REGISTER_TERNOP(op::Select, runtime::eigen::SelectInstruction<element::Float32>);
REGISTER_BINOP (op::Subtract,runtime::eigen::SubtractInstruction<element::Float32>);
// Parameter, as a "runtime no-op", is a special case.
op_map[type_index(typeid(op::Parameter))] = [](Node* n,
ExternalFunction* ef,
const std::vector<size_t>& in,
......
......@@ -104,8 +104,8 @@ namespace ngraph
}
};
NGRAPH_DEFINE_TRAITED_TYPE_NAME(bool)
using Bool = TraitedType<bool>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(char)
using Bool = TraitedType<char>;
NGRAPH_DEFINE_TRAITED_TYPE_NAME(float)
using Float32 = TraitedType<float>;
......
......@@ -48,3 +48,208 @@ TEST(execute, test_abc)
(*cf)({a, c, b}, {result});
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
}
TEST(execute, test_abs)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Abs>(A), op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, -2, 0, -4.8f};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 0, 4.8f}), result->get_vector());
}
TEST(execute, test_divide)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Divide>(A,B), op::Parameters{A,B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{2, 4, 8, 16};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, 4, 8};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a,b}, {result});
ASSERT_EQ((vector<float>{2, 2, 2, 2}), result->get_vector());
}
TEST(execute, test_equal)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Equal>(A,B), op::Parameters{A,B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0, 1, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 8, 4, 8, 0, 0, 1, 1.5};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a,b}, {result});
ASSERT_EQ((vector<char>{1, 1, 0, 0, 0, 1, 1, 0}), result->get_vector());
}
TEST(execute, test_lessthan)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Less>(A,B), op::Parameters{A,B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0.5, 2, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, 4, 8, 0, 0, 1, 1.5};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a,b}, {result});
ASSERT_EQ((vector<char>{0, 0, 1, 0, 1, 0, 0, 1}), result->get_vector());
}
TEST(execute, test_log)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Log>(A), op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{expf(1), expf(2), expf(3), expf(4), expf(5), expf(6), expf(7), expf(8)};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8}), result->get_vector());
}
TEST(execute, test_maximum)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Maximum>(A,B), op::Parameters{A,B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0.5, 2, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, 4, 8, 0, 0, 1, 1.5};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a,b}, {result});
ASSERT_EQ((vector<float>{1, 8, 4, 17, 0, 0.5, 2, 1.5}), result->get_vector());
}
TEST(execute, test_negative)
{
auto shape = Shape{2, 3};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Negative>(A), op::Parameters{A});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, -2, 0, -4.8f, 8.6f, -8.6f};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{-1, 2, 0, 4.8f, -8.6f, 8.6f}), result->get_vector());
}
TEST(execute, test_notequal)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::NotEqual>(A,B), op::Parameters{A,B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{1, 8, -8, 17, -0.5, 0, 1, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 8, 4, 8, 0, 0, 1, 1.5};
auto result = ngraph::runtime::make_tensor<element::Bool>(shape);
(*cf)({a,b}, {result});
ASSERT_EQ((vector<char>{0, 0, 1, 1, 1, 0, 0, 1}), result->get_vector());
}
TEST(execute, test_select)
{
auto shape = Shape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::Bool::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto C = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Select>(A,B,C), op::Parameters{A,B,C});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Bool>(shape);
*a = vector<char>{ 0, 1, 1, 0, 0, 1, 0, 1};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{ 1, 2, 3, 4, 5, 6, 7, 8};
auto c = ngraph::runtime::make_tensor<element::Float32>(shape);
*c = vector<float>{11, 12, 13, 14, 15, 16, 17, 18};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a,b,c}, {result});
ASSERT_EQ((vector<float>{11, 2, 3, 14, 15, 6, 17, 8}), result->get_vector());
}
TEST(execute, test_subtract)
{
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto B = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto f = make_shared<Function>(make_shared<op::Subtract>(A,B), op::Parameters{A,B});
auto external = make_shared<ngraph::runtime::ExternalFunction>(f);
auto cf = external->make_call_frame();
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape);
*a = vector<float>{2, 4, 8, 16};
auto b = ngraph::runtime::make_tensor<element::Float32>(shape);
*b = vector<float>{1, 2, 4, 8};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape);
(*cf)({a,b}, {result});
ASSERT_EQ((vector<float>{1, 2, 4, 8}), result->get_vector());
}
......@@ -399,3 +399,159 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_deduce)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_deduce_correct)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
TEST(type_prop, select_shape_mismatch_a)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{3, 5}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_shape_mismatch_b)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 5}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_shape_mismatch_c)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 5}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments must have the same tensor view shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_elem_mismatch_a)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Argument 0 for arithmetic operators must have boolean element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, select_elem_mismatch_bc)
{
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0,tv0_2_4_param_1,tv0_2_4_param_2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments 1 and 2 must have the same tensor view type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment