Skip to content

Commit

Permalink
update recordio reader
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Jul 26, 2017
1 parent 89a24c5 commit 06d631f
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions doc/usage_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ scp -r my_training_data_dir/ user@tunnel-server:/mnt/hdfs_mulan/idl/idl-dl/mydir
在训练任务提交后,每个训练节点会把HDFS挂载在`/pfs/[datacenter_name]/home/[username]/`目录下这样训练程序即可使用这个路径读取训练数据并开始训练。

### 使用[RecordIO](https://github.com/PaddlePaddle/recordio)对训练数据进行预处理
用户可以在本地将数据预先处理为RecordIO的格式,再上传至集群进行训练。
用户需要在本地将数据预先处理为RecordIO的格式,再上传至集群进行训练。
- 使用RecordIO库进行数据预处理
```python
import paddle.v2.dataset as dataset
Expand Down Expand Up @@ -97,15 +97,30 @@ dataset.convert(output_path = "./dataset",
- 编写reader来读取RecordIO格式的文件
```python
import cPickle as pickle
def cluster_creator(filename):
import recordio
import recordio
import glob
import sys
def recordio_reader(filepath, parallelism, trainer_id):
# sample filepath as "/pfs/dlnel/home/yanxu05@baidu.com/dataset/uci_housing/uci_housing_train*"
def reader():
r = recordio.reader("./dataset/uci_housing_train*")
while True:
d = r.read()
if not d:
break
yield pickle.loads(d)
if trainer_id >= parallelism:
sys.stdout.write("invalied trainer_id: %d\n" % trainer_id)
return
files = glob.glob(filepath)
files.sort()
my_file_list = []
for idx, f in enumerate(files):
if idx % parallelism == trainer_id:
my_file_list.append(f)

for fn in my_file_list:
r = recordio.reader(fn)
while True:
d = r.read()
if not d:
break
yield pickle.loads(d)

return reader
```

Expand Down

0 comments on commit 06d631f

Please sign in to comment.