Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor draft for review #2645

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions paddle/framework/tensor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "paddle/framework/tensor.h"

namespace paddle {
namespace framework {

template <typename T>
const T* Tensor::Data() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Template should implement in header.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ASSERT(holder_ != nullptr);
PADDLE_ASSERT(holder_->Place() == place_);
PADDLE_ASSERT(holder_->Size() >= product(dims_) * sizeof(T));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check the type of holder_ here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worth discussing, I think.

return static_cast<const T*>(holder->Ptr());
}

bool Tensor::NeedReset() const {
return (holder_ == nullptr || holder_->Place() != place_ ||
holder_->Size() < product(dims_) * sizeof(T));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no T. Do there need a template? And again, do we need to check the type of holder_ here?

}

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* Tensor::MutableData() {
if (NeedReset()) {
holder_.reset(new PlaceholderImpl(place_, product(dims_) * sizeof(T)));
}
return static_cast<T*>(holder_->Ptr());
}

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* Tensor::MutableData(const DDim& dims) {
dims_ = dims;
return MutableData<T>();
}

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* Tensor::MutableData(const DDim& dims, const Place& place) {
dims_ = dims;
place_ = place;
return MutableData<T>();
}

int Tensor::Rank() const { return arity(dims_); }

int Tensor::Numel() const { return product(dims_); }

void Tensor::Resize(const DDim& dims) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my design, I don't see that we need Resize. Indeed, I followed @qingqing01 's suggestion and removed the ability of the constructor to set the size.

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jun 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that if users want to change tensor's size, they can invoke mutable_data(DDim) directly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

dims_ = dims;
return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add so many return at the end of a void func()

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jun 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just my personal habit... I'm used to adding return; to mark all return points of a void function.
Of course, they can be removed, if they break common code style or might cause confusion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please strictly follow a common style as we work as a team. Let's remove all unnecessary code.

}

void Tensor::Reshape(const DDim& dims) {
if (product(dims) != product(dims_)) {
// TODO: error: "Reshape() can not change tensor's numel".
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can not reshape change tensor's numel?

In each mini-batch training, each output of Op should be Reshape, because the shape of input tensor could be changed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users can use Resize() to change tensor's numel.

}
dims_ = dims;
return;
}

void Tensor::ShareData(const Tensor& src) {
if (src.NeedReset()) {
// TODO: error: "Src tensor need to be reseted before calling ShareData()".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of reset? Does it actually mean resize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NeedReset means the underlying memory block needs to be re-allocated.

}
holder_ = src.Holder();
dims_ = src.Dims();
place_ = src.Place();
return;
}

template <typename T>
void Tensor::CopyFrom(const Tensor& src) {
if ((void*)&src == (void*)this) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need more checking? Is CopyFrom always success?

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jun 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some checking in Tensor::Data(). I can't find out more potential problems for the moment... Do you have any idea?

return;
}
int len = product(src.Dims());
T* src_ptr = src.Data<T>();
T* dst_ptr = MutableData<T>(src.Dims());
for (int i = 0; i < len; ++i) {
dst_ptr[i] = src_ptr[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct for tensor with GPUPlace.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's just for CPU at present. GPU part will be implemented later.

}
return;
}

const std::shared_ptr<Tensor::Placeholder>& Tensor::Holder() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that users don't need to and shouldn't know the concept "holder", which is defined as a private type inside Tensor in my design.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return holder_;
}

const DDim& Tensor::Dims() const { return dims_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put these simple functions into the header file, the compiler can do inline optimization.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


const paddle::platform::Place& Tensor::Place() const { return place_; }

template <typename T>
bool Tensor::IsType() const {
return typeid(T) == holder_.TypeInfo();
}

} // namespace framework
} // namespace paddle
94 changes: 94 additions & 0 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#pragma once

#include <memory>
#include <type_traits>
#include <typeinfo>
#include "paddle/framework/ddim.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace framework {

class Tensor {
using paddle::platform::Place;
using paddle::platform::get_place;

public:
explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please be aware that I followed @qingqing01 's suggestion 3 days ago #2611 (review) and removed Tensor::Tensor, Tensor::place_, Tensor::dims_ in my PR #2611 so to make a concise and consistent syntax.

explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {}

Tensor& operator=(const Tensor& src) = delete;

template <typename T>
const T* Data() const;

bool NeedReset() const;

// must be POD types
template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I followed @qingqing01 's suggestion and removed multiple signatures of mutable_data.

T* MutableData();

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* MutableData(const DDim& dims);

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* MutableData(const DDim& dims, const Place& place);

int Rank() const;

int Numel() const;

void Resize(const DDim& dims);

void Reshape(const DDim& dims);

void ShareData(const Tensor& src);

template <typename T>
void CopyFrom(const Tensor& src);

const std::shared_ptr<Placeholder>& Holder() const;

const DDim& Dims() const;

const paddle::platform::Place& Place() const;

template <typename T>
bool IsType() const;

// Placeholder hides type T, so it doesn't appear as a template
struct Placeholder {
virtual ~Placeholder() {}
virtual std::type_info TypeInfo() const = 0;
virtual void* Ptr() const = 0;
virtual Place Place() const = 0;
virtual size_t Size() const = 0;
};

private:
template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place place, size_t size)
: ptr_(paddle::memory::Alloc(place, size),
paddle::memory::Deleter(place)),
place_(place),
size_(size) {}

virtual std::type_info TypeInfo() const { return typeid(T); }
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Ptr() should not be added in interface, it should return T* like

struct Placeholder {
  virtual ~Placeholder() {}  // for rtti
};

template <typename T>
struct PlaceholderImpl : public Placeholder {
  T* Ptr() const { return ptr_.get(); }
};

template <typename T>
const T* Data() const {
  auto holder = std::dynamic_pointer_cast<PlaceholderImpl<T>>(holder_);
  ASSERT(holder != nullptr);
  return holder->Ptr();
}

Using static_cast everywhere make compiler cannot check type for us.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR #2647

virtual size_t Size() const { return size_; }
virtual Place Place() const { return place_; }

std::unique_ptr<T, paddle::memory::Deleter> ptr_;
Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};

std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following @qingqing01 's suggestion, dims_ and place should be removed.

DDim dims_; // could be smallers than the holder_->Size().
Place place_;
};

} // namespace framework
} // namespace paddle