-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move one hot to phi #39876
Move one hot to phi #39876
Conversation
Thanks for your contribution! |
… move_one_hot_to_phi
… move_one_hot_to_phi
auto out_dims_vec = phi::vectorize(x_dims); | ||
out_dims_vec.push_back(depth); | ||
auto out_dims = phi::make_ddim(out_dims_vec); | ||
out->set_dims(out_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
}; | ||
out->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用dev_ctx.Alloc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
}; | ||
out->mutable_data<T>(dev_ctx.GetPlace()); | ||
paddle::framework::VisitDataType( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉可以用PD_VISIT_ALL_TYPES
替换framework::VisitDataType
,这样就不用转成proto::VarType了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
#include "paddle/phi/kernels/one_hot_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one_hot_kernel.h按照规范推荐放在开头
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里好像并未按建议修改
void apply() const { | ||
auto* p_in_data = in_->data<InT>(); | ||
auto numel = in_->numel(); | ||
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.Alloc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
#pragma once | ||
|
||
#include "paddle/phi/common/scalar.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scalar.h好像没有用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -477,6 +477,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, | |||
"Unsupported attribute type is received when call " | |||
"InferShapeFunctor.")); | |||
} | |||
} else if (ctx->HasInput(attr_name)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支在前面好像有了?和前面合并一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然scalar去不掉的话,我们是否有必要单独为depth增加这几处分支
}; | ||
out->mutable_data<T>(dev_ctx.GetPlace()); | ||
paddle::framework::VisitDataType( | ||
static_cast<paddle::framework::proto::VarType::Type>(dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个值不相等的吧,如果想用这个,core/utils/data_type.h也有phi的VisitDataType,但这里确实有点乱了,需要整理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
out->Resize(out_dims); | ||
} | ||
dev_ctx.template Alloc<T>(out); | ||
paddle::framework::VisitDataType( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
}; | ||
|
||
template <typename T, typename Context> | ||
void OneHotKernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名也直接叫Raw?其实有raw的,也得同时注册下非row的kernel,我们自己的话最好迁全一点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneHot是不是应该直接调用OneHotRaw?然后参考math_kernel中的那些,直接在kernels根目录下加kernel
paddle/phi/ops/compat/one_hot_sig.cc
Outdated
|
||
} // namespace phi | ||
|
||
PD_REGISTER_BASE_KERNEL_NAME(one_hot_v2, one_hot_raw); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里映射到raw不太好,我们有个隐含的原则,有xxx_raw,必有xxx,不能只有raw kernel,不然raw kernel就应该直接注册为非raw的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
… move_one_hot_to_phi
… move_one_hot_to_phi
… move_one_hot_to_phi
… move_one_hot_to_phi
… move_one_hot_to_phi
… move_one_hot_to_phi
void OneHotRawKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
int32_t depth, | ||
int dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以直接使用DataType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
… move_one_hot_to_phi
… move_one_hot_to_phi
… move_one_hot_to_phi
… move_one_hot_to_phi
… into move_one_hot_to_phi
… into move_one_hot_to_phi
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/math_function.h" | ||
#include "paddle/phi/kernels/one_hot_kernel.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里好像并未按建议修改
}; | ||
|
||
template <typename T, typename Context> | ||
void OneHotKernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OneHot是不是应该直接调用OneHotRaw?然后参考math_kernel中的那些,直接在kernels根目录下加kernel
} | ||
|
||
template <typename T, typename Context> | ||
void OneHotKernel(const Context& dev_ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -477,6 +477,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, | |||
"Unsupported attribute type is received when call " | |||
"InferShapeFunctor.")); | |||
} | |||
} else if (ctx->HasInput(attr_name)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然scalar去不掉的话,我们是否有必要单独为depth增加这几处分支
… move_one_hot_to_phi
… move_one_hot_to_phi
PR types
Breaking changes
PR changes
OPs
Describe
move one hot to phi