-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix bug for binary_cross_entropy_with_logits loss #54869
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3410,5 +3410,61 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, | |
out_count->set_dims({-1}); | ||
out_count->set_dtype(DataType::INT32); | ||
} | ||
|
||
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 函数位置按字典序放置 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const MetaTensor& label, | ||
const MetaTensor& pos_weight, | ||
bool normalize, | ||
int ignore_index, | ||
MetaTensor* out, | ||
MetaConfig config) { | ||
auto x_dims = x.dims(); | ||
auto labels_dims = label.dims(); | ||
int rank = x_dims.size(); | ||
PADDLE_ENFORCE_EQ(rank, | ||
labels_dims.size(), | ||
phi::errors::InvalidArgument( | ||
"Input(X) and Input(Label) shall have the same rank." | ||
"But received: the rank of Input(X) is [%d], " | ||
"the rank of Input(Label) is [%d].", | ||
rank, | ||
labels_dims.size())); | ||
|
||
bool check = true; | ||
if ((!config.is_runtime) && | ||
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) { | ||
check = false; | ||
} | ||
|
||
if (check) { | ||
PADDLE_ENFORCE_EQ( | ||
phi::slice_ddim(x_dims, 0, rank), | ||
phi::slice_ddim(labels_dims, 0, rank), | ||
phi::errors::InvalidArgument( | ||
"Input(X) and Input(Label) shall have the same shape " | ||
"except the last dimension. But received: the shape of " | ||
"Input(X) is [%s], the shape of Input(Label) is [%s].", | ||
x_dims, | ||
labels_dims)); | ||
|
||
if (pos_weight) { | ||
auto weight_dims = pos_weight.dims(); | ||
PADDLE_ENFORCE_EQ( | ||
phi::slice_ddim(weight_dims, 0, rank), | ||
phi::slice_ddim(labels_dims, 0, rank), | ||
phi::errors::InvalidArgument( | ||
"Input(PosWeight) and Input(Label) shall have the same shape " | ||
"But received: the shape of Input(PosWeight) is [%s], " | ||
"the shape of Input(Label) is [%s].", | ||
weight_dims, | ||
labels_dims)); | ||
} | ||
} | ||
|
||
out->set_dims(x_dims); | ||
out->set_dtype(x.dtype()); | ||
out->share_lod(x); | ||
} | ||
|
||
} // namespace phi | ||
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -646,4 +646,12 @@ void MoeInferMeta(const MetaTensor& x, | |
const std::string& act_type, | ||
MetaTensor* out); | ||
|
||
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const MetaTensor& label, | ||
const MetaTensor& pos_weight, | ||
bool normalize, | ||
int ignore_index, | ||
MetaTensor* out, | ||
MetaConfig config = MetaConfig()); | ||
|
||
} // namespace phi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不用改,这是为了兼容原先的写法的,新增参数,不需要兼容旧的,后续新IR重构后原先驼峰式的命名写法都会删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done