Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Tensor and LoDTensor class #6798

Closed
tonyyang-svail opened this issue Dec 20, 2017 · 2 comments
Closed

Merge Tensor and LoDTensor class #6798

tonyyang-svail opened this issue Dec 20, 2017 · 2 comments

Comments

@tonyyang-svail
Copy link

In protocol buffer, tensor and lodtensor are of the same type.

enum VarType {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
}

This consistency should stay true at the class level.

@jacquesqiao
Copy link
Member

why do we need to merge these two classes or Why do we need to keep these two classes?

@tonyyang-svail
Copy link
Author

tonyyang-svail commented Dec 20, 2017

@jacquesqiao In the current implementation:

const Tensor* x = context.Input<Tensor>("X");

Calls

template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

Calls

template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
}

Calls

static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
t = &(var->Get<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
return t;
}

Meaning, even though context.Input<Tensor>("X"); is called on the surface, deep down the code the type returned is LoDTensor. This is fundamentally wrong.

We have to remove Tensor class so that the code is less error-prone. For example, let's say someone calls var->Mutable<Tensor>, but the underlying type is LoDTensor, the call would return a brand new tensor. This is VERY likely since we set LoDTensor as a default varType.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants