Unverified Commit f7a22b95 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #278 from NervanaSystems/aprocter/de-eigenize-partial

Work so far on de-Eigenization
parents 2f0a33c3 18998c41
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
set (SRC set (SRC
autodiff/adjoints.cpp autodiff/adjoints.cpp
builder/autobroadcast.cpp builder/autobroadcast.cpp
builder/reduce_ops.cpp
builder/numpy_transpose.cpp builder/numpy_transpose.cpp
builder/reduce_ops.cpp
coordinate_iterator.cpp
descriptor/input.cpp descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp descriptor/layout/tensor_view_layout.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/except.hpp"
using namespace ngraph;
CoordinateIterator::CoordinateIterator(const Shape& space_shape,
const Strides& strides,
const Coordinate& window_outer_corner,
const Coordinate& window_inner_corner)
: m_space_shape(space_shape)
, m_strides(strides)
, m_window_outer_corner(window_outer_corner)
, m_window_inner_corner(window_inner_corner)
, m_current_coordinate(window_inner_corner)
{
if (space_shape.size() != window_inner_corner.size())
{
throw ngraph_error("Coordinate iterator inner corner rank does not match space shape rank");
}
if (space_shape.size() != window_outer_corner.size())
{
throw ngraph_error("Coordinate iterator outer corner rank does not match space shape rank");
}
if (space_shape.size() != strides.size())
{
throw ngraph_error("Coordinate iterator stride rank does not match space shape rank");
}
for (size_t i = 0; i < space_shape.size(); i++)
{
if (window_inner_corner[i] > window_outer_corner[i])
{
throw ngraph_error("Coordinate iterator inner corner is outside outer corner");
}
if (window_inner_corner[i] >= m_space_shape[i])
{
throw ngraph_error("Coordinate iterator inner corner is out of bounds");
}
if (window_outer_corner[i] > m_space_shape[i])
{
throw ngraph_error("Coordinate iterator outer corner is out of bounds");
}
if (m_strides[i] == 0)
{
throw ngraph_error("Coordinate iterator stride is zero");
}
}
}
CoordinateIterator::CoordinateIterator(const Shape& space_shape)
: CoordinateIterator(space_shape,
Strides(space_shape.size(), 1),
space_shape,
Coordinate(space_shape.size(), 0))
{
}
CoordinateIterator::CoordinateIterator(const Shape& space_shape, const Strides& strides)
: CoordinateIterator(space_shape, strides, space_shape, Coordinate(space_shape.size(), 0))
{
}
size_t CoordinateIterator::get_current_index() const
{
size_t index = 0;
size_t stride = 1;
for (size_t i = m_space_shape.size(); i-- > 0;)
{
index += m_current_coordinate[i] * stride;
stride *= m_space_shape[i];
}
return index;
}
bool CoordinateIterator::increment()
{
bool overflow = true;
for (size_t i = m_space_shape.size(); i-- > 0;)
{
m_current_coordinate[i] += m_strides[i];
if (m_current_coordinate[i] >= m_window_outer_corner[i])
{
m_current_coordinate[i] = m_window_inner_corner[i];
}
else
{
overflow = false;
break;
}
}
return !overflow;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cstdio>
#include <iostream>
#include <vector>
#include "ngraph/common.hpp"
namespace ngraph
{
class CoordinateIterator
{
public:
CoordinateIterator(const Shape& space_shape,
const Strides& strides,
const Coordinate& window_outer_corner,
const Coordinate& window_inner_corner);
CoordinateIterator(const Shape& space_shape);
CoordinateIterator(const Shape& space_shape, const Strides& strides);
Coordinate get_current_coordinate() const { return m_current_coordinate; }
size_t get_current_index() const;
bool increment();
private:
const Shape m_space_shape;
const Strides m_strides;
const Coordinate m_window_outer_corner;
const Coordinate m_window_inner_corner;
Coordinate m_current_coordinate;
};
}
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "ngraph/builder/numpy_transpose.hpp" #include "ngraph/builder/numpy_transpose.hpp"
#include "ngraph/builder/reduce_ops.hpp" #include "ngraph/builder/reduce_ops.hpp"
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
#include "ngraph/descriptor/buffer.hpp" #include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/input.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void abs(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
// TODO: generic "abs" doesn't work here for some reason.
out[i] = (arg[i] < 0 ? -arg[i] : arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void acos(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::acos(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void add(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] + arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void asin(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::asin(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void atan(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::atan(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
#include "ngraph/common.hpp"
#include "ngraph/coordinate_iterator.hpp"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void broadcast(T* arg,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
// For the outer loop we will walk over the entire input shape.
CoordinateIterator arg_iter(in_shape);
do
{
// For the inner loop we will walk across the entire axis for the new broadcast axes, and stay put at the current arg position for the existing axes.
Coordinate arg_coordinate = arg_iter.get_current_coordinate();
Strides out_strides(out_shape.size(), 1);
Coordinate out_outer_corner(out_shape.size());
Coordinate out_inner_corner(out_shape.size());
size_t arg_pos = 0;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (broadcast_axes.find(i) == broadcast_axes.end())
{
// This is an existing axis.
out_outer_corner[i] = arg_coordinate[arg_pos];
out_inner_corner[i] = arg_coordinate[arg_pos];
arg_pos++;
}
else
{
// This is a new broadcast axis.
out_outer_corner[i] = out_shape[i];
out_inner_corner[i] = 0;
}
}
CoordinateIterator out_iter(
out_shape, out_strides, out_outer_corner, out_inner_corner);
do
{
out[out_iter.get_current_index()] = arg[arg_iter.get_current_index()];
} while (out_iter.increment());
} while (arg_iter.increment());
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void ceiling(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::ceil(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename TI, typename TO>
void convert(TI* arg, TO* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = static_cast<TO>(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void copy(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void cos(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::cos(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void cosh(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::cosh(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void divide(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] / arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void equal(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] == arg1[i];
}
}
}
}
}
#pragma clang diagnostic pop
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void exp(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::exp(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void floor(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::floor(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void greater(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] > arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void greater_eq(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] >= arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void less(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] < arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void less_eq(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] <= arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void log(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::log(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void maximum(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] > arg1[i] ? arg0[i] : arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void minimum(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] < arg1[i] ? arg0[i] : arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void multiply(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] * arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void negate(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = -arg[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
void logical_not(char* arg,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = !(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void not_equal(T* arg0,
T* arg1,
char* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] != arg1[i];
}
}
}
}
}
#pragma clang diagnostic pop
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void power(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::pow(arg0[i], arg1[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void select(char* arg0,
T* arg1,
T* arg2,
T* out,
size_t count) // TODO: using char for bool, is this right?
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] ? arg1[i] : arg2[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sign(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = (arg[i] < 0 ? -1 : (arg[i] > 0 ? 1 : 0));
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sin(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::sin(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sinh(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::sinh(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void sqrt(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::sqrt(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void subtract(T* arg0, T* arg1, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg0[i] - arg1[i];
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void tan(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::tan(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cmath>
namespace ngraph
{
namespace runtime
{
namespace kernel
{
template <typename T>
void tanh(T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = std::tanh(arg[i]);
}
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class BroadcastVectorRowwiseInstruction : public Instruction
{
public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class GreaterThanInstruction : public Instruction
{
public:
GreaterThanInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class LessThanInstruction : public Instruction
{
public:
LessThanInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/abs.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class AbsInstruction : public Instruction class AbsInstruction : public Instruction
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::abs<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/acos.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class AcosInstruction : public Instruction class AcosInstruction : public Instruction
{ {
public: public:
AcosInstruction(TensorViewInfo arg, TensorViewInfo out) AcosInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).acos(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::acos<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/add.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,7 +26,7 @@ namespace ngraph ...@@ -25,7 +26,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class AddInstruction : public Instruction class AddInstruction : public Instruction
...@@ -42,8 +43,13 @@ namespace ngraph ...@@ -42,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) + typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg1); typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::add<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/asin.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class AsinInstruction : public Instruction class AsinInstruction : public Instruction
{ {
public: public:
AsinInstruction(TensorViewInfo arg, TensorViewInfo out) AsinInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).asin(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::asin<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/atan.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class AtanInstruction : public Instruction class AtanInstruction : public Instruction
{ {
public: public:
AtanInstruction(TensorViewInfo arg, TensorViewInfo out) AtanInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).atan(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::atan<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/broadcast.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -24,30 +25,41 @@ namespace ngraph ...@@ -24,30 +25,41 @@ namespace ngraph
namespace runtime namespace runtime
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class BroadcastVectorColwiseInstruction : public Instruction class BroadcastInstruction : public Instruction
{ {
public: public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg, BroadcastInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out) const TensorViewInfo& out,
const Shape& arg_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
, m_arg_shape(arg_shape)
, m_out_shape(out_shape)
, m_broadcast_axes(broadcast_axes)
{ {
} }
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenMatrix<ET>(call_frame, m_out).colwise() = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenVector<ET>(call_frame, m_arg); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
kernel::broadcast<typename ET::type>(
arg, out, m_arg_shape, m_out_shape, m_broadcast_axes);
} }
protected: protected:
TensorViewInfo m_arg; TensorViewInfo m_arg;
TensorViewInfo m_out; TensorViewInfo m_out;
Shape m_arg_shape;
Shape m_out_shape;
AxisSet m_broadcast_axes;
}; };
} }
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <memory> #include <memory>
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp" #include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
...@@ -28,7 +27,7 @@ namespace ngraph ...@@ -28,7 +27,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
class CallInstruction : public Instruction class CallInstruction : public Instruction
{ {
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/ceiling.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class CeilingInstruction : public Instruction class CeilingInstruction : public Instruction
{ {
public: public:
CeilingInstruction(TensorViewInfo arg, TensorViewInfo out) CeilingInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).ceil(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::ceiling<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
...@@ -26,7 +25,7 @@ namespace ngraph ...@@ -26,7 +25,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class ConstantInstruction : public Instruction class ConstantInstruction : public Instruction
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/convert.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ETI, typename ETO> template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction class ConvertInstruction : public Instruction
...@@ -40,9 +40,12 @@ namespace ngraph ...@@ -40,9 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ETO>(call_frame, m_out) = typename ETI::type* arg = get_tensor_data_ptr<ETI>(call_frame, m_arg);
EigenArray1d<ETI>(call_frame, m_arg) typename ETO::type* out = get_tensor_data_ptr<ETO>(call_frame, m_out);
.template cast<typename ETO::type>();
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::convert<typename ETI::type, typename ETO::type>(arg, out, count);
} }
protected: protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/copy.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class CopyInstruction : public Instruction
{
public:
CopyInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::copy<typename ET::type>(arg, out, count);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <cassert> #include <cassert>
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
...@@ -27,16 +26,16 @@ namespace ngraph ...@@ -27,16 +26,16 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
/// @brief Copies a tensor from in to out. /// @brief Copies a tensor from in to out.
template <typename ET> template <typename ET>
class CopyInstruction : public Instruction class CopyByIndexInstruction : public Instruction
{ {
public: public:
/// @param in Index of input tensor in call frame. /// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame. /// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out) CopyByIndexInstruction(size_t in, size_t out)
: m_in(in) : m_in(in)
, m_out(out) , m_out(out)
{ {
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/cos.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class CosInstruction : public Instruction class CosInstruction : public Instruction
{ {
public: public:
CosInstruction(TensorViewInfo arg, TensorViewInfo out) CosInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).cos(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::cos<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/cosh.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class CoshInstruction : public Instruction class CoshInstruction : public Instruction
{ {
public: public:
CoshInstruction(TensorViewInfo arg, TensorViewInfo out) CoshInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).cosh(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::cosh<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/divide.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,7 +26,7 @@ namespace ngraph ...@@ -25,7 +26,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class DivideInstruction : public Instruction class DivideInstruction : public Instruction
...@@ -42,8 +43,13 @@ namespace ngraph ...@@ -42,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) / typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg1); typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::divide<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/equal.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,15 @@ namespace ngraph ...@@ -25,13 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class EqualInstruction : public Instruction class EqualInstruction : public Instruction
{ {
public: public:
EqualInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out) EqualInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -40,10 +43,14 @@ namespace ngraph ...@@ -40,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<element::Bool>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
(EigenArray1d<ET>(call_frame, m_arg0) == typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
EigenArray1d<ET>(call_frame, m_arg1)) char* out = get_tensor_data_ptr<element::Bool>(
.template cast<char>(); call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::equal<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/exp.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,13 @@ namespace ngraph ...@@ -25,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class ExpInstruction : public Instruction class ExpInstruction : public Instruction
{ {
public: public:
ExpInstruction(TensorViewInfo arg, TensorViewInfo out) ExpInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -39,8 +40,12 @@ namespace ngraph ...@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).exp(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::exp<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/floor.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class FloorInstruction : public Instruction class FloorInstruction : public Instruction
{ {
public: public:
FloorInstruction(TensorViewInfo arg, TensorViewInfo out) FloorInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).floor(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::floor<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/greater.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class GreaterInstruction : public Instruction
{
public:
GreaterInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::greater<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/greater_eq.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,23 +26,15 @@ namespace ngraph ...@@ -25,23 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename TI, typename TO>
void greater_eq(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map_array(&*arg0) <= get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char);
}
template <typename ET> template <typename ET>
class GreaterEqInstruction : public Instruction class GreaterEqInstruction : public Instruction
{ {
public: public:
GreaterEqInstruction(TensorViewInfo arg0, GreaterEqInstruction(const TensorViewInfo& arg0,
TensorViewInfo arg1, const TensorViewInfo& arg1,
TensorViewInfo out) const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -50,10 +43,14 @@ namespace ngraph ...@@ -50,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<element::Bool>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
(EigenArray1d<ET>(call_frame, m_arg0) >= typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
EigenArray1d<ET>(call_frame, m_arg1)) char* out = get_tensor_data_ptr<element::Bool>(
.template cast<char>(); call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::greater_eq<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/kernel/less.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace instruction
{
template <typename ET>
class LessInstruction : public Instruction
{
public:
LessInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
char* out = get_tensor_data_ptr<element::Bool>(
call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::less<typename ET::type>(arg0, arg1, out, count);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/less_eq.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,15 @@ namespace ngraph ...@@ -25,13 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class LessEqInstruction : public Instruction class LessEqInstruction : public Instruction
{ {
public: public:
LessEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out) LessEqInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -40,10 +43,14 @@ namespace ngraph ...@@ -40,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<element::Bool>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
(EigenArray1d<ET>(call_frame, m_arg0) <= typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
EigenArray1d<ET>(call_frame, m_arg1)) char* out = get_tensor_data_ptr<element::Bool>(
.template cast<char>(); call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::less_eq<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/log.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,13 @@ namespace ngraph ...@@ -25,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class LogInstruction : public Instruction class LogInstruction : public Instruction
{ {
public: public:
LogInstruction(TensorViewInfo arg, TensorViewInfo out) LogInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -39,8 +40,12 @@ namespace ngraph ...@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::log<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/maximum.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,15 @@ namespace ngraph ...@@ -25,13 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class MaximumInstruction : public Instruction class MaximumInstruction : public Instruction
{ {
public: public:
MaximumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out) MaximumInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -40,9 +43,13 @@ namespace ngraph ...@@ -40,9 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg0) typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
.max(EigenArray1d<ET>(call_frame, m_arg1)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::maximum<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/minimum.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,15 @@ namespace ngraph ...@@ -25,13 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class MinimumInstruction : public Instruction class MinimumInstruction : public Instruction
{ {
public: public:
MinimumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out) MinimumInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -40,9 +43,13 @@ namespace ngraph ...@@ -40,9 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg0) typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
.min(EigenArray1d<ET>(call_frame, m_arg1)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::minimum<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/multiply.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -24,15 +26,15 @@ namespace ngraph ...@@ -24,15 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class MultiplyInstruction : public Instruction class MultiplyInstruction : public Instruction
{ {
public: public:
MultiplyInstruction(TensorViewInfo arg0, MultiplyInstruction(const TensorViewInfo& arg0,
TensorViewInfo arg1, const TensorViewInfo& arg1,
TensorViewInfo out) const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -41,8 +43,13 @@ namespace ngraph ...@@ -41,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) * typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg1); typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::multiply<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/negate.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,13 @@ namespace ngraph ...@@ -25,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class NegateInstruction : public Instruction class NegateInstruction : public Instruction
{ {
public: public:
NegateInstruction(TensorViewInfo arg, TensorViewInfo out) NegateInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -39,7 +40,12 @@ namespace ngraph ...@@ -39,7 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = -EigenArray1d<ET>(call_frame, m_arg); typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::negate<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/not.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,12 +26,12 @@ namespace ngraph ...@@ -25,12 +26,12 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
class NotInstruction : public Instruction class NotInstruction : public Instruction
{ {
public: public:
NotInstruction(TensorViewInfo arg, TensorViewInfo out) NotInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -38,13 +39,14 @@ namespace ngraph ...@@ -38,13 +39,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
// This is a bit frustrating. We have to cast the Eigen char* arg = get_tensor_data_ptr<element::Bool>(
// matrix to a real bool, negate that, then cast that call_frame, m_arg); // FIXME: temporarily char not bool
// back to our storage representation (ultimately char). char* out = get_tensor_data_ptr<element::Bool>(
EigenArray1d<element::Bool>(call_frame, m_out) = call_frame, m_out); // FIXME: temporarily char not bool
(!(EigenArray1d<element::Bool>(call_frame, m_arg)
.template cast<bool>())) size_t count = get_tensor_element_count(call_frame, m_arg);
.template cast<element::Bool::type>();
kernel::logical_not(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/not_equal.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,15 +26,15 @@ namespace ngraph ...@@ -25,15 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class NotEqualInstruction : public Instruction class NotEqualInstruction : public Instruction
{ {
public: public:
NotEqualInstruction(TensorViewInfo arg0, NotEqualInstruction(const TensorViewInfo& arg0,
TensorViewInfo arg1, const TensorViewInfo& arg1,
TensorViewInfo out) const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -42,10 +43,14 @@ namespace ngraph ...@@ -42,10 +43,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<element::Bool>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
(EigenArray1d<ET>(call_frame, m_arg0) != typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
EigenArray1d<ET>(call_frame, m_arg1)) char* out = get_tensor_data_ptr<element::Bool>(
.template cast<char>(); call_frame, m_out); // FIXME: temporarily char not bool
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::not_equal<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/power.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,7 +26,7 @@ namespace ngraph ...@@ -25,7 +26,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class PowerInstruction : public Instruction class PowerInstruction : public Instruction
...@@ -42,9 +43,13 @@ namespace ngraph ...@@ -42,9 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg0) typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
.pow(EigenArray1d<ET>(call_frame, m_arg1)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::power<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
...@@ -23,7 +23,7 @@ namespace ngraph ...@@ -23,7 +23,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
class ReturnInstruction : public Instruction class ReturnInstruction : public Instruction
{ {
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/select.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,16 +26,16 @@ namespace ngraph ...@@ -25,16 +26,16 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class SelectInstruction : public Instruction class SelectInstruction : public Instruction
{ {
public: public:
SelectInstruction(TensorViewInfo arg0, SelectInstruction(const TensorViewInfo& arg0,
TensorViewInfo arg1, const TensorViewInfo& arg1,
TensorViewInfo arg2, const TensorViewInfo& arg2,
TensorViewInfo out) const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_arg2(arg2) , m_arg2(arg2)
...@@ -44,10 +45,15 @@ namespace ngraph ...@@ -44,10 +45,15 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = char* arg0 = get_tensor_data_ptr<element::Bool>(
EigenArray1d<element::Bool>(call_frame, m_arg0) call_frame, m_arg0); // FIXME: temporarily char not bool
.select(EigenArray1d<ET>(call_frame, m_arg1), typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
EigenArray1d<ET>(call_frame, m_arg2)); typename ET::type* arg2 = get_tensor_data_ptr<ET>(call_frame, m_arg2);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::select<typename ET::type>(arg0, arg1, arg2, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/sign.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,13 @@ namespace ngraph ...@@ -25,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class SignInstruction : public Instruction class SignInstruction : public Instruction
{ {
public: public:
SignInstruction(TensorViewInfo arg, TensorViewInfo out) SignInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -39,8 +40,12 @@ namespace ngraph ...@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).sign(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sign<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/sin.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,13 +26,13 @@ namespace ngraph ...@@ -25,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class SinInstruction : public Instruction class SinInstruction : public Instruction
{ {
public: public:
SinInstruction(TensorViewInfo arg, TensorViewInfo out) SinInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -39,8 +40,12 @@ namespace ngraph ...@@ -39,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).sin(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sin<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/sinh.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class SinhInstruction : public Instruction class SinhInstruction : public Instruction
{ {
public: public:
SinhInstruction(TensorViewInfo arg, TensorViewInfo out) SinhInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).sinh(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sinh<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/sqrt.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class SqrtInstruction : public Instruction class SqrtInstruction : public Instruction
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
Eigen::sqrt(EigenArray1d<ET>(call_frame, m_arg)); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::sqrt<typename ET::type>(arg, out, count);
} }
protected: protected:
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/subtract.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -25,15 +26,15 @@ namespace ngraph ...@@ -25,15 +26,15 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class SubtractInstruction : public Instruction class SubtractInstruction : public Instruction
{ {
public: public:
SubtractInstruction(TensorViewInfo arg0, SubtractInstruction(const TensorViewInfo& arg0,
TensorViewInfo arg1, const TensorViewInfo& arg1,
TensorViewInfo out) const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -42,8 +43,13 @@ namespace ngraph ...@@ -42,8 +43,13 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) - typename ET::type* arg0 = get_tensor_data_ptr<ET>(call_frame, m_arg0);
EigenArray1d<ET>(call_frame, m_arg1); typename ET::type* arg1 = get_tensor_data_ptr<ET>(call_frame, m_arg1);
typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg0);
kernel::subtract<typename ET::type>(arg0, arg1, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/tan.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class TanInstruction : public Instruction class TanInstruction : public Instruction
{ {
public: public:
TanInstruction(TensorViewInfo arg, TensorViewInfo out) TanInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).tan(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::tan<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc. // Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,9 +14,10 @@ ...@@ -15,9 +14,10 @@
#pragma once #pragma once
#include "ngraph/runtime/kernel/tanh.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/ngvm/utils.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
...@@ -26,13 +26,13 @@ namespace ngraph ...@@ -26,13 +26,13 @@ namespace ngraph
{ {
namespace ngvm namespace ngvm
{ {
namespace eigen namespace instruction
{ {
template <typename ET> template <typename ET>
class TanhInstruction : public Instruction class TanhInstruction : public Instruction
{ {
public: public:
TanhInstruction(TensorViewInfo arg, TensorViewInfo out) TanhInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg) : m_arg(arg)
, m_out(out) , m_out(out)
{ {
...@@ -40,8 +40,12 @@ namespace ngraph ...@@ -40,8 +40,12 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET, fmt::V>(call_frame, m_out) = typename ET::type* arg = get_tensor_data_ptr<ET>(call_frame, m_arg);
EigenArray1d<ET, fmt::V>(call_frame, m_arg).tanh(); typename ET::type* out = get_tensor_data_ptr<ET>(call_frame, m_out);
size_t count = get_tensor_element_count(call_frame, m_arg);
kernel::tanh<typename ET::type>(arg, out, count);
} }
protected: protected:
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
namespace ngvm
{
class CallFrame;
template <typename ET>
typename ET::type* get_tensor_data_ptr(CallFrame& call_frame,
const TensorViewInfo& tensor_view_info)
{
return call_frame.get_tensor_view_data<ET>(tensor_view_info.get_index());
}
size_t get_tensor_element_count(CallFrame& call_frame,
const TensorViewInfo& tensor_view_info)
{
return tensor_view_info
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size();
}
}
}
}
...@@ -23,9 +23,10 @@ include_directories( ...@@ -23,9 +23,10 @@ include_directories(
set (SRC set (SRC
autodiff.cpp autodiff.cpp
build_graph.cpp
builder.cpp builder.cpp
builder_autobroadcast.cpp builder_autobroadcast.cpp
build_graph.cpp
coordinate_iterator.cpp
copy.cpp copy.cpp
eigen.cpp eigen.cpp
element_type.cpp element_type.cpp
......
...@@ -1513,6 +1513,78 @@ TEST(${BACKEND_NAME}, broadcast_vector_rowwise_int64) ...@@ -1513,6 +1513,78 @@ TEST(${BACKEND_NAME}, broadcast_vector_rowwise_int64)
result->get_vector<element::Int64::type>()); result->get_vector<element::Int64::type>());
} }
TEST(DISABLED_${BACKEND_NAME}, broadcast_matrix_0)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 3, 4, 1, 2, 3, 4}),
result->get_vector<element::Float32::type>());
}
TEST(DISABLED_${BACKEND_NAME}, broadcast_matrix_1)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{1}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 2, 1, 2, 3, 4, 3, 4}),
result->get_vector<element::Float32::type>());
}
TEST(DISABLED_${BACKEND_NAME}, broadcast_matrix_2)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
auto shape_r = Shape{2, 2, 2};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{2}), rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = backend->make_primary_tensor_view(element::Float32::element_type(), shape_a);
copy_data(a, vector<element::Float32::type>{1, 2, 3, 4});
auto result = backend->make_primary_tensor_view(element::Float32::element_type(), shape_r);
cf->call({a}, {result});
ASSERT_EQ((vector<element::Float32::type>{1, 1, 2, 2, 3, 3, 4, 4}),
result->get_vector<element::Float32::type>());
}
TEST(${BACKEND_NAME}, convert_int32_float32) TEST(${BACKEND_NAME}, convert_int32_float32)
{ {
auto shape = Shape{2, 2}; auto shape = Shape{2, 2};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
TEST(coordinate_iterator, construct)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
auto ci = CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
}
TEST(coordinate_iterator, construct_defaults)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{2, 2, 2, 1};
auto ci = CoordinateIterator(space_shape, strides);
}
TEST(coordinate_iterator, construct_defaults_stride)
{
Shape space_shape{2, 3, 5, 6};
auto ci = CoordinateIterator(space_shape);
}
TEST(coordinate_iterator, construct_bad_outer_oob)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 4, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_inner_oob)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 3, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_inner_outside_outer)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 1};
Coordinate window_outer_corner{2, 1, 5, 6};
Coordinate window_inner_corner{0, 2, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, construct_bad_zero_stride)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 0, 1, 1};
Coordinate window_outer_corner{2, 3, 5, 6};
Coordinate window_inner_corner{0, 0, 0, 0};
EXPECT_ANY_THROW({
auto ci =
CoordinateIterator(space_shape, strides, window_outer_corner, window_inner_corner);
});
}
TEST(coordinate_iterator, cover_count_defaults)
{
Shape space_shape{2, 3, 5, 6};
auto ci = CoordinateIterator(space_shape);
size_t count = 0;
size_t expected_index = 0;
do
{
count++;
EXPECT_EQ(ci.get_current_index(), expected_index);
expected_index++;
} while (ci.increment());
EXPECT_EQ(count, 2 * 3 * 5 * 6);
}
TEST(coordinate_iterator, cover_count_stride_2)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 1, 1, 2};
auto ci = CoordinateIterator(space_shape, strides);
size_t count = 0;
size_t expected_index = 0;
do
{
count++;
EXPECT_EQ(ci.get_current_index(), expected_index);
expected_index += 2;
} while (ci.increment());
EXPECT_EQ(count, 2 * 3 * 5 * 6 / 2);
}
#define CEIL_DIV(x, y) (1 + (((x)-1) / (y)))
TEST(coordinate_iterator, cover_count_stride_uneven)
{
Shape space_shape{2, 3, 5, 6};
Strides strides{1, 2, 2, 3};
auto ci = CoordinateIterator(space_shape, strides);
size_t count = 0;
do
{
count++;
} while (ci.increment());
EXPECT_EQ(count, CEIL_DIV(2, 1) * CEIL_DIV(3, 2) * CEIL_DIV(5, 2) * CEIL_DIV(6, 3));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment