Commit 7d2d0850 authored by Scott Cyphers's avatar Scott Cyphers

Add convert, update op list.

parent 070b958e
......@@ -20,6 +20,7 @@ set (SRC
log.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/convert.cpp
ops/constant.cpp
ops/dot.cpp
ops/function.cpp
......
......@@ -26,6 +26,7 @@
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convert.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/tuple.hpp"
......
......@@ -26,7 +26,6 @@ namespace ngraph
{
Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr candidate();
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr convert();
//Node::ptr convolution();
......@@ -34,11 +33,13 @@ namespace ngraph
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get();
//Node::ptr get_tuple_element();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr greater_equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr less_equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr log(const Node::ptr& arg0);
//Node::ptr logical();
//Node::ptr logical(); and, or, not
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
......@@ -46,11 +47,13 @@ namespace ngraph
//Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce();
// Node::ptr reduce_window();
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape);
//Node::ptr reverse();
//Node::ptr rng();
//Node::ptr select();
//Node::ptr select_scatter();
//Node::ptr slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr transpose();
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
namespace ngraph
{
class ConvertOp : public BuiltinOp
{
public:
ConvertOp(const Node::ptr& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg})
, m_element_type(element_type)
{
}
virtual std::string op_name() const override { return "convert"; }
virtual void propagate_types() override;
protected:
const ngraph::element::Type& m_element_type;
};
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const Node::ptr& arg, const ngraph::element::Type& element_type);
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
void ConvertOp::propagate_types()
{
throw ngraph_error("NIY");
}
shared_ptr<ConvertOp> op::convert(const Node::ptr& arg, const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment