-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensor draft for review #2645
tensor draft for review #2645
Changes from 3 commits
8912719
271ee6d
db4347e
b73b8a0
ebf06ad
67fe709
d73ee9e
6366d12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#include "paddle/framework/tensor.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
int Tensor::Rank() const { return arity(dims_); } | ||
|
||
int Tensor::Numel() const { return product(dims_); } | ||
|
||
void Tensor::Resize(const DDim& dims) { | ||
dims_ = dims; | ||
return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why add so many There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just my personal habit... I'm used to adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please strictly follow a common style as we work as a team. Let's remove all unnecessary code. |
||
} | ||
|
||
void Tensor::Reshape(const DDim& dims) { | ||
if (product(dims) != product(dims_)) { | ||
// TODO: error: "Reshape() can not change tensor's numel". | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can not reshape change tensor's numel? In each mini-batch training, each output of Op should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users can use |
||
} | ||
dims_ = dims; | ||
return; | ||
} | ||
|
||
const std::shared_ptr<Tensor::Placeholder>& Tensor::Holder() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that users don't need to and shouldn't know the concept "holder", which is defined as a private type inside Tensor in my design. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return holder_; | ||
} | ||
|
||
const DDim& Tensor::Dims() const { return dims_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put these simple functions into the header file, the compiler can do inline optimization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
const paddle::platform::Place& Tensor::Place() const { return place_; } | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
#pragma once | ||
|
||
#include <memory> | ||
#include <type_traits> | ||
#include <typeinfo> | ||
#include "paddle/framework/ddim.h" | ||
#include "paddle/platform/assert.h" | ||
#include "paddle/platform/place.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class Tensor { | ||
using paddle::platform::Place; | ||
using paddle::platform::get_place; | ||
|
||
public: | ||
explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please be aware that I followed @qingqing01 's suggestion 3 days ago #2611 (review) and removed |
||
explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {} | ||
|
||
Tensor& operator=(const Tensor& src) = delete; | ||
|
||
template <typename T> | ||
const T* Data() const { | ||
PADDLE_ASSERT(holder_ != nullptr); | ||
PADDLE_ASSERT(holder_->Place() == place_); | ||
PADDLE_ASSERT(holder_->Size() >= product(dims_) * sizeof(T)); | ||
return static_cast<const T*>(holder->Ptr()); | ||
} | ||
|
||
template <typename T> | ||
bool NeedReset() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @hedaoyuan 's comment means these functions can be inlined by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is no inline keyword that will also be inline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NeedReset is not something should be exposed. |
||
return (holder_ == nullptr || holder_->Place() != place_ || | ||
holder_->Size() < product(dims_) * sizeof(T)); | ||
} | ||
|
||
// must be POD types | ||
template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I followed @qingqing01 's suggestion and removed multiple signatures of mutable_data. |
||
T* MutableData() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please follow Google C++ style guide carefully -- there is a reason there of using the name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, it' ok to using By the way, do we need to rename functions in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is a good idea. |
||
if (NeedReset<T>()) { | ||
holder_.reset(new PlaceholderImpl(place_, product(dims_) * sizeof(T))); | ||
} | ||
return static_cast<T*>(holder_->Ptr()); | ||
} | ||
|
||
template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type> | ||
T* MutableData(const DDim& dims) { | ||
dims_ = dims; | ||
return MutableData<T>(); | ||
} | ||
|
||
template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type> | ||
T* MutableData(const DDim& dims, const Place& place) { | ||
dims_ = dims; | ||
place_ = place; | ||
return MutableData<T>(); | ||
} | ||
|
||
int Rank() const; | ||
|
||
int Numel() const; | ||
|
||
void Resize(const DDim& dims); | ||
|
||
void Reshape(const DDim& dims); | ||
|
||
template <typename T> | ||
void ShareData(const Tensor& src) { | ||
if (src.NeedReset<T>()) { | ||
// TODO: error: "Src tensor need to be reseted before calling | ||
// ShareData()". | ||
} | ||
holder_ = src.Holder(); | ||
dims_ = src.Dims(); | ||
place_ = src.Place(); | ||
return; | ||
} | ||
|
||
template <typename T> | ||
void CopyFrom(const Tensor& src) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @Xreki . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds like what we need is a method void Tensor::Serialize(TensorProto* proto) {
if (is_gpu_place(Place()) {
Tensor cpu_tensor;
cudaMemcpy(cpu_tensor.mutable_data(Size(), CPUPlace()), data(), Size());
cpu_tensor.Serialize(proto);
} else {
// fill in proto
}
} |
||
if ((void*)&src == (void*)this) { | ||
return; | ||
} | ||
int len = product(src.Dims()); | ||
T* src_ptr = src.Data<T>(); | ||
T* dst_ptr = MutableData<T>(src.Dims()); | ||
for (int i = 0; i < len; ++i) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use memcpy? |
||
dst_ptr[i] = src_ptr[i]; | ||
} | ||
return; | ||
} | ||
|
||
const std::shared_ptr<Placeholder>& Holder() const; | ||
|
||
const DDim& Dims() const; | ||
|
||
const paddle::platform::Place& Place() const; | ||
|
||
template <typename T> | ||
bool IsType() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IsType should be removed because we can interpret a Tensor of any type -- just call |
||
return typeid(T) == holder_.TypeInfo(); | ||
} | ||
|
||
// Placeholder hides type T, so it doesn't appear as a template | ||
struct Placeholder { | ||
virtual ~Placeholder() {} | ||
virtual std::type_info TypeInfo() const = 0; | ||
virtual void* Ptr() const = 0; | ||
virtual Place Place() const = 0; | ||
virtual size_t Size() const = 0; | ||
}; | ||
|
||
private: | ||
template <typename T> | ||
struct PlaceholderImpl : public Placeholder { | ||
PlaceholderImpl(Place place, size_t size) | ||
: ptr_(paddle::memory::Alloc(place, size), | ||
paddle::memory::Deleter(place)), | ||
place_(place), | ||
size_(size) {} | ||
|
||
virtual std::type_info TypeInfo() const { return typeid(T); } | ||
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe Ptr() should not be added in interface, it should return T* like struct Placeholder {
virtual ~Placeholder() {} // for rtti
};
template <typename T>
struct PlaceholderImpl : public Placeholder {
T* Ptr() const { return ptr_.get(); }
};
template <typename T>
const T* Data() const {
auto holder = std::dynamic_pointer_cast<PlaceholderImpl<T>>(holder_);
ASSERT(holder != nullptr);
return holder->Ptr();
} Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See PR #2647 |
||
virtual size_t Size() const { return size_; } | ||
virtual Place Place() const { return place_; } | ||
|
||
std::unique_ptr<T, paddle::memory::Deleter> ptr_; | ||
Place place_; // record the place of ptr_. | ||
size_t size_; // size of the memory block. | ||
}; | ||
|
||
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following @qingqing01 's suggestion, |
||
DDim dims_; // could be smallers than the holder_->Size(). | ||
Place place_; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my design, I don't see that we need
Resize
. Indeed, I followed @qingqing01 's suggestion and removed the ability of the constructor to set the size.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean that if users want to change tensor's size, they can invoke
mutable_data(DDim)
directly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.