Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #28 from sneakerkg/master
Browse files Browse the repository at this point in the history
Local IO
  • Loading branch information
sneakerkg committed Aug 28, 2015
2 parents 2ad67a3 + 25e3363 commit 5e9b53b
Show file tree
Hide file tree
Showing 19 changed files with 1,019 additions and 160 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,3 @@ Debug
.dir-locals.el
__pycache__
*.pkl
*
9 changes: 5 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ endif
#BIN = test/test_threaded_engine test/api_registry_test
OBJ = narray_function_cpu.o
# add threaded engine after it is done
OBJCXX11 = reshape_cpu.o engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o
OBJCXX11 = reshape_cpu.o engine.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
Expand Down Expand Up @@ -105,12 +105,13 @@ convolution_cpu.o: src/operator/convolution.cc
convolution_gpu.o: src/operator/convolution.cu
reshape_cpu.o: src/operator/reshape.cc
reshape_gpu.o: src/operator/reshape.cu
io.o: src/io/io.cc
iter_mnist.o: src/io/iter_mnist.cc

lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)
lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) $(LIB_DEP)
lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) $(LIB_DEP)

test/test_storage: test/test_storage.cc lib/libmxnet.a
#test/test_threaded_engine: test/test_threaded_engine.cc api/libmxnet.a

$(BIN) :
$(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)
Expand Down
12 changes: 12 additions & 0 deletions doc/python/io.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Python IO API
===================
Mxnet handles IO for you by implementing data iterators.
It is like an iterable class in python, you can traverse the data using a for loop.


IO API Reference
----------------------
```eval_rst
.. automodule:: mxnet.io
:members:
```
2 changes: 0 additions & 2 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ typedef mshadow::TBlob TBlob;
namespace dmlc {
// Add a few patches to support TShape in dmlc/parameter.
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
DMLC_DECLARE_TYPE_NAME(uint32_t, "unsigned int");


namespace parameter {
template<>
Expand Down
169 changes: 149 additions & 20 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ typedef void *SymbolHandle;
typedef void *AtomicSymbolHandle;
/*! \brief handle to an Executor */
typedef void *ExecutorHandle;
/*! \brief handle a dataiter creator */
typedef void *DataIterCreator;
/*! \brief handle to a DataIterator */
typedef void *DataIterHandle;
/*!
Expand Down Expand Up @@ -452,49 +454,176 @@ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
// Part 5: IO Interface
//--------------------------------------------
/*!
* \brief create an data iterator from configs string
* \param cfg config string that contains the
* configuration about the iterator
* \param out the handle to the iterator
* \brief List all the available iterator entries
* \param out_size the size of returned iterators
* \param out_array the output iteratos entries
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIOCreateFromConfig(const char *cfg,
DataIterHandle *out);
MXNET_DLL int MXListDataIters(mx_uint *out_size,
DataIterCreator **out_array);
/*!
* \brief move iterator to next position
* \param handle the handle to iterator
* \param out return value of next
* \brief Init an iterator, init with parameters
* the array size of passed in arguments
* \param handle of the iterator creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIONext(DataIterHandle handle,
int *out);
MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle,
int num_param,
const char **keys,
const char **vals,
DataIterHandle *out);
/*!
* \brief call iterator.BeforeFirst
* \param handle the handle to iterator
* \brief Get the detailed information about data iterator.
* \param creator the DataIterCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIOBeforeFirst(DataIterHandle handle);
MXNET_DLL int MXDataIterGetIterInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief free the handle to the IO module
* \brief Free the handle to the IO module
* \param handle the handle pointer to the data iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIOFree(DataIterHandle handle);
MXNET_DLL int MXDataIterFree(DataIterHandle handle);
/*!
* \brief get the name of iterator entry
* \param iter iterator entry
* \param out_name the name of the iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetName(DataIterCreator iter,
const char **out_name);
/*!
* \brief Init an iterator, init with parameters
* the array size of passed in arguments
* \param handle of the iterator creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle,
int num_param,
const char **keys,
const char **vals,
DataIterHandle *out);
/*!
* \brief Get the detailed information about data iterator.
* \param creator the DataIterCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetIterInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief Free the handle to the IO module
* \param handle the handle pointer to the data iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterFree(DataIterHandle handle);
/*!
* \brief Get the name of iterator entry
* \param iter iterator entry
* \param out_name the name of the iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetName(DataIterCreator iter,
const char **out_name);
/*!
* \brief Init an iterator, init with parameters
* the array size of passed in arguments
* \param handle of the iterator creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle,
int num_param,
const char **keys,
const char **vals,
DataIterHandle *out);
/*!
* \brief Get the detailed information about data iterator.
* \param creator the DataIterCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetIterInfo(AtomicSymbolCreator creator,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief Free the handle to the IO module
* \param handle the handle pointer to the data iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterFree(DataIterHandle handle);
/*!
* \brief Move iterator to next position
* \param handle the handle to iterator
* \param out return value of next
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterNext(DataIterHandle handle,
int *out);
/*!
* \brief Call iterator.Reset
* \param handle the handle to iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);

/*!
* \brief get the handle to the NArray of underlying data
* \brief Get the handle to the NArray of underlying data
* \param handle the handle pointer to the data iterator
* \param out handle to underlying data NArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIOGetData(DataIterHandle handle,
MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
NArrayHandle *out);
/*!
* \brief get the handle to the NArray of underlying label
* \brief Get the handle to the NArray of underlying label
* \param handle the handle pointer to the data iterator
* \param out the handle to underlying label NArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXIOGetLabel(DataIterHandle handle,
MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
NArrayHandle *out);

#endif // MXNET_C_API_H_
113 changes: 113 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*!
* Copyright (c) 2015 by Contributors
* \file io.h
* \brief mxnet io data structure and data iterator
*/
#ifndef MXNET_IO_H_
#define MXNET_IO_H_
#include <dmlc/data.h>
#include <dmlc/registry.h>
#include <vector>
#include <string>
#include <utility>
#include "./base.h"

namespace mxnet {
/*!
* \brief iterator type
* \tparam DType data type
*/
template<typename DType>
class IIterator : public dmlc::DataIter<DType> {
public:
/*!
* \brief set the parameters and init iter
* \param kwargs key-value pairs
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*! \brief reset the iterator */
virtual void BeforeFirst(void) = 0;
/*! \brief move to next item */
virtual bool Next(void) = 0;
/*! \brief get current data */
virtual const DType &Value(void) const = 0;
/*! \brief constructor */
virtual ~IIterator(void) {}
/*! \brief store the name of each data, it could be used for making NArrays */
std::vector<std::string> data_names;
/*! \brief set data name to each attribute of data */
inline void SetDataName(const std::string data_name){
data_names.push_back(data_name);
}
}; // class IIterator

/*! \brief a single data instance */
struct DataInst {
/*! \brief unique id for instance */
unsigned index;
/*! \brief content of data */
std::vector<TBlob> data;
/*! \brief extra data to be fed to the network */
std::string extra_data;
}; // struct DataInst

/*!
* \brief a standard batch of data commonly used by iterator
* a databatch contains multiple TBlobs. Each Tblobs has
* a name stored in a map. There's no different between
* data and label, how we use them is to see the DNN implementation.
*/
struct DataBatch {
public:
/*! \brief unique id for instance, can be NULL, sometimes is useful */
unsigned *inst_index;
/*! \brief number of instance */
mshadow::index_t batch_size;
/*! \brief number of padding elements in this batch,
this is used to indicate the last elements in the batch are only padded up to match the batch, and should be discarded */
mshadow::index_t num_batch_padd;
public:
/*! \brief content of dense data, if this DataBatch is dense */
std::vector<TBlob> data;
/*! \brief extra data to be fed to the network */
std::string extra_data;
public:
/*! \brief constructor */
DataBatch(void) {
inst_index = NULL;
batch_size = 0; num_batch_padd = 0;
}
/*! \brief giving name to the data */
void Naming(std::vector<std::string> names);
}; // struct DataBatch

/*! \brief typedef the factory function of data iterator */
typedef IIterator<DataBatch> *(*DataIteratorFactory)();
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct DataIteratorReg
: public dmlc::FunctionRegEntryBase<DataIteratorReg,
DataIteratorFactory> {
};
//--------------------------------------------------------------
// The following part are API Registration of Iterators
//--------------------------------------------------------------
/*!
* \brief Macro to register Iterators
*
* \code
* // example of registering a mnist iterator
* REGISTER_IO_ITERATOR(MNIST, MNISTIterator)
* .describe("Mnist data iterator");
*
* \endcode
*/
#define MXNET_REGISTER_IO_ITER(name, DataIteratorType) \
static ::mxnet::IIterator<DataBatch>* __create__ ## DataIteratorType ## __() { \
return new DataIteratorType; \
} \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name) \
.set_body(__create__ ## DataIteratorType ## __)
} // namespace mxnet
#endif // MXNET_IO_H_
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import MXNetError
from . import narray
from . import symbol
from . import io

__version__ = "0.1.0"

Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _load_lib():
SymbolCreatorHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
ExecutorHandle = ctypes.c_void_p

DataIterCreatorHandle = ctypes.c_void_p
DataIterHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
Expand Down
Loading

0 comments on commit 5e9b53b

Please sign in to comment.