Commit 9d9f42bc authored by Robert Kimball's avatar Robert Kimball

update to latest ngraph API

parent bdee1ec5
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# ****************************************************************************** # ******************************************************************************
if (NGRAPH_GENERIC_CPU_ENABLE) if (NGRAPH_GENERIC_CPU_ENABLE)
find_package(OpenMP) # find_package(OpenMP)
if (OPENMP_FOUND) # if (OPENMP_FOUND)
add_compile_options(${OpenMP_CXX_FLAGS}) # add_compile_options(${OpenMP_CXX_FLAGS})
endif() # endif()
add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp) add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp)
if(NGRAPH_LIB_VERSIONING_ENABLE) if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(gcpu_backend PROPERTIES set_target_properties(gcpu_backend PROPERTIES
......
...@@ -72,7 +72,6 @@ ...@@ -72,7 +72,6 @@
#include "ngraph/runtime/generic_cpu/kernel/broadcast.hpp" #include "ngraph/runtime/generic_cpu/kernel/broadcast.hpp"
#include "ngraph/runtime/generic_cpu/kernel/dot.hpp" #include "ngraph/runtime/generic_cpu/kernel/dot.hpp"
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp" #include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/kernel/result.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp" #include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
...@@ -484,11 +483,11 @@ private: ...@@ -484,11 +483,11 @@ private:
Shape in_shape = node.get_input_shape(0); Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0); Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes(); AxisSet broadcast_axes = broadcast->get_broadcast_axes();
reference::broadcast<T>(args[0]->get_data_ptr<const T>(), kernel::broadcast<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
in_shape, in_shape,
out_shape, out_shape,
broadcast_axes); broadcast_axes);
break; break;
} }
case OP_TYPEID::BroadcastDistributed: case OP_TYPEID::BroadcastDistributed:
...@@ -737,13 +736,13 @@ private: ...@@ -737,13 +736,13 @@ private:
{ {
const op::Dot* dot = static_cast<const op::Dot*>(&node); const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(args[0]->get_data_ptr<const T>(), kernel::dot(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(), args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
dot->get_reduction_axes_count()); dot->get_reduction_axes_count());
break; break;
} }
case OP_TYPEID::DynReshape: case OP_TYPEID::DynReshape:
...@@ -1343,11 +1342,11 @@ private: ...@@ -1343,11 +1342,11 @@ private:
case OP_TYPEID::Reshape: case OP_TYPEID::Reshape:
{ {
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node); const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
reference::reshape(args[0]->get_data_ptr<const T>(), kernel::reshape(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reshape->get_input_order(), reshape->get_input_order(),
node.get_output_shape(0)); node.get_output_shape(0));
break; break;
} }
case OP_TYPEID::Result: case OP_TYPEID::Result:
......
...@@ -140,6 +140,91 @@ namespace ngraph ...@@ -140,6 +140,91 @@ namespace ngraph
} }
} }
template <typename T>
void broadcast_5d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[5];
size_t* out_index = 0;
for (size_t i = 0; i < 5; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] *
out_shape[4] +
index[1] * out_shape[2] * out_shape[3] * out_shape[4] +
index[2] * out_shape[3] * out_shape[4] +
index[3] * out_shape[4] + index[4]] = in[*out_index];
}
}
}
}
}
}
template <typename T>
void broadcast_6d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[6];
size_t* out_index = 0;
for (size_t i = 0; i < 6; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
out[index[0] * out_shape[1] * out_shape[2] *
out_shape[3] * out_shape[4] * out_shape[5] +
index[1] * out_shape[2] * out_shape[3] *
out_shape[4] * out_shape[5] +
index[2] * out_shape[3] * out_shape[4] *
out_shape[5] +
index[3] * out_shape[4] * out_shape[5] +
index[4] * out_shape[5] + index[5]] =
in[*out_index];
}
}
}
}
}
}
}
template <typename T> template <typename T>
void broadcast(const T* in, void broadcast(const T* in,
T* out, T* out,
...@@ -167,6 +252,16 @@ namespace ngraph ...@@ -167,6 +252,16 @@ namespace ngraph
case 4: case 4:
broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes); broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes);
break; break;
case 5:
broadcast_5d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 6:
broadcast_6d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
} }
} }
else else
......
...@@ -244,10 +244,7 @@ namespace ngraph ...@@ -244,10 +244,7 @@ namespace ngraph
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break; case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break; case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break; case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default: default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
NGRAPH_INFO << "reference::reshape";
reference::reshape(in, out, in_shape, in_axis_order, out_shape);
break;
} }
} }
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void result(const T* arg, T* out, size_t count)
{
memcpy(out, arg, sizeof(T) * count);
}
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment