Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor draft for review #2645

Closed
wants to merge 8 commits into from
137 changes: 137 additions & 0 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#pragma once

#include <memory>
#include <type_traits>
#include <typeinfo>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace framework {

class Tensor {
using paddle::platform::Place;
using paddle::platform::get_place;

public:
explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please be aware that I followed @qingqing01 's suggestion 3 days ago #2611 (review) and removed Tensor::Tensor, Tensor::place_, Tensor::dims_ in my PR #2611 so to make a concise and consistent syntax.

explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {}

Tensor& operator=(const Tensor& src) = delete;

template <typename T>
const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr);
PADDLE_ENFORCE(holder_->Place() == place_);
PADDLE_ENFORCE(holder_->Size() >= product(dims_) * sizeof(T));
return static_cast<const T*>(holder->Ptr());
}

template <typename T>
bool NeedReset() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @hedaoyuan 's comment means these functions can be inlined by using inline keyword.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no inline keyword that will also be inline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NeedReset is not something should be exposed.

return (holder_ == nullptr || holder_->Place() != place_ ||
holder_->Size() < product(dims_) * sizeof(T));
}

// must be POD types
template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I followed @qingqing01 's suggestion and removed multiple signatures of mutable_data.

T* mutable_data() {
if (NeedReset<T>()) {
holder_.reset(new PlaceholderImpl(place_, product(dims_) * sizeof(T)));
}
return static_cast<T*>(holder_->Ptr());
}

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(const DDim& dims) {
dims_ = dims;
return mutable_data<T>();
}

template <typename T, typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(const DDim& dims, const Place& place) {
dims_ = dims;
place_ = place;
return mutable_data<T>();
}

int Rank() const { return arity(dims_); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed following @qingqing01 's suggestion.


int Numel() const { return product(dims_); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is should be removed too.


void Resize(const DDim& dims) { dims_ = dims; }

void Reshape(const DDim& dims) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need for Reshape, because we can simply call mutable_data(new_place, new_dim).

PADDLE_ENFORCE(product(dims) == product(dims_),
"Reshape() can not change tensor's numel!");
dims_ = dims;
}

template <typename T>
void ShareData(const Tensor& src) {
PADDLE_ENFORCE(!src.NeedReset<T>(),
"Src tensor need to be reseted before calling ShareData().");
holder_ = src.holder_;
dims_ = src.dims_;
place_ = src.place_;
}

template <typename T>
void CopyFrom(const Tensor& src) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need CopyFrom?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need CopyFrom when the src Tensor is on GPU. We may need to copy GPU Tensor to CPU and do something such as writing to a file.

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jun 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @Xreki .

Copy link
Collaborator

@wangkuiyi wangkuiyi Jun 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like what we need is a method Serialize that fills in a protobuf message; instead a Copy or CopyFrom?

void Tensor::Serialize(TensorProto* proto) {
  if (is_gpu_place(Place()) {
    Tensor cpu_tensor;
    cudaMemcpy(cpu_tensor.mutable_data(Size(), CPUPlace()), data(), Size());
    cpu_tensor.Serialize(proto);
  } else {
    // fill in proto
  }
}

if ((void*)&src == (void*)this) {
return;
}
int len = product(src.Dims());
T* src_ptr = src.data<T>();
T* dst_ptr = mutable_data<T>(src.Dims());
for (int i = 0; i < len; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use memcpy?

dst_ptr[i] = src_ptr[i];
}
}

const DDim& Dims() const { return dims_; }

const paddle::platform::Place& Place() const { return place_; }

template <typename T>
bool IsType() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsType should be removed because we can interpret a Tensor of any type -- just call mutable_data<T>(place, dim).

return typeid(T) == holder_.TypeInfo();
}

private:
// Placeholder hides type T, so it doesn't appear as a template
struct Placeholder {
virtual ~Placeholder() {}
virtual std::type_info TypeInfo() const = 0;
virtual void* Ptr() const = 0;
virtual Place Place() const = 0;
virtual size_t Size() const = 0;
};

template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place place, size_t size)
: ptr_(paddle::memory::Alloc(place, size),
paddle::memory::Deleter(place)),
place_(place),
size_(size) {}

virtual std::type_info TypeInfo() const { return typeid(T); }
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Ptr() should not be added in interface, it should return T* like

struct Placeholder {
  virtual ~Placeholder() {}  // for rtti
};

template <typename T>
struct PlaceholderImpl : public Placeholder {
  T* Ptr() const { return ptr_.get(); }
};

template <typename T>
const T* Data() const {
  auto holder = std::dynamic_pointer_cast<PlaceholderImpl<T>>(holder_);
  ASSERT(holder != nullptr);
  return holder->Ptr();
}

Using static_cast everywhere make compiler cannot check type for us.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR #2647

virtual size_t Size() const { return size_; }
virtual Place Place() const { return place_; }

std::unique_ptr<T, paddle::memory::Deleter> ptr_;
Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};

std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following @qingqing01 's suggestion, dims_ and place should be removed.

DDim dims_; // could be smallers than the holder_->Size().
Place place_;
};

} // namespace framework
} // namespace paddle