Commit d37359aa authored by Scott Cyphers's avatar Scott Cyphers

Only worry about call graph for now.

parent 28f13818
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <algorithm>
......@@ -6,90 +20,7 @@
#include "values/type.hpp"
namespace ngraph {
class ValueDescriptor
{
public:
using ptr_t = std::shared_ptr<ValueDescriptor>;
virtual ValueType::ptr_t value_type() const = 0;
};
class TensorDescriptor
{
public:
using ptr_t = std::shared_ptr<TensorDescriptor>;
TensorDescriptor(const ElementType& element_type)
: m_element_type(element_type)
{}
protected:
const ElementType& m_element_type;
};
class TensorLayoutDescriptor
{
public:
using ptr_t = std::shared_ptr<TensorLayoutDescriptor>;
};
class TensorViewDescriptor : public ValueDescriptor
namespace ngraph
{
public:
using ptr_t = std::shared_ptr<TensorViewDescriptor>;
TensorViewDescriptor(const TensorViewType::ptr_t& type)
: m_type(type)
{}
TensorViewDescriptor(const ElementType& element_type, const Shape& shape)
: TensorViewDescriptor(TensorViewType::make(element_type, shape))
{}
static ptr_t make(const TensorViewType::ptr_t& type){
return ptr_t::make_shared(type);
}
static ptr_t make(const ElementType& element_type, const Shape& shape){
return ptr_t::make_shared(element_type, shape);
}
ValueType::ptr_t value_type() const override {
return m_type;
}
protected:
TensorViewType::ptr_t m_type;
TensorDescriptor::ptr_t m_tensor_descriptor;
TensorLayoutDescriptor::ptr_t m_tensor_layout_descriptor;
};
class TupleDescriptor : public ValueDescriptor
{
public:
using ptr_t = std::shared_ptr<TupleDescriptor>;
TupleDescriptor(const std::vector<ValueDescriptor::ptr_t>& elements)
: m_element_descriptors(elements)
{
std::vector<ValueType::ptr_t> types;
for(auto elt : elements){
types.push_back(elt->value_type());
}
m_type = TupleType::make(types);
}
static ptr_t make(const std::vector<ValueDescriptor::ptr_t>& elements){
return ptr_t::make_shared(elements);
}
ValueType::ptr_t value_type() const override {
return m_type;
}
protected:
TupleType::ptr_t m_type;
std::vector<ValueDescriptor::ptr_t> m_element_descriptors;
};
} // End of NGRAPH
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/function.hpp"
using namespace std;
using namespace ngraph;
Parameter::ptr_t Parameter::make(Function& function, size_t index, const ValueType::ptr_t& output_type){
return ptr_t::make_shared(function, index, output_type);
}
Function::ptr_t Function::make(const ValueType::ptr_t& return_type, const std::vector<ValueType::ptr_t>& argument_types){
return ptr_t::make_shared(return_type, argument_types);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "values/descriptor.hpp"
#include "values/node.hpp"
#include "values/op.hpp"
#include "values/type.hpp"
namespace ngraph {
class Function;
class Parameter : public Op
namespace ngraph
{
public:
using ptr_t = std::shared_ptr<Parameter>;
static ptr_t make(Function& function, size_t index, const ValueType::ptr_t& output_type);
class Function;
Parameter(Function& function, size_t index, const ValueType::ptr_t& output_type)
: Op({}, output_type)
class Parameter : public Node
{
public:
Parameter(Function& function, size_t index, const std::shared_ptr<ValueType>& type)
: Node(type)
, m_function(function)
, m_index(index)
{}
protected:
protected:
Function& m_function;
size_t m_index;
};
};
class Function
{
public:
using ptr_t = std::shared_ptr<Function>;
class Result {
public:
void type(const std::shared_ptr<ValueType>& t){
m_type = t;
}
Function(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types)
: m_return_type(return_type)
, m_argument_types(argument_types)
{
size_t i = 0;
for (auto argument_type : argument_types){
m_parameters.push_back(Parameter::make(*this, i++, argument_type));
void type(const ElementType& element_type, const Shape& shape){
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const {
return m_type;
}
protected:
std::shared_ptr<ValueType> m_type;
};
static ptr_t make(const ValueType::ptr_t& return_type,
const std::vector<ValueType::ptr_t>& argument_types);
class Function
{
public:
Function(size_t n_parameters)
: m_parameters(n_parameters)
{}
Parameter::ptr_t parameter(size_t i){
return m_parameters[i];
Result *result(){
return &m_result;
}
protected:
std::vector<Parameter::ptr_t> m_parameters;
std::vector<std::shared_ptr<ValueType>> m_argument_types;
std::shared_ptr<ValueType> m_return_type;
};
std::shared_ptr<Parameter> parameter(size_t i){
return m_parameters[i];
}
protected:
std::vector<std::shared_ptr<Parameter>> m_parameters;
Result m_result;
};
} // end namespace ngraph
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <vector>
#include "values/type.hpp"
namespace ngraph
{
class Node
{
public:
Node(std::shared_ptr<ValueType> type=0)
: m_type(type)
{}
virtual ~Node(){}
virtual std::vector<std::shared_ptr<Node>> dependents() {
return m_parameters;
}
void type(const std::shared_ptr<ValueType>& t){
m_type = t;
}
void type(const ElementType& element_type, const Shape& shape){
m_type = std::make_shared<TensorViewType>(element_type, shape);
}
std::shared_ptr<ValueType> type() const {
return m_type;
}
protected:
std::vector<std::shared_ptr<Node>> m_parameters;
std::shared_ptr<ValueType> m_type;
};
}
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
......@@ -5,49 +19,13 @@
#include "values/descriptor.hpp"
#include "values/type.hpp"
namespace ngraph {
class Op
{
public:
using ptr_t = std::shared_ptr<Op>;
protected:
Op(const std::vector<ptr_t>& inputs, const ValueType::ptr_t output_type)
: m_inputs(inputs)
, m_output_type(output_type)
{}
std::vector<ptr_t> m_inputs;
ValueType::ptr_t m_output_type;
};
class Broadcast : public Op
namespace ngraph
{
public:
using ptr_t = std::shared_ptr<Broadcast>;
Broadcast(const Op::ptr_t& x, std::vector<size_t> dims)
: Op({x}, 0)
, m_dims(dims)
{}
public:
static ptr_t make(const Op::ptr_t& x, std::vector<size_t> dims){
return ptr_t::make_shared(x, dims);
}
protected:
std::vector<size_t> m_dims;
};
class Tuple : public Op
{
public:
Tuple(const std::vector<ptr_t>& inputs)
: Op(inputs, 0)
class Call : public Node
{
}
};
protected:
std::vector<std::shared_ptr<Node>> m_args;
};
} // end of namespace ngraph
\ No newline at end of file
}
\ No newline at end of file
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
......@@ -7,67 +21,47 @@
namespace ngraph {
class TensorViewDescriptor;
class TupleDescriptor;
using value_size_t = size_t;
class Shape
{
public:
Shape(const std::initializer_list<value_size_t>& sizes)
class Shape
{
public:
Shape(const std::initializer_list<size_t>& sizes)
: m_sizes(sizes)
{}
protected:
std::vector<value_size_t> m_sizes;
};
// Base type for ngraph values
class ValueType
{
public:
using ptr_t = std::shared_ptr<ValueType>;
};
protected:
std::vector<size_t> m_sizes;
};
class TensorViewType : public ValueType
{
public:
using ptr_t = std::shared_ptr<TensorViewType>;
using descriptor_t = TensorViewDescriptor;
// ValueType is
// TensorViewType
// | TupleType(ValueType[])
class ValueType
{
};
class TensorViewType : public ValueType
{
public:
TensorViewType(const ElementType& element_type, const Shape& shape)
: m_element_type(element_type)
, m_shape(shape)
{}
static ptr_t make(const ElementType& element_type, const Shape& shape){
return ptr_t::make_shared(element_type, shape);
}
protected:
protected:
TensorViewType(const TensorViewType&) = delete;
const ElementType& m_element_type;
Shape m_shape;
};
};
class TupleType : public ValueType
{
public:
using ptr_t = std::shared_ptr<TupleType>;
using descriptor_t = TupleDescriptor;
class TupleType : public ValueType
{
public:
TupleType(const std::vector<ValueType::ptr_t>& element_types)
TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types)
{}
static ptr_t make(const std::vector<ValueType::ptr_t>& element_types){
return ptr_t::make_shared(element_types);
}
protected:
// Is this name too similar to TensorViewType.to m_element_type?
std::vector<ValueType::ptr_t> m_element_types;
};
} // End of ngraph
\ No newline at end of file
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
};
}
\ No newline at end of file
#include "values/descriptor.hpp"
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "values/type.hpp"
#include "values/function.hpp"
using namespace std;
......@@ -6,15 +20,20 @@ using namespace ngraph;
void build_simple_graph()
{
auto cluster_0 = Function::make(
TensorViewType::make(element_type_float, Shape({32, 3})),
{TensorViewType::make(element_type_float, Shape({7, 3})),
TensorViewType::make(element_type_float, Shape({3})),
TensorViewType::make(element_type_float, Shape({32, 7})),
TensorViewType::make(element_type_float, Shape({32, 7}))
});
// Function with 4 parameters
auto cluster_0 = make_shared<Function>(4);
cluster_0->result()->type(element_type_float, Shape {32, 3});
cluster_0->parameter(0)->type(element_type_float, Shape {Shape {7, 3}});
cluster_0->parameter(1)->type(element_type_float, Shape {Shape {3}});
cluster_0->parameter(2)->type(element_type_float, Shape {Shape {32, 7}});
cluster_0->parameter(3)->type(element_type_float, Shape {Shape {32, 7}});
auto arg3 = cluster_0->parameter(3);
auto broadcast_1 = Broadcast::make(arg3, {1});
// call broadcast op on arg3, broadcasting on axis 1.
//auto broadcast_1 = op::broadcast(arg3, 1);
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
//auto dot = op::dot(arg2, arg0);
// Function returns tuple of dot and broadcast_1.
//cluster_0.result->value(op::tuple(dot, broadcast_1));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment