-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support numpy dense #207
Support numpy dense #207
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,16 @@ limitations under the License. */ | |
#include <stdlib.h> | ||
#include <unordered_set> | ||
#include <list> | ||
#include <Python.h> | ||
#include <numpy/numpyconfig.h> | ||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | ||
#include <numpy/ndarrayobject.h> | ||
|
||
#include "DataProvider.h" | ||
|
||
#include "paddle/utils/PythonUtil.h" | ||
#include "paddle/utils/Locks.h" | ||
#include "paddle/utils/Stat.h" | ||
|
||
namespace paddle { | ||
|
||
|
@@ -202,7 +209,10 @@ class PyDataProvider2 : public DataProvider { | |
PyDataProvider2(const DataConfig& config, | ||
const ModelConfig& modelConfig, | ||
bool useGpu) | ||
:DataProvider(config, useGpu), callingContextCreated_(2) { | ||
:DataProvider(config, useGpu), | ||
callingContextCreated_(2) { | ||
if (PyArray_API == NULL) | ||
import_array(); | ||
auto& args = config.load_data_args(); | ||
PyObjectPtr kwargs = PyObjectPtr(PyDict_New()); | ||
if (!args.empty()) { | ||
|
@@ -454,6 +464,7 @@ class PyDataProvider2 : public DataProvider { | |
std::condition_variable pushCV_; | ||
std::condition_variable pullCV_; | ||
std::mutex mtx_; | ||
|
||
ThreadBarrier callingContextCreated_; | ||
std::unique_ptr<IPyDataProviderCache> cache_; | ||
|
||
|
@@ -496,8 +507,8 @@ class PyDataProvider2 : public DataProvider { | |
* Resetting the PyDataProvider. May start reading thread here. | ||
*/ | ||
virtual void reset() { | ||
DataProvider::reset(); | ||
resetImpl(true); | ||
DataProvider::reset(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DataProvider::reset() should be invoked at the end of reset(). Because the DataProvider::reset() will invoke getNextBatchInternal in other thread right now. |
||
} | ||
|
||
/** | ||
|
@@ -518,6 +529,7 @@ class PyDataProvider2 : public DataProvider { | |
* Loading a batch of data. | ||
*/ | ||
int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) { | ||
REGISTER_TIMER("PyDP2.getNextBatchInternal") | ||
CHECK_GE(size_, 0); | ||
size_t size = (size_t) size_; | ||
if (loadThread_) { // loading from thread should wait for data pool ready. | ||
|
@@ -698,10 +710,22 @@ class DenseScanner: public IFieldScanner { | |
*/ | ||
virtual void fill(Argument &argument, PyObject *obj) { | ||
real* dat = argument.value->getData() + height_ * headerPtr_->dim; | ||
py::SequenceHelper s(obj); | ||
// TODO(yuyang18): Here we can use AVX or SSE to accelerate memory copy. | ||
for (size_t i=0; i < headerPtr_->dim; ++i) { | ||
dat[i] = (real) s.getDouble(i); | ||
if (PyArray_Check(obj)) { | ||
auto dtype = PyArray_DTYPE((PyArrayObject*)obj); | ||
if (dtype->type == 'f' && dtype->elsize == sizeof(real)) { | ||
real * data = (real*)PyArray_DATA((PyArrayObject*)obj); | ||
auto sz = PyArray_SIZE((PyArrayObject*)obj); | ||
std::copy(data, data + sz, dat); | ||
} else { | ||
LOG(FATAL) << "You should yield float" << sizeof(real) * 8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only support the same data size returned by numpy. Maybe i need implement other format by casting, but it will be slow. |
||
<< " array"; | ||
} | ||
} else { | ||
py::SequenceHelper s(obj); | ||
// TODO(yuyang18): Here we can use AVX or SSE to accelerate memory copy. | ||
for (size_t i=0; i < headerPtr_->dim; ++i) { | ||
dat[i] = (real) s.getDouble(i); | ||
} | ||
} | ||
++height_; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ def py_data2(files, load_data_module, load_data_object, load_data_args, | |
data.load_data_module = load_data_module | ||
data.load_data_object = load_data_object | ||
data.load_data_args = load_data_args | ||
data.async_load_data = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always enable double buffer. The Maybe we should change it later. |
||
return data | ||
data_cls = py_data2 | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because PyDataProvider2 is always using double-buffer, and PyDataProvider2 is the only interface to providing data in opensource. So remove the noising line here.