Unverified Commit f7a22b95 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #278 from NervanaSystems/aprocter/de-eigenize-partial

Work so far on de-Eigenization
parents 2f0a33c3 18998c41
......@@ -14,8 +14,9 @@
set (SRC
autodiff/adjoints.cpp
builder/autobroadcast.cpp
builder/reduce_ops.cpp
builder/numpy_transpose.cpp
builder/reduce_ops.cpp
coordinate_iterator.cpp
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/except.hpp"
using namespace ngraph;
CoordinateIterator::CoordinateIterator(const Shape& space_shape,
const Strides& strides,
const Coordinate& window_outer_corner,
const Coordinate& window_inner_corner)
: m_space_shape(space_shape)
, m_strides(strides)
, m_window_outer_corner(window_outer_corner)
, m_window_inner_corner(window_inner_corner)
, m_current_coordinate(window_inner_corner)
{
if (space_shape.size() != window_inner_corner.size())
{
throw ngraph_error("Coordinate iterator inner corner rank does not match space shape rank");
}
if (space_shape.size() != window_outer_corner.size())
{
throw ngraph_error("Coordinate iterator outer corner rank does not match space shape rank");
}
if (space_shape.size() != strides.size())
{
throw ngraph_error("Coordinate iterator stride rank does not match space shape rank");
}
for (size_t i = 0; i < space_shape.size(); i++)
{
if (window_inner_corner[i] > window_outer_corner[i])
{
throw ngraph_error("Coordinate iterator inner corner is outside outer corner");
}
if (window_inner_corner[i] >= m_space_shape[i])
{
throw ngraph_error("Coordinate iterator inner corner is out of bounds");
}
if (window_outer_corner[i] > m_space_shape[i])
{
throw ngraph_error("Coordinate iterator outer corner is out of bounds");
}
if (m_strides[i] == 0)
{
throw ngraph_error("Coordinate iterator stride is zero");
}
}
}
CoordinateIterator::CoordinateIterator(const Shape& space_shape)
: CoordinateIterator(space_shape,
Strides(space_shape.size(), 1),
space_shape,
Coordinate(space_shape.size(), 0))
{
}
CoordinateIterator::CoordinateIterator(const Shape& space_shape, const Strides& strides)
: CoordinateIterator(space_shape, strides, space_shape, Coordinate(space_shape.size(), 0))
{
}
size_t CoordinateIterator::get_current_index() const
{
size_t index = 0;
size_t stride = 1;
for (size_t i = m_space_shape.size(); i-- > 0;)
{
index += m_current_coordinate[i] * stride;
stride *= m_space_shape[i];
}
return index;
}
bool CoordinateIterator::increment()
{
bool overflow = true;
for (size_t i = m_space_shape.size(); i-- > 0;)
{
m_current_coordinate[i] += m_strides[i];
if (m_current_coordinate[i] >= m_window_outer_corner[i])
{
m_current_coordinate[i] = m_window_inner_corner[i];
}
else
{
overflow = false;
break;
}
}
return !overflow;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
namespace ngraph
{
class CoordinateIterator
{
public:
CoordinateIterator(const Shape& space_shape,
const Strides& strides,
const Coordinate& window_outer_corner,
const Coordinate& window_inner_corner);
CoordinateIterator(const Shape& space_shape);
CoordinateIterator(const Shape& space_shape, const Strides& strides);
Coordinate get_current_coordinate() const { return m_current_coordinate; }
size_t get_current_index() const;
bool increment();
private:
const Shape m_space_shape;
const Strides m_strides;
const Coordinate m_window_outer_corner;
const Coordinate m_window_inner_corner;
Coordinate m_current_coordinate;
};
}
......@@ -45,6 +45,7 @@
#include "ngraph/builder/numpy_transpose.hpp"
#include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void abs(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
// TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < 0 ? -arg[i] : arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void acos(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::acos(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void add(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] + arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void asin(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::asin(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void atan(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::atan(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void broadcast(T* arg,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
// For the outer loop we will walk over the entire input shape.
CoordinateIterator arg_iter(in_shape);
do
{
// For the inner loop we will walk across the entire axis for the new broadcast axes, and stay put at the current arg position for the existing axes.
Coordinate arg_coordinate = arg_iter.get_current_coordinate();
Strides out_strides(out_shape.size(), 1);
Coordinate out_outer_corner(out_shape.size());
Coordinate out_inner_corner(out_shape.size());
size_t arg_pos = 0;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (broadcast_axes.find(i) == broadcast_axes.end())
{
// This is an existing axis.
out_outer_corner[i] = arg_coordinate[arg_pos];
out_inner_corner[i] = arg_coordinate[arg_pos];
arg_pos++;
}
else
{
// This is a new broadcast axis.
out_outer_corner[i] = out_shape[i];
out_inner_corner[i] = 0;
}
}
CoordinateIterator out_iter(
out_shape, out_strides, out_outer_corner, out_inner_corner);
do
{
out[out_iter.get_current_index()] = arg[arg_iter.get_current_index()];
} while (out_iter.increment());
} while (arg_iter.increment());
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void ceiling(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::ceil(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename TI, typename TO>
void convert(TI* arg, TO* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = static_cast<TO>(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void copy(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void cos(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::cos(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void cosh(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::cosh(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void divide(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] / arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void equal(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] == arg1[i];
}
}
}
}
}
#pragma clang diagnostic pop
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void exp(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::exp(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void floor(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::floor(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void greater(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] > arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void greater_eq(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] >= arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void less(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] < arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void less_eq(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] <= arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void log(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::log(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void maximum(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] > arg1[i] ? arg0[i] : arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void minimum(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] < arg1[i] ? arg0[i] : arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void multiply(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] * arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void negate(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = -arg[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
void logical_not(char* arg,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = !(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void not_equal(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] != arg1[i];
}
}
}
}
}
#pragma clang diagnostic pop
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void power(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::pow(arg0[i], arg1[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void select(char* arg0,
T* arg1,
T* arg2,
T* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] ? arg1[i] : arg2[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sign(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = (arg[i] < 0 ? -1 : (arg[i] > 0 ? 1 : 0));
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sin(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::sin(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sinh(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::sinh(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sqrt(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::sqrt(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void subtract(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] - arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void tan(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::tan(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void tanh(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::tanh(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class BroadcastVectorRowwiseInstruction : public Instruction
{
public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class GreaterThanInstruction : public Instruction
{
public:
GreaterThanInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class LessThanInstruction : public Instruction
{
public:
LessThanInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -70,64 +70,64 @@
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/ngvm/eigen/abs.hpp"
#include "ngraph/runtime/ngvm/eigen/acos.hpp"
#include "ngraph/runtime/ngvm/eigen/add.hpp"
#include "ngraph/runtime/ngvm/eigen/asin.hpp"
#include "ngraph/runtime/ngvm/eigen/atan.hpp"
#include "ngraph/runtime/ngvm/eigen/broadcast_scalar.hpp"
#include "ngraph/runtime/ngvm/eigen/broadcast_vector_colwise.hpp"
#include "ngraph/runtime/ngvm/eigen/broadcast_vector_rowwise.hpp"
#include "ngraph/runtime/ngvm/eigen/call.hpp"
#include "ngraph/runtime/ngvm/eigen/ceiling.hpp"
#include "ngraph/runtime/ngvm/eigen/concat_matrix.hpp"
#include "ngraph/runtime/ngvm/eigen/concat_vector.hpp"
#include "ngraph/runtime/ngvm/eigen/constant.hpp"
#include "ngraph/runtime/ngvm/eigen/convert.hpp"
#include "ngraph/runtime/ngvm/eigen/copy.hpp"
#include "ngraph/runtime/ngvm/eigen/cos.hpp"
#include "ngraph/runtime/ngvm/eigen/cosh.hpp"
#include "ngraph/runtime/ngvm/eigen/divide.hpp"
#include "ngraph/runtime/ngvm/eigen/dot.hpp"
#include "ngraph/runtime/ngvm/eigen/equal.hpp"
#include "ngraph/runtime/ngvm/eigen/exp.hpp"
#include "ngraph/runtime/ngvm/eigen/floor.hpp"
#include "ngraph/runtime/ngvm/eigen/greater_eq.hpp"
#include "ngraph/runtime/ngvm/eigen/greater_than.hpp"
#include "ngraph/runtime/ngvm/eigen/less_eq.hpp"
#include "ngraph/runtime/ngvm/eigen/less_than.hpp"
#include "ngraph/runtime/ngvm/eigen/log.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_mult.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_slice.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_transpose.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_vector_product.hpp"
#include "ngraph/runtime/ngvm/eigen/maximum.hpp"
#include "ngraph/runtime/ngvm/eigen/minimum.hpp"
#include "ngraph/runtime/ngvm/eigen/multiply.hpp"
#include "ngraph/runtime/ngvm/eigen/negate.hpp"
#include "ngraph/runtime/ngvm/eigen/not.hpp"
#include "ngraph/runtime/ngvm/eigen/not_equal.hpp"
#include "ngraph/runtime/ngvm/eigen/power.hpp"
#include "ngraph/runtime/ngvm/eigen/reduce_matrix_columns.hpp"
#include "ngraph/runtime/ngvm/eigen/reduce_matrix_rows.hpp"
#include "ngraph/runtime/ngvm/eigen/reduce_to_scalar.hpp"
#include "ngraph/runtime/ngvm/eigen/replace_matrix_slice.hpp"
#include "ngraph/runtime/ngvm/eigen/replace_vector_slice.hpp"
#include "ngraph/runtime/ngvm/eigen/return.hpp"
#include "ngraph/runtime/ngvm/eigen/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/eigen/select.hpp"
#include "ngraph/runtime/ngvm/eigen/sign.hpp"
#include "ngraph/runtime/ngvm/eigen/sin.hpp"
#include "ngraph/runtime/ngvm/eigen/sinh.hpp"
#include "ngraph/runtime/ngvm/eigen/sqrt.hpp"
#include "ngraph/runtime/ngvm/eigen/subtract.hpp"
#include "ngraph/runtime/ngvm/eigen/sum_matrix_columns.hpp"
#include "ngraph/runtime/ngvm/eigen/sum_matrix_rows.hpp"
#include "ngraph/runtime/ngvm/eigen/sum_to_scalar.hpp"
#include "ngraph/runtime/ngvm/eigen/tan.hpp"
#include "ngraph/runtime/ngvm/eigen/tanh.hpp"
#include "ngraph/runtime/ngvm/eigen/vector_slice.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/ngvm/instruction/abs.hpp"
#include "ngraph/runtime/ngvm/instruction/acos.hpp"
#include "ngraph/runtime/ngvm/instruction/add.hpp"
#include "ngraph/runtime/ngvm/instruction/asin.hpp"
#include "ngraph/runtime/ngvm/instruction/atan.hpp"
#include "ngraph/runtime/ngvm/instruction/broadcast.hpp"
#include "ngraph/runtime/ngvm/instruction/call.hpp"
#include "ngraph/runtime/ngvm/instruction/ceiling.hpp"
#include "ngraph/runtime/ngvm/instruction/constant.hpp"
#include "ngraph/runtime/ngvm/instruction/convert.hpp"
#include "ngraph/runtime/ngvm/instruction/copy.hpp"
#include "ngraph/runtime/ngvm/instruction/copy_by_index.hpp"
#include "ngraph/runtime/ngvm/instruction/cos.hpp"
#include "ngraph/runtime/ngvm/instruction/cosh.hpp"
#include "ngraph/runtime/ngvm/instruction/divide.hpp"
#include "ngraph/runtime/ngvm/instruction/equal.hpp"
#include "ngraph/runtime/ngvm/instruction/exp.hpp"
#include "ngraph/runtime/ngvm/instruction/floor.hpp"
#include "ngraph/runtime/ngvm/instruction/greater.hpp"
#include "ngraph/runtime/ngvm/instruction/greater_eq.hpp"
#include "ngraph/runtime/ngvm/instruction/less.hpp"
#include "ngraph/runtime/ngvm/instruction/less_eq.hpp"
#include "ngraph/runtime/ngvm/instruction/log.hpp"
#include "ngraph/runtime/ngvm/instruction/maximum.hpp"
#include "ngraph/runtime/ngvm/instruction/minimum.hpp"
#include "ngraph/runtime/ngvm/instruction/multiply.hpp"
#include "ngraph/runtime/ngvm/instruction/negate.hpp"
#include "ngraph/runtime/ngvm/instruction/not.hpp"
#include "ngraph/runtime/ngvm/instruction/not_equal.hpp"
#include "ngraph/runtime/ngvm/instruction/power.hpp"
#include "ngraph/runtime/ngvm/instruction/return.hpp"
#include "ngraph/runtime/ngvm/instruction/select.hpp"
#include "ngraph/runtime/ngvm/instruction/sign.hpp"
#include "ngraph/runtime/ngvm/instruction/sin.hpp"
#include "ngraph/runtime/ngvm/instruction/sinh.hpp"
#include "ngraph/runtime/ngvm/instruction/sqrt.hpp"
#include "ngraph/runtime/ngvm/instruction/subtract.hpp"
#include "ngraph/runtime/ngvm/instruction/tan.hpp"
#include "ngraph/runtime/ngvm/instruction/tanh.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/util.hpp"
......@@ -230,51 +230,12 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
} \
}
#define DO_ON_SIGNED_NUMERIC_TYPE(et, err_msg, macro, ...) \
{ \
if (et == element::Float32::element_type()) \
{ \
macro(element::Float32, ##__VA_ARGS__); \
} \
else if (et == element::Int8::element_type()) \
{ \
macro(element::Int8, ##__VA_ARGS__); \
} \
else if (et == element::Int32::element_type()) \
{ \
macro(element::Int32, ##__VA_ARGS__); \
} \
else if (et == element::Int64::element_type()) \
{ \
macro(element::Int64, ##__VA_ARGS__); \
} \
else \
{ \
throw ngraph_error(err_msg); \
} \
}
#define REGISTER_INSTRUCTION(op_class, instr_class, ...) \
REGISTER_TO_OP_MAP(op_class) \
{ \
ef->get_instructions()->push_back(make_shared<instr_class>(__VA_ARGS__)); \
}
#define M_REGISTER_SIGNED_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_SIGNED_NUMERIC_UNOP(op_class, instr_class) \
REGISTER_TO_OP_MAP(op_class) \
{ \
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>( \
n->get_arguments().at(0)->get_value_type())) \
->get_element_type(); \
DO_ON_SIGNED_NUMERIC_TYPE( \
et, \
"Internal error: signed numeric unop has unhandled element type", \
M_REGISTER_SIGNED_NUMERIC_UNOP, \
instr_class); \
}
#define M_REGISTER_NUMERIC_UNOP(T, instr_class) \
ef->get_instructions()->push_back(make_shared<instr_class<T>>(in[0], out[0]));
#define REGISTER_NUMERIC_UNOP(op_class, instr_class) \
......@@ -363,7 +324,7 @@ std::vector<typename ET::type>
{ \
REGISTER_INSTRUCTION( \
op::ParameterizedConstant<T>, \
eigen::ConstantInstruction<T>, \
instruction::ConstantInstruction<T>, \
std::vector<T::type>{ \
get_vector<T>(dynamic_cast<const op::ParameterizedConstant<T>*>(n)->get_value())}, \
out[0]); \
......@@ -388,32 +349,31 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
static OpMap op_map;
if (!initialized)
{
REGISTER_NUMERIC_UNOP(op::Acos, eigen::AcosInstruction);
REGISTER_NUMERIC_UNOP(op::Asin, eigen::AsinInstruction);
REGISTER_NUMERIC_UNOP(op::Atan, eigen::AtanInstruction);
REGISTER_NUMERIC_UNOP(op::Ceiling, eigen::CeilingInstruction);
REGISTER_NUMERIC_UNOP(op::Cos, eigen::CosInstruction);
REGISTER_NUMERIC_UNOP(op::Cosh, eigen::CoshInstruction);
REGISTER_NUMERIC_UNOP(op::Exp, eigen::ExpInstruction);
REGISTER_NUMERIC_UNOP(op::Floor, eigen::FloorInstruction);
REGISTER_NUMERIC_UNOP(op::Log, eigen::LogInstruction);
REGISTER_NUMERIC_UNOP(op::Negative, eigen::NegateInstruction);
REGISTER_NUMERIC_UNOP(op::Sign, eigen::SignInstruction);
REGISTER_NUMERIC_UNOP(op::Sin, eigen::SinInstruction);
REGISTER_NUMERIC_UNOP(op::Sinh, eigen::SinhInstruction);
REGISTER_NUMERIC_UNOP(op::Sqrt, eigen::SqrtInstruction);
REGISTER_NUMERIC_UNOP(op::Tan, eigen::TanInstruction);
REGISTER_NUMERIC_UNOP(op::Tanh, eigen::TanhInstruction);
REGISTER_SIGNED_NUMERIC_UNOP(op::Abs, eigen::AbsInstruction);
REGISTER_NUMERIC_BINOP(op::Add, eigen::AddInstruction);
REGISTER_NUMERIC_BINOP(op::Divide, eigen::DivideInstruction);
REGISTER_NUMERIC_BINOP(op::Maximum, eigen::MaximumInstruction);
REGISTER_NUMERIC_BINOP(op::Minimum, eigen::MinimumInstruction);
REGISTER_NUMERIC_BINOP(op::Multiply, eigen::MultiplyInstruction);
REGISTER_NUMERIC_BINOP(op::Power, eigen::PowerInstruction);
REGISTER_NUMERIC_BINOP(op::Subtract, eigen::SubtractInstruction);
REGISTER_NUMERIC_UNOP(op::Abs, instruction::AbsInstruction);
REGISTER_NUMERIC_UNOP(op::Acos, instruction::AcosInstruction);
REGISTER_NUMERIC_UNOP(op::Asin, instruction::AsinInstruction);
REGISTER_NUMERIC_UNOP(op::Atan, instruction::AtanInstruction);
REGISTER_NUMERIC_UNOP(op::Ceiling, instruction::CeilingInstruction);
REGISTER_NUMERIC_UNOP(op::Cos, instruction::CosInstruction);
REGISTER_NUMERIC_UNOP(op::Cosh, instruction::CoshInstruction);
REGISTER_NUMERIC_UNOP(op::Exp, instruction::ExpInstruction);
REGISTER_NUMERIC_UNOP(op::Floor, instruction::FloorInstruction);
REGISTER_NUMERIC_UNOP(op::Log, instruction::LogInstruction);
REGISTER_NUMERIC_UNOP(op::Negative, instruction::NegateInstruction);
REGISTER_NUMERIC_UNOP(op::Sign, instruction::SignInstruction);
REGISTER_NUMERIC_UNOP(op::Sin, instruction::SinInstruction);
REGISTER_NUMERIC_UNOP(op::Sinh, instruction::SinhInstruction);
REGISTER_NUMERIC_UNOP(op::Sqrt, instruction::SqrtInstruction);
REGISTER_NUMERIC_UNOP(op::Tan, instruction::TanInstruction);
REGISTER_NUMERIC_UNOP(op::Tanh, instruction::TanhInstruction);
REGISTER_NUMERIC_BINOP(op::Add, instruction::AddInstruction);
REGISTER_NUMERIC_BINOP(op::Divide, instruction::DivideInstruction);
REGISTER_NUMERIC_BINOP(op::Maximum, instruction::MaximumInstruction);
REGISTER_NUMERIC_BINOP(op::Minimum, instruction::MinimumInstruction);
REGISTER_NUMERIC_BINOP(op::Multiply, instruction::MultiplyInstruction);
REGISTER_NUMERIC_BINOP(op::Power, instruction::PowerInstruction);
REGISTER_NUMERIC_BINOP(op::Subtract, instruction::SubtractInstruction);
REGISTER_TO_OP_MAP(op::Constant)
{
......@@ -424,7 +384,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto c_value_strings = c->get_value_strings();
#define M_REGISTER_POLYMORPHIC_CONSTANT(ET) \
ef->get_instructions()->push_back(make_shared<eigen::ConstantInstruction<ET>>( \
ef->get_instructions()->push_back(make_shared<instruction::ConstantInstruction<ET>>( \
parse_string<typename ET::type>(c_value_strings), out[0]));
DO_ON_ELEMENT_TYPE(c_element_type,
......@@ -432,14 +392,16 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
M_REGISTER_POLYMORPHIC_CONSTANT);
};
REGISTER_POLYMORPHIC_BINOP(op::Equal, eigen::EqualInstruction);
REGISTER_POLYMORPHIC_BINOP(op::NotEqual, eigen::NotEqualInstruction);
REGISTER_POLYMORPHIC_BINOP(op::Greater, eigen::GreaterThanInstruction);
REGISTER_POLYMORPHIC_BINOP(op::GreaterEq, eigen::GreaterEqInstruction);
REGISTER_POLYMORPHIC_BINOP(op::Less, eigen::LessThanInstruction);
REGISTER_POLYMORPHIC_BINOP(op::LessEq, eigen::LessEqInstruction);
REGISTER_POLYMORPHIC_BINOP(op::Equal, instruction::EqualInstruction);
REGISTER_POLYMORPHIC_BINOP(op::NotEqual, instruction::NotEqualInstruction);
REGISTER_POLYMORPHIC_BINOP(op::Greater, instruction::GreaterInstruction);
REGISTER_POLYMORPHIC_BINOP(op::GreaterEq, instruction::GreaterEqInstruction);
REGISTER_POLYMORPHIC_BINOP(op::Less, instruction::LessInstruction);
REGISTER_POLYMORPHIC_BINOP(op::LessEq, instruction::LessEqInstruction);
REGISTER_POLYMORPHIC_TERNOP(op::Select, eigen::SelectInstruction);
REGISTER_LOGICAL_UNOP(op::Not, instruction::NotInstruction);
REGISTER_POLYMORPHIC_TERNOP(op::Select, instruction::SelectInstruction);
REGISTER_CONSTANT_INSTRUCTIONS(element::Bool);
REGISTER_CONSTANT_INSTRUCTIONS(element::Float32);
......@@ -450,8 +412,6 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt32);
REGISTER_CONSTANT_INSTRUCTIONS(element::UInt64);
REGISTER_LOGICAL_UNOP(op::Not, eigen::NotInstruction);
REGISTER_TO_OP_MAP(op::Broadcast)
{
auto broadcast = static_cast<const op::Broadcast*>(n);
......@@ -459,60 +419,22 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto arg_tensor_type = dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type());
assert(nullptr != arg_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(nullptr != result_tensor_type);
auto arg_shape = arg_tensor_type->get_shape();
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (broadcast->get_broadcast_axes().empty())
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::CopyInstruction,
in[0].get_index(),
out[0].get_index());
}
else if (arg_shape.size() == 0)
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::BroadcastScalarInstruction,
in[0],
out[0]);
}
else if (arg_shape.size() == 1 && result_shape.size() == 2)
{
if (broadcast->get_broadcast_axes() == AxisSet{1})
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::BroadcastVectorColwiseInstruction,
in[0],
out[0]);
}
else if (broadcast->get_broadcast_axes() == AxisSet{0})
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
eigen::BroadcastVectorRowwiseInstruction,
in[0],
out[0]);
}
else
{
throw ngraph_error(
"Internal error: axis set for vector-matrix broadcast is neither {0} nor "
"{1}");
}
}
else
{
throw ngraph_error("Broadcast not implemented for rank>2 in VM yet");
}
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Broadcast has unhandled element type",
instruction::BroadcastInstruction,
in[0],
out[0],
arg_shape,
result_shape,
broadcast->get_broadcast_axes());
};
REGISTER_TO_OP_MAP(op::Concat)
......@@ -571,7 +493,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
result_element_type == (TO::element_type())) \
{ \
ef->get_instructions()->push_back( \
make_shared<eigen::ConvertInstruction<TI, TO>>(in[0], out[0])); \
make_shared<instruction::ConvertInstruction<TI, TO>>(in[0], out[0])); \
}
// End hacky macro
......@@ -697,9 +619,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"GetTupleElement has unhandled element type",
eigen::CopyInstruction,
in.at(get_tuple_element->get_n()).get_index(),
out.at(0).get_index());
instruction::CopyInstruction,
in[get_tuple_element->get_n()],
out[0]);
};
// Tuple will be spliced out, with the users of out connected to the corresponding in's source, but, for now, we need to copy.
......@@ -710,9 +632,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
auto& et = in.at(i).get_tensor_view_layout()->get_element_type();
PUSH_POLYMORPHIC_INSTRUCTION(et,
"Tuple has unhandled element type",
eigen::CopyInstruction,
in.at(i).get_index(),
out.at(i).get_index());
instruction::CopyInstruction,
in[i],
out[i]);
}
};
......@@ -734,7 +656,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
}
ef->get_instructions()->push_back(
make_shared<eigen::CallInstruction>(external, in, out));
make_shared<instruction::CallInstruction>(external, in, out));
};
REGISTER_TO_OP_MAP(op::Reduce)
......@@ -778,9 +700,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
"Reduce has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
// Behavior for zero-size axes bears some explanation here. XLA's reduce
// operator provides an "base" element (usually, but not necessarily,
......@@ -815,9 +737,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
"Reduce has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(1).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[1],
out[0]);
}
else
{
......@@ -902,9 +824,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(s_element_type,
"Sum has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
// Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
......@@ -969,9 +891,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Reshape has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
// If there *is* a layout change in the 2D case, we transpose the input.
else if (arg_rank == 2)
......@@ -1018,9 +940,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type,
"Slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[0],
out[0]);
}
else if (arg_rank == 1)
{
......@@ -1085,9 +1007,9 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
PUSH_POLYMORPHIC_INSTRUCTION(arg0_element_type,
"Replace-slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(1).get_index(),
out.at(0).get_index());
runtime::ngvm::instruction::CopyInstruction,
in[1],
out[0]);
}
else if (arg0_rank == 1)
{
......@@ -1182,9 +1104,10 @@ void ExternalFunction::compile(FunctionMap& function_map)
assert(nullptr != result_tensor_type);
auto& result_element_type = result_tensor_type->get_element_type();
auto ef = this;
// TODO: This is the one case where we can't use the new CopyInstruction that takes in a TensorViewInfo. (At least, I can't figure out how to do it.)
PUSH_POLYMORPHIC_INSTRUCTION(result_element_type,
"Copy has unhandled element type",
eigen::CopyInstruction,
instruction::CopyByIndexInstruction,
prev_index_it->second,
index);
}
......@@ -1240,7 +1163,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
}
m_instructions->insert(
m_instructions->end(), input_output_copies.begin(), input_output_copies.end());
m_instructions->push_back(make_shared<eigen::ReturnInstruction>());
m_instructions->push_back(make_shared<instruction::ReturnInstruction>());
m_is_compiled = true;
if (m_release_function)
{
......
......@@ -14,11 +14,11 @@
#pragma once
#include "ngraph/runtime/kernel/abs.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
......@@ -26,7 +26,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class AbsInstruction : public Instruction
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::abs<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/acos.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class AcosInstruction : public Instruction
{
public:
AcosInstruction(TensorViewInfo arg, TensorViewInfo out)
AcosInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).acos();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::acos<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/add.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,7 +26,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class AddInstruction : public Instruction
......@@ -42,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) +
EigenArray1d<ET>(call_frame, m_arg1);
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::add<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/asin.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class AsinInstruction : public Instruction
{
public:
AsinInstruction(TensorViewInfo arg, TensorViewInfo out)
AsinInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).asin();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::asin<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/atan.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class AtanInstruction : public Instruction
{
public:
AtanInstruction(TensorViewInfo arg, TensorViewInfo out)
AtanInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).atan();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::atan<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/broadcast.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -24,30 +25,41 @@ namespace ngraph
namespace runtime
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class BroadcastVectorColwiseInstruction : public Instruction
class BroadcastInstruction : public Instruction
{
public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
BroadcastInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out,
const Shape& arg_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
: m_arg(arg)
, m_out(out)
, m_arg_shape(arg_shape)
, m_out_shape(out_shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).colwise() =
EigenVector<ET>(call_frame, m_arg);
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
kernel::broadcast<typename ET::type>(
arg, out, m_arg_shape, m_out_shape, m_broadcast_axes);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
Shape m_arg_shape;
Shape m_out_shape;
AxisSet m_broadcast_axes;
};
}
}
......
......@@ -17,7 +17,6 @@
#include <memory>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
......@@ -28,7 +27,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
class CallInstruction : public Instruction
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/ceiling.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class CeilingInstruction : public Instruction
{
public:
CeilingInstruction(TensorViewInfo arg, TensorViewInfo out)
CeilingInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).ceil();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::ceiling<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -15,7 +15,6 @@
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
......@@ -26,7 +25,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class ConstantInstruction : public Instruction
......
......@@ -14,11 +14,11 @@
#pragma once
#include "ngraph/runtime/kernel/convert.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
......@@ -26,7 +26,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction
......@@ -40,9 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ETO>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg)
.template cast<typename ETO::type>();
typename ETI::type* arg = get_tensor_data_ptr<ETI>(call_frame, m_arg);
typename ETO::type* out = get_tensor_data_ptr<ETO>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::convert<typename ETI::type, typename ETO::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/copy.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class CopyInstruction : public Instruction
{
public:
CopyInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::copy<typename ET::type>(arg, out, count);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -17,7 +17,6 @@
#include <cassert>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
......@@ -27,16 +26,16 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
/// @brief Copies a tensor from in to out.
template <typename ET>
class CopyInstruction : public Instruction
class CopyByIndexInstruction : public Instruction
{
public:
/// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out)
CopyByIndexInstruction(size_t in, size_t out)
: m_in(in)
, m_out(out)
{
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/cos.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class CosInstruction : public Instruction
{
public:
CosInstruction(TensorViewInfo arg, TensorViewInfo out)
CosInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).cos();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::cos<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/cosh.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class CoshInstruction : public Instruction
{
public:
CoshInstruction(TensorViewInfo arg, TensorViewInfo out)
CoshInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).cosh();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::cosh<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/divide.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,7 +26,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class DivideInstruction : public Instruction
......@@ -42,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) /
EigenArray1d<ET>(call_frame, m_arg1);
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::divide<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/equal.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class EqualInstruction : public Instruction
{
public:
EqualInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
EqualInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -40,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) ==
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::equal<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/exp.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class ExpInstruction : public Instruction
{
public:
ExpInstruction(TensorViewInfo arg, TensorViewInfo out)
ExpInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).exp();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::exp<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/floor.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class FloorInstruction : public Instruction
{
public:
FloorInstruction(TensorViewInfo arg, TensorViewInfo out)
FloorInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).floor();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::floor<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/greater.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class GreaterInstruction : public Instruction
{
public:
GreaterInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::greater<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/greater_eq.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,23 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename TI, typename TO>
void greater_eq(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map_array(&*arg0) <= get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char);
}
template <typename ET>
class GreaterEqInstruction : public Instruction
{
public:
GreaterEqInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
GreaterEqInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -50,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::greater_eq<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/less.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class LessInstruction : public Instruction
{
public:
LessInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::less<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/less_eq.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class LessEqInstruction : public Instruction
{
public:
LessEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
LessEqInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -40,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::less_eq<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/log.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class LogInstruction : public Instruction
{
public:
LogInstruction(TensorViewInfo arg, TensorViewInfo out)
LogInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg));
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::log<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/maximum.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class MaximumInstruction : public Instruction
{
public:
MaximumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
MaximumInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -40,9 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0)
.max(EigenArray1d<ET>(call_frame, m_arg1));
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::maximum<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/minimum.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class MinimumInstruction : public Instruction
{
public:
MinimumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
MinimumInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -40,9 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0)
.min(EigenArray1d<ET>(call_frame, m_arg1));
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::minimum<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,11 @@
#pragma once
#include "ngraph/runtime/kernel/multiply.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
......@@ -24,15 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class MultiplyInstruction : public Instruction
{
public:
MultiplyInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
MultiplyInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -41,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) *
EigenArray1d<ET>(call_frame, m_arg1);
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::multiply<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/negate.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class NegateInstruction : public Instruction
{
public:
NegateInstruction(TensorViewInfo arg, TensorViewInfo out)
NegateInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -39,7 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = -EigenArray1d<ET>(call_frame, m_arg);
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::negate<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/not.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,12 +26,12 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
class NotInstruction : public Instruction
{
public:
NotInstruction(TensorViewInfo arg, TensorViewInfo out)
NotInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -38,13 +39,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit frustrating. We have to cast the Eigen
// matrix to a real bool, negate that, then cast that
// back to our storage representation (ultimately char).
EigenArray1d<element::Bool>(call_frame, m_out) =
(!(EigenArray1d<element::Bool>(call_frame, m_arg)
.template cast<bool>()))
.template cast<element::Bool::type>();
char* arg = get_tensor_data_ptr<element::Bool>(
call_frame, m_arg); // FIXME: temporarily char not bool
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::logical_not(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/not_equal.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,15 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class NotEqualInstruction : public Instruction
{
public:
NotEqualInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
NotEqualInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -42,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) !=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::not_equal<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/power.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,7 +26,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class PowerInstruction : public Instruction
......@@ -42,9 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0)
.pow(EigenArray1d<ET>(call_frame, m_arg1));
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::power<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
......@@ -23,7 +23,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
class ReturnInstruction : public Instruction
{
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,16 +26,16 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class SelectInstruction : public Instruction
{
public:
SelectInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo arg2,
TensorViewInfo out)
SelectInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& arg2,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_arg2(arg2)
......@@ -44,10 +45,15 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<element::Bool>(call_frame, m_arg0)
.select(EigenArray1d<ET>(call_frame, m_arg1),
EigenArray1d<ET>(call_frame, m_arg2));
char* arg0 = get_tensor_data_ptr<element::Bool>(
call_frame, m_arg0); // FIXME: temporarily char not bool
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* arg2 = get_tensor_data_ptr<ET>(call_frame, m_arg2);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::select<typename ET::type>(arg0, arg1, arg2, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class SignInstruction : public Instruction
{
public:
SignInstruction(TensorViewInfo arg, TensorViewInfo out)
SignInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).sign();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sign<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/sin.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class SinInstruction : public Instruction
{
public:
SinInstruction(TensorViewInfo arg, TensorViewInfo out)
SinInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).sin();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sin<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/sinh.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class SinhInstruction : public Instruction
{
public:
SinhInstruction(TensorViewInfo arg, TensorViewInfo out)
SinhInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).sinh();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sinh<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,11 +14,11 @@
#pragma once
#include "ngraph/runtime/kernel/sqrt.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
......@@ -26,7 +26,7 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class SqrtInstruction : public Instruction
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
Eigen::sqrt(EigenArray1d<ET>(call_frame, m_arg));
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sqrt<typename ET::type>(arg, out, count);
}
protected:
......
......@@ -14,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/subtract.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -25,15 +26,15 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class SubtractInstruction : public Instruction
{
public:
SubtractInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
SubtractInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
......@@ -42,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) -
EigenArray1d<ET>(call_frame, m_arg1);
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::subtract<typename ET::type>(arg0, arg1, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/tan.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class TanInstruction : public Instruction
{
public:
TanInstruction(TensorViewInfo arg, TensorViewInfo out)
TanInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).tan();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::tan<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,9 +14,10 @@
#pragma once
#include "ngraph/runtime/kernel/tanh.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
......@@ -26,13 +26,13 @@ namespace ngraph
{
namespace ngvm
{
namespace eigen
namespace instruction
{
template <typename ET>
class TanhInstruction : public Instruction
{
public:
TanhInstruction(TensorViewInfo arg, TensorViewInfo out)
TanhInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
......@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
EigenArray1d<ET, fmt::V>(call_frame, m_arg).tanh();
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::tanh<typename ET::type>(arg, out, count);
}
protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
namespace ngvm
{
class CallFrame;
template <typename ET>
typename ET::type* get_tensor_data_ptr(CallFrame& call_frame,
const TensorViewInfo& tensor_view_info)
{
return call_frame.get_tensor_view_data<ET>(tensor_view_info.get_index());
}
size_t get_tensor_element_count(CallFrame& call_frame,
const TensorViewInfo& tensor_view_info)
{
return tensor_view_info
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size();
}
}
}
}
......@@ -23,9 +23,10 @@ include_directories(
set (SRC
autodiff.cpp
build_graph.cpp
builder.cpp
builder_autobroadcast.cpp
build_graph.cpp
coordinate_iterator.cpp
copy.cpp
eigen.cpp
element_type.cpp
......
......@@ -1513,6 +1513,78 @@ TEST(${BACKEND_NAME}, broadcast_vector_rowwise_int64)
result->get_vector<element::Int64::type>());
}
TEST(DISABLED_${BACKEND_NAME}, broadcast_matrix_0)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4, 1, 2, 3, 4}),
result->get_vector<element::Float32::type>());
}
TEST(DISABLED_${BACKEND_NAME}, broadcast_matrix_1)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{1}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 1, 2, 3, 4, 3, 4}),
result->get_vector<element::Float32::type>());
}
TEST(DISABLED_${BACKEND_NAME}, broadcast_matrix_2)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{2}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 1, 2, 2, 3, 3, 4, 4}),
result->get_vector<element::Float32::type>());
}
TEST(${BACKEND_NAME}, convert_int32_float32)
{
auto shape = Shape{2, 2};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(coordinate_iterator, construct)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
auto ci = CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
}
TEST(coordinate_iterator, construct_defaults)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{2, 2, 2, 1};
auto ci = CoordinateIterator(space_shape, strides);
}
TEST(coordinate_iterator, construct_defaults_stride)
{
Shape space_shape{2, 3, 5, 6};
auto ci = CoordinateIterator(space_shape);
}
TEST(coordinate_iterator, construct_bad_outer_oob)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 4, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_inner_oob)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 3, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_inner_outside_outer)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 1, 5, 6};
Coordinate window_inner_corner{0, 2, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_zero_stride)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 0, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, cover_count_defaults)
{
Shape space_shape{2, 3, 5, 6};
auto ci = CoordinateIterator(space_shape);
size_t count = 0;
size_t expected_index = 0;
do
{
count++;
EXPECT_EQ(ci.get_current_index(), expected_index);
expected_index++;
} while (ci.increment());
EXPECT_EQ(count, 2 * 3 * 5 * 6);
}
TEST(coordinate_iterator, cover_count_stride_2)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 2};
auto ci = CoordinateIterator(space_shape, strides);
size_t count = 0;
size_t expected_index = 0;
do
{
count++;
EXPECT_EQ(ci.get_current_index(), expected_index);
expected_index += 2;
} while (ci.increment());
EXPECT_EQ(count, 2 * 3 * 5 * 6 / 2);
}
#define CEIL_DIV(x, y) (1 + (((x)-1) / (y)))
TEST(coordinate_iterator, cover_count_stride_uneven)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 2, 2, 3};
auto ci = CoordinateIterator(space_shape, strides);
size_t count = 0;
do
{
count++;
} while (ci.increment());
EXPECT_EQ(count, CEIL_DIV(2, 1) * CEIL_DIV(3, 2) * CEIL_DIV(5, 2) * CEIL_DIV(6, 3));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment