From 4fb5c13f8a2e16489350bc9aa1ce104a5583b9dc Mon Sep 17 00:00:00 2001 From: Thunderbrook Date: Thu, 2 Jun 2022 23:01:24 +0800 Subject: [PATCH] adapt uint64 and iterate edge table --- paddle/fluid/framework/data_feed.cu | 26 +++++------ paddle/fluid/framework/data_feed.h | 8 ++-- paddle/fluid/framework/data_set.cc | 16 ++++++- paddle/fluid/framework/data_set.h | 8 ++-- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 43 +++++++++++-------- 5 files changed, 62 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 5cc511d7222d11..b801e55f442522 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -359,7 +359,7 @@ __global__ void GraphFillFirstStepKernel(int *prefix_sum, int *sampleidx2row, } // Fill sample_res to the stepth column of walk -void GraphDataGenerator::FillOneStep(int64_t *d_start_ids, int64_t *walk, +void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, uint64_t *walk, int len, NeighborSampleResult &sample_res, int cur_degree, int step, int *len_per_row) { @@ -469,8 +469,8 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { size_t device_key_size = h_device_keys_[type_index]->size(); VLOG(2) << "type: " << node_type << " size: " << device_key_size << " start: " << start; - int64_t *d_type_keys = - reinterpret_cast(d_device_keys_[type_index]->ptr()); + uint64_t *d_type_keys = + reinterpret_cast(d_device_keys_[type_index]->ptr()); int tmp_len = start + once_sample_startid_len_ > device_key_size ? device_key_size - start : once_sample_startid_len_; @@ -492,7 +492,7 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { uint64_t *cur_walk = walk + i; NeighborSampleQuery q; - q.initialize(gpuid_, path[0], (int64_t)(d_type_keys + start), walk_degree_, + q.initialize(gpuid_, path[0], (uint64_t)(d_type_keys + start), walk_degree_, tmp_len); auto sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false); @@ -518,11 +518,11 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { break; } auto sample_key_mem = sample_res.actual_val_mem; - int64_t *sample_keys_ptr = - reinterpret_cast(sample_key_mem->ptr()); + uint64_t *sample_keys_ptr = + reinterpret_cast(sample_key_mem->ptr()); int edge_type_id = path[(step - 1) % path_len]; VLOG(2) << "sample edge type: " << edge_type_id << " step: " << step; - q.initialize(gpuid_, edge_type_id, (int64_t)sample_keys_ptr, 1, + q.initialize(gpuid_, edge_type_id, (uint64_t)sample_keys_ptr, 1, sample_res.total_sample_size); sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false); @@ -588,11 +588,11 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, VLOG(3) << "h_device_keys_[" << i << "][" << j << "] = " << (*(h_device_keys_[i]))[j]; } - auto buf = memory::AllocShared(place_, - h_device_keys_[i]->size() * sizeof(int64_t)); + auto buf = memory::AllocShared( + place_, h_device_keys_[i]->size() * sizeof(uint64_t)); d_device_keys_.push_back(buf); CUDA_CHECK(cudaMemcpyAsync(buf->ptr(), h_device_keys_[i]->data(), - h_device_keys_[i]->size() * sizeof(int64_t), + h_device_keys_[i]->size() * sizeof(uint64_t), cudaMemcpyHostToDevice, stream_)); } // h_device_keys_ = h_device_keys; @@ -610,10 +610,10 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, (once_max_sample_keynum + 1) * sizeof(int), stream_); cursor_ = 0; jump_rows_ = 0; - d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(int64_t)); - cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(int64_t), stream_); + d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(uint64_t)); + cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); d_sample_keys_ = - memory::AllocShared(place_, once_max_sample_keynum * sizeof(int64_t)); + memory::AllocShared(place_, once_max_sample_keynum * sizeof(uint64_t)); d_sampleidx2rows_.push_back( memory::AllocShared(place_, once_max_sample_keynum * sizeof(int))); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index ce1d55faad445d..2fc0d242198cc1 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -895,11 +895,11 @@ class GraphDataGenerator { int AcquireInstance(BufState* state); int GenerateBatch(); int FillWalkBuf(std::shared_ptr d_walk); - void FillOneStep(int64_t* start_ids, int64_t* walk, int len, + void FillOneStep(uint64_t* start_ids, uint64_t* walk, int len, NeighborSampleResult& sample_res, int cur_degree, int step, int* len_per_row); int FillInsBuf(); - void SetDeviceKeys(std::vector* device_keys, int type) { + void SetDeviceKeys(std::vector* device_keys, int type) { type_to_index_[type] = h_device_keys_.size(); h_device_keys_.push_back(device_keys); } @@ -913,7 +913,7 @@ class GraphDataGenerator { // start ids // int64_t* device_keys_; // size_t device_key_size_; - std::vector*> h_device_keys_; + std::vector*> h_device_keys_; std::unordered_map type_to_index_; // point to device_keys_ size_t cursor_; @@ -1018,7 +1018,7 @@ class DataFeed { virtual void SetParseLogKey(bool parse_logkey) {} virtual void SetEnablePvMerge(bool enable_pv_merge) {} virtual void SetCurrentPhase(int current_phase) {} - virtual void SetDeviceKeys(std::vector* device_keys, int type) { + virtual void SetDeviceKeys(std::vector* device_keys, int type) { gpu_graph_data_generator_.SetDeviceKeys(device_keys, type); } virtual void SetGpuGraphMode(int gpu_graph_mode) { diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index e4608356a4751b..8e3462edfde882 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -429,7 +429,7 @@ void MultiSlotDataset::PrepareTrain() { template void DatasetImpl::SetGraphDeviceKeys( - const std::vector& h_device_keys) { + const std::vector& h_device_keys) { // for (size_t i = 0; i < gpu_graph_device_keys_.size(); i++) { // gpu_graph_device_keys_[i].clear(); // } @@ -452,6 +452,7 @@ void DatasetImpl::LoadIntoMemory() { graph_all_type_total_keys_.clear(); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); auto node_to_id = gpu_graph_ptr->feature_to_id; + auto edge_to_id = gpu_graph_ptr->edge_to_id; graph_all_type_total_keys_.resize(node_to_id.size()); int cnt = 0; for (auto& iter : node_to_id) { @@ -474,6 +475,19 @@ void DatasetImpl::LoadIntoMemory() { } cnt++; } + // FIX: trick for iterate edge table + for (auto& iter : edge_to_id) { + int edge_idx = iter.second; + auto gpu_graph_device_keys = + gpu_graph_ptr->get_all_id(0, edge_idx, thread_num_); + for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { + VLOG(1) << "edge type: " << edge_idx << ", gpu_graph_device_keys[" << i + << "] = " << gpu_graph_device_keys[i].size(); + for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { + gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]); + } + } + } } else { for (int64_t i = 0; i < thread_num_; ++i) { diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 89fc881045a44c..0d326d3fd1364a 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -269,7 +269,9 @@ class DatasetImpl : public Dataset { return multi_consume_channel_; } } - std::vector& GetGpuGraphTotalKeys() { return gpu_graph_total_keys_; } + std::vector& GetGpuGraphTotalKeys() { + return gpu_graph_total_keys_; + } Channel& GetInputChannelRef() { return input_channel_; } protected: @@ -331,8 +333,8 @@ class DatasetImpl : public Dataset { bool enable_heterps_ = false; int gpu_graph_mode_ = 1; // std::vector> gpu_graph_device_keys_; - std::vector>> graph_all_type_total_keys_; - std::vector gpu_graph_total_keys_; + std::vector>> graph_all_type_total_keys_; + std::vector gpu_graph_total_keys_; }; // use std::vector or Record as data type diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 43a24f2c40dc8e..cf9fb14bb9b9cd 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -115,7 +115,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { std::vector threads; // data should be in input channel - + thread_dim_keys_.resize(thread_keys_thread_num_); for (int i = 0; i < thread_keys_thread_num_; i++) { thread_dim_keys_[i].resize(thread_keys_shard_num_); @@ -130,21 +130,23 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { size_t begin = 0; std::string data_set_name = std::string(typeid(*dataset_).name()); - - VLOG(0) <<"gpu_graph_mode_:" << gpu_graph_mode_; + + VLOG(0) << "gpu_graph_mode_:" << gpu_graph_mode_; if (!gpu_graph_mode_) { if (data_set_name.find("SlotRecordDataset") != std::string::npos) { VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset"; SlotRecordDataset* dataset = dynamic_cast(dataset_); auto input_channel = dataset->GetInputChannel(); - VLOG(0) << "psgpu wrapperinputslotchannle size: " << input_channel->Size(); + VLOG(0) << "psgpu wrapperinputslotchannle size: " + << input_channel->Size(); const std::deque& vec_data = input_channel->GetData(); total_len = vec_data.size(); len_per_thread = total_len / thread_keys_thread_num_; remain = total_len % thread_keys_thread_num_; VLOG(0) << "total len: " << total_len; - auto gen_dynamic_mf_func = [this](const std::deque& total_data, - int begin_index, int end_index, int i) { + auto gen_dynamic_mf_func = [this]( + const std::deque& total_data, int begin_index, + int end_index, int i) { for (auto iter = total_data.begin() + begin_index; iter != total_data.begin() + end_index; iter++) { const auto& ins = *iter; @@ -157,7 +159,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { int shard_id = feasign_v[j] % thread_keys_shard_num_; int dim_id = slot_index_vec_[slot_idx]; if (feasign_v[j] != 0) { - this->thread_dim_keys_[i][shard_id][dim_id].insert(feasign_v[j]); + this->thread_dim_keys_[i][shard_id][dim_id].insert( + feasign_v[j]); } } } @@ -165,8 +168,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { }; for (int i = 0; i < thread_keys_thread_num_; i++) { threads.push_back( - std::thread(gen_dynamic_mf_func, std::ref(vec_data), begin, - begin + len_per_thread + (i < remain ? 1 : 0), i)); + std::thread(gen_dynamic_mf_func, std::ref(vec_data), begin, + begin + len_per_thread + (i < remain ? 1 : 0), i)); begin += len_per_thread + (i < remain ? 1 : 0); } @@ -174,7 +177,8 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { t.join(); } timeline.Pause(); - VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; + VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() + << " seconds."; } else { CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos); VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset"; @@ -208,18 +212,19 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { t.join(); } timeline.Pause(); - VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; + VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() + << " seconds."; } } else { VLOG(0) << "PreBuild in GpuGraph mode"; SlotRecordDataset* dataset = dynamic_cast(dataset_); - const std::vector& vec_data = dataset->GetGpuGraphTotalKeys(); + const std::vector& vec_data = dataset->GetGpuGraphTotalKeys(); total_len = vec_data.size(); len_per_thread = total_len / thread_keys_thread_num_; VLOG(0) << "GpuGraphTotalKeys: " << total_len; remain = total_len % thread_keys_thread_num_; - auto gen_graph_data_func = [this](const std::vector& total_data, - int begin_index, int end_index, int i) { + auto gen_graph_data_func = [this](const std::vector& total_data, + int begin_index, int end_index, int i) { for (auto iter = total_data.begin() + begin_index; iter != total_data.begin() + end_index; iter++) { uint64_t cur_key = *iter; @@ -227,10 +232,11 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { this->thread_keys_[i][shard_id].insert(cur_key); } }; - auto gen_graph_dynamic_mf_func = [this](const std::vector& total_data, - int begin_index, int end_index, int i) { + auto gen_graph_dynamic_mf_func = [this]( + const std::vector& total_data, int begin_index, int end_index, + int i) { for (auto iter = total_data.begin() + begin_index; - iter != total_data.begin() + end_index; iter++) { + iter != total_data.begin() + end_index; iter++) { uint64_t cur_key = *iter; int shard_id = cur_key % thread_keys_shard_num_; // int dim_id = slot_index_vec_[slot_idx]; @@ -895,7 +901,8 @@ void PSGPUWrapper::EndPass() { auto& device_keys = this->current_task_->device_dim_keys_[i][j]; size_t len = device_keys.size(); int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim << " key_len :" << len; + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len; size_t feature_value_size = TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float)));