Commit b8de3b7d authored by Chris Sullivan's avatar Chris Sullivan Committed by Scott Cyphers

Revert changes to gpu shape and update (#1354)

* GPUShape(int32_t) -> NVShape(uint32_2), NVDiff(int32_t)

* Update code merged from master.

* Add nvshape.hpp and nvdiff.hpp.
parent e5e8d03c
This diff is collapsed.
This diff is collapsed.
...@@ -475,7 +475,7 @@ namespace ngraph ...@@ -475,7 +475,7 @@ namespace ngraph
auto axis = concat->get_concatenation_axis(); auto axis = concat->get_concatenation_axis();
std::vector<std::string> dtypes; std::vector<std::string> dtypes;
std::vector<GPUShape> input_shapes; std::vector<NVShape> input_shapes;
for (auto arg : args) for (auto arg : args)
{ {
dtypes.push_back(arg.get_type()); dtypes.push_back(arg.get_type());
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <cstdio>
#include <stdexcept>
#include <vector>
#include "ngraph/coordinate_diff.hpp"
namespace ngraph
{
class Shape;
/// \brief Shape for a tensor resident on GPU.
class NVDiff : public std::vector<int32_t>
{
public:
NVDiff(const std::initializer_list<int32_t>& axis_lengths)
: std::vector<int32_t>(axis_lengths)
{
}
NVDiff(const std::vector<int32_t>& axis_lengths)
: std::vector<int32_t>(axis_lengths)
{
}
NVDiff(const NVDiff& axis_lengths)
: std::vector<int32_t>(axis_lengths)
{
}
explicit NVDiff(size_t n, int32_t initial_value = 0)
: std::vector<int32_t>(n, initial_value)
{
}
template <class InputIterator>
NVDiff(InputIterator first, InputIterator last)
: std::vector<int32_t>(first, last)
{
}
NVDiff() {}
NVDiff& operator=(const NVDiff& v)
{
static_cast<std::vector<int32_t>*>(this)->operator=(v);
return *this;
}
NVDiff& operator=(NVDiff&& v)
{
static_cast<std::vector<int32_t>*>(this)->operator=(v);
return *this;
}
NVDiff(const CoordinateDiff& coord)
{
for (auto const& dim : coord)
{
if (std::abs(dim) >> 32 != 0)
{
throw std::runtime_error(
"Request for CoordinateDiff which exceed the bitwidth available for "
"NVDiffs (32)");
}
this->push_back(static_cast<int32_t>(dim));
}
}
};
}
...@@ -31,129 +31,114 @@ namespace ngraph ...@@ -31,129 +31,114 @@ namespace ngraph
{ {
class Shape; class Shape;
/// \brief Shape for a tensor resident on GPU. /// \brief Shape for a tensor resident on GPU.
class GPUShape : public std::vector<int32_t> class NVShape : public std::vector<uint32_t>
{ {
public: public:
GPUShape(const std::initializer_list<int32_t>& axis_lengths) NVShape(const std::initializer_list<uint32_t>& axis_lengths)
: std::vector<int32_t>(axis_lengths) : std::vector<uint32_t>(axis_lengths)
{ {
} }
GPUShape(const std::vector<int32_t>& axis_lengths) NVShape(const std::vector<uint32_t>& axis_lengths)
: std::vector<int32_t>(axis_lengths) : std::vector<uint32_t>(axis_lengths)
{ {
} }
GPUShape(const GPUShape& axis_lengths) NVShape(const NVShape& axis_lengths)
: std::vector<int32_t>(axis_lengths) : std::vector<uint32_t>(axis_lengths)
{ {
} }
explicit GPUShape(size_t n, int32_t initial_value = 0) explicit NVShape(size_t n, uint32_t initial_value = 0)
: std::vector<int32_t>(n, initial_value) : std::vector<uint32_t>(n, initial_value)
{ {
} }
template <class InputIterator> template <class InputIterator>
GPUShape(InputIterator first, InputIterator last) NVShape(InputIterator first, InputIterator last)
: std::vector<int32_t>(first, last) : std::vector<uint32_t>(first, last)
{ {
} }
GPUShape() {} NVShape() {}
GPUShape& operator=(const GPUShape& v) NVShape& operator=(const NVShape& v)
{ {
static_cast<std::vector<int32_t>*>(this)->operator=(v); static_cast<std::vector<uint32_t>*>(this)->operator=(v);
return *this; return *this;
} }
GPUShape& operator=(GPUShape&& v) NVShape& operator=(NVShape&& v)
{ {
static_cast<std::vector<int32_t>*>(this)->operator=(v); static_cast<std::vector<uint32_t>*>(this)->operator=(v);
return *this; return *this;
} }
GPUShape(const std::vector<size_t>& vec) NVShape(const std::vector<size_t>& vec)
{ {
for (size_t const& size : vec) for (size_t const& size : vec)
{ {
if (size >> 32 != 0) if (size >> 32 != 0)
{ {
throw std::runtime_error( throw std::runtime_error(
"Request exceeds the bitwidth available for GPUShapes (32)"); "Request exceeds the bitwidth available for NVShapes (32)");
} }
this->push_back(static_cast<int32_t>(size)); this->push_back(static_cast<uint32_t>(size));
} }
} }
GPUShape(const Shape& shape) NVShape(const Shape& shape)
{ {
for (size_t const& size : shape) for (size_t const& size : shape)
{ {
if (size >> 32 != 0) if (size >> 32 != 0)
{ {
throw std::runtime_error( throw std::runtime_error(
"Request for Shape which exceeds the bitwidth available for GPUShapes " "Request for Shape which exceeds the bitwidth available for NVShapes "
"(32)"); "(32)");
} }
this->push_back(static_cast<int32_t>(size)); this->push_back(static_cast<uint32_t>(size));
} }
} }
GPUShape(const Strides& strides) NVShape(const Strides& strides)
{ {
for (size_t const& size : strides) for (size_t const& size : strides)
{ {
if (size >> 32 != 0) if (size >> 32 != 0)
{ {
throw std::runtime_error( throw std::runtime_error(
"Request for Strides which exceed the bitwidth available for GPUShapes " "Request for Strides which exceed the bitwidth available for NVShapes "
"(32)"); "(32)");
} }
this->push_back(static_cast<int32_t>(size)); this->push_back(static_cast<uint32_t>(size));
} }
} }
GPUShape(const Coordinate& coord) NVShape(const Coordinate& coord)
{ {
for (size_t const& size : coord) for (size_t const& size : coord)
{ {
if (size >> 32 != 0) if (size >> 32 != 0)
{ {
throw std::runtime_error( throw std::runtime_error(
"Request for Coordinate which exceed the bitwidth available for GPUShapes " "Request for Coordinate which exceed the bitwidth available for NVShapes "
"(32)"); "(32)");
} }
this->push_back(static_cast<int32_t>(size)); this->push_back(static_cast<uint32_t>(size));
} }
} }
GPUShape(const CoordinateDiff& coord) NVShape(const AxisVector& vec)
{
for (auto const& dim : coord)
{
if (dim > 0 && dim >> 32 != 0)
{
throw std::runtime_error(
"Request for CoordinateDiff which exceed the bitwidth available for "
"GPUShapes "
"(32)");
}
this->push_back(static_cast<int32_t>(dim));
}
}
GPUShape(const AxisVector& vec)
{ {
for (auto const& size : vec) for (auto const& size : vec)
{ {
if (size >> 32 != 0) if (size >> 32 != 0)
{ {
throw std::runtime_error( throw std::runtime_error(
"Request for axis vector which exceed the bitwidth available for GPUShapes " "Request for axis vector which exceed the bitwidth available for NVShapes "
"(32)"); "(32)");
} }
this->push_back(static_cast<int32_t>(size)); this->push_back(static_cast<uint32_t>(size));
} }
} }
}; };
......
...@@ -20,15 +20,15 @@ ...@@ -20,15 +20,15 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp" #include "ngraph/runtime/gpu/gpu_primitive_emitter.hpp"
#include "ngraph/runtime/gpu/gpu_shape.hpp"
#include "ngraph/runtime/gpu/gpu_util.hpp" #include "ngraph/runtime/gpu/gpu_util.hpp"
#include "ngraph/runtime/gpu/nvshape.hpp"
using namespace ngraph; using namespace ngraph;
TEST(gpu_test, gpu_shape_from_64bit_shape) TEST(gpu_test, gpu_shape_from_64bit_shape)
{ {
Shape shape{1UL << 33}; Shape shape{1UL << 33};
ASSERT_ANY_THROW([](GPUShape s) {}(shape);); ASSERT_ANY_THROW([](NVShape s) {}(shape););
} }
TEST(gpu_test, memory_manager_unallocated) TEST(gpu_test, memory_manager_unallocated)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment