Commit 2fef5308 authored by Yixing Lao's avatar Yixing Lao Committed by GitHub

Merge pull request #50 from NervanaSystems/avijit/add_dynamic_plugin_support

Initial implementation of the plugin library
parents 8d57ce68 f65fdf82
......@@ -21,6 +21,7 @@ set (SRC
tree.cpp
util.cpp
log.cpp
ngraph.cpp
transformers/axes.cpp
transformers/exop.cpp
......@@ -40,6 +41,12 @@ include_directories(
add_library(ngraph SHARED ${SRC})
if ( APPLE )
set_property( TARGET ngraph PROPERTY PREFIX "lib" )
set_property( TARGET ngraph PROPERTY OUTPUT_NAME "ngraph.so" )
set_property( TARGET ngraph PROPERTY SUFFIX "" )
endif()
#-----------------------------------------------------------------------------------------------
# Installation logic...
#-----------------------------------------------------------------------------------------------
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph.hpp"
#include "log.hpp"
NGraph* create_ngraph_object()
{
return new NGraph();
}
void destroy_ngraph_object(NGraph* pObj)
{
delete pObj;
}
void NGraph::add_params(const std::vector<std::string>& paramList)
{
INFO << "Adding parameters";
m_params.insert(m_params.end(), paramList.begin(), paramList.end());
}
const std::vector<std::string>& NGraph::get_params() const
{
return m_params;
}
#pragma once
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <string>
#include <vector>
class NGraph
{
public:
void add_params(const std::vector<std::string>& paramList);
const std::vector<std::string>& get_params() const;
std::string get_name() const { return "NGraph Implementation Object"; }
private:
std::vector<std::string> m_params;
};
// Factory methods
extern "C" NGraph* create_ngraph_object();
extern "C" void destroy_ngraph_object(NGraph* pObj);
// FUnction pointers to the factory methods
typedef NGraph* (*CreateNGraphObjPfn)();
typedef void (*DestroyNGraphObjPfn)(NGraph*);
......@@ -31,13 +31,17 @@ set (SRC
uuid.cpp
names.cpp
strides.cpp
ngraph.cpp
)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCURDIR=\\\"${CMAKE_CURRENT_SOURCE_DIR}\\\"")
add_executable(unit-test ${SRC})
target_link_libraries(unit-test ngraph pthread libgtest)
target_link_libraries(unit-test ${CMAKE_DL_LIBS})
add_dependencies(unit-test ngraph libgtest)
add_custom_target(check
......
......@@ -16,6 +16,7 @@
#include <iostream>
#include "gtest/gtest.h"
#include "log.hpp"
using namespace std;
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <dlfcn.h>
#include "gtest/gtest.h"
#include "ngraph.hpp"
#include "log.hpp"
using namespace std;
TEST(NGraph, loadTest)
{
// load the triangle library
void* ngraphImplLib = dlopen("../src/libngraph.so", RTLD_LAZY);
if (!ngraphImplLib)
{
std::cerr << "Cannot load library: " << dlerror() << '\n';
ASSERT_FALSE(true);
}
// reset errors
dlerror();
// Get the symbols
auto createPfn =
reinterpret_cast<CreateNGraphObjPfn>(dlsym(ngraphImplLib, "create_ngraph_object"));
ASSERT_FALSE(createPfn == nullptr);
auto destroyPfn =
reinterpret_cast<DestroyNGraphObjPfn>(dlsym(ngraphImplLib, "destroy_ngraph_object"));
ASSERT_FALSE(destroyPfn == nullptr);
NGraph* nGraphObj = createPfn();
INFO << "Call a method on the Object";
ASSERT_EQ("NGraph Implementation Object", nGraphObj->get_name());
INFO << "Object Name: " << nGraphObj->get_name();
// Add some parameters
const vector<string> TEST_PARAMS = {"param-1", "param-2", "param-3"};
nGraphObj->add_params(TEST_PARAMS);
// Get the list of params
auto& storedParams = nGraphObj->get_params();
EXPECT_EQ(TEST_PARAMS.size(), storedParams.size());
for (int i = 0; i < TEST_PARAMS.size(); i++)
{
EXPECT_EQ(TEST_PARAMS[i], storedParams[i]);
}
INFO << "Destroy the NGraph Object";
destroyPfn(nGraphObj);
dlclose(ngraphImplLib);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment