Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]Support optional input and output for pir api #57492

Merged
merged 11 commits into from
Sep 22, 2023
Merged

Conversation

0x45f
Copy link
Contributor

@0x45f 0x45f commented Sep 19, 2023

PR types

Others

PR changes

APIs

Description

[PIR]Support optional input and output for pir api

// python c api
PyObject *static_api_merged_adam_(PyObject *self, PyObject *args, PyObject *kwargs) {
    try {
        VLOG(6) << "Add merged_adam_ op into program";
        VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);

        // Get Value from args
        PyObject *param_obj = PyTuple_GET_ITEM(args, 0);
        auto param = CastPyArg2VectorOfValue(param_obj, "merged_adam_", 0);
        PyObject *grad_obj = PyTuple_GET_ITEM(args, 1);
        auto grad = CastPyArg2VectorOfValue(grad_obj, "merged_adam_", 1);
        PyObject *learning_rate_obj = PyTuple_GET_ITEM(args, 2);
        auto learning_rate = CastPyArg2VectorOfValue(learning_rate_obj, "merged_adam_", 2);
        PyObject *moment1_obj = PyTuple_GET_ITEM(args, 3);
        auto moment1 = CastPyArg2VectorOfValue(moment1_obj, "merged_adam_", 3);
        PyObject *moment2_obj = PyTuple_GET_ITEM(args, 4);
        auto moment2 = CastPyArg2VectorOfValue(moment2_obj, "merged_adam_", 4);
        PyObject *beta1_pow_obj = PyTuple_GET_ITEM(args, 5);
        auto beta1_pow = CastPyArg2VectorOfValue(beta1_pow_obj, "merged_adam_", 5);
        PyObject *beta2_pow_obj = PyTuple_GET_ITEM(args, 6);
        auto beta2_pow = CastPyArg2VectorOfValue(beta2_pow_obj, "merged_adam_", 6);
        PyObject *master_param_obj = PyTuple_GET_ITEM(args, 7);
        auto master_param = CastPyArg2OptionalVectorOfValue(master_param_obj, "merged_adam_", 7);

        // Parse Attributes
        PyObject *beta1_obj = PyTuple_GET_ITEM(args, 8);
        PyObject *beta2_obj = PyTuple_GET_ITEM(args, 9);
        PyObject *epsilon_obj = PyTuple_GET_ITEM(args, 10);
        PyObject *multi_precision_obj = PyTuple_GET_ITEM(args, 11);
        PyObject *use_global_beta_pow_obj = PyTuple_GET_ITEM(args, 12);

        // Check for mutable attrs
       pir::Value beta1;

       pir::Value beta2;

       pir::Value epsilon;

        if (PyObject_CheckIROpResult(beta1_obj)){
             beta1 = CastPyArg2Value(beta1_obj, "merged_adam_", 8);
        }else{
            float beta1_tmp = CastPyArg2Float(beta1_obj, "merged_adam_", 8);
            beta1 = paddle::dialect::full(std::vector<int64_t>{1}, beta1_tmp, phi::DataType::FLOAT32, phi::CPUPlace());

        }
        if (PyObject_CheckIROpResult(beta2_obj)){
             beta2 = CastPyArg2Value(beta2_obj, "merged_adam_", 9);
        }else{
            float beta2_tmp = CastPyArg2Float(beta2_obj, "merged_adam_", 9);
            beta2 = paddle::dialect::full(std::vector<int64_t>{1}, beta2_tmp, phi::DataType::FLOAT32, phi::CPUPlace());

        }
        if (PyObject_CheckIROpResult(epsilon_obj)){
             epsilon = CastPyArg2Value(epsilon_obj, "merged_adam_", 10);
        }else{
            float epsilon_tmp = CastPyArg2Float(epsilon_obj, "merged_adam_", 10);
            epsilon = paddle::dialect::full(std::vector<int64_t>{1}, epsilon_tmp, phi::DataType::FLOAT32, phi::CPUPlace());

        }
            bool multi_precision = CastPyArg2Boolean(multi_precision_obj, "merged_adam_", 11);
            bool use_global_beta_pow = CastPyArg2Boolean(use_global_beta_pow_obj, "merged_adam_", 12);

        // Call ir static api
        auto static_api_out = paddle::dialect::merged_adam_(param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, beta1, beta2, epsilon, multi_precision, use_global_beta_pow);
        return ToPyObject(static_api_out);


    } catch (...) {
        ThrowExceptionToPython(std::current_exception());
        return nullptr;
    }
}


// c++ api
std::tuple<std::vector<pir::OpResult>, std::vector<pir::OpResult>, std::vector<pir::OpResult>, std::vector<pir::OpResult>, std::vector<pir::OpResult>, paddle::optional<std::vector<pir::OpResult>>> merged_adam_(std::vector<pir::Value> param, std::vector<pir::Value> grad, std::vector<pir::Value> learning_rate, std::vector<pir::Value> moment1, std::vector<pir::Value> moment2, std::vector<pir::Value> beta1_pow, std::vector<pir::Value> beta2_pow, paddle::optional<std::vector<pir::Value>> master_param, float beta1, float beta2, float epsilon, bool multi_precision, bool use_global_beta_pow){
    paddle::optional<pir::Value> optional_master_param;
    if (!master_param) {
        optional_master_param = paddle::make_optional<pir::Value>(pir::Value());
    } else {
        auto optional_master_param_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(master_param.get());
        optional_master_param = paddle::make_optional<pir::Value>(optional_master_param_combine_op.out());
    }
    auto param_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(param);
    auto grad_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(grad);
    auto learning_rate_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(learning_rate);
    auto moment1_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(moment1);
    auto moment2_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(moment2);
    auto beta1_pow_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(beta1_pow);
    auto beta2_pow_combine_op = APIBuilder::Instance().GetBuilder()->Build<pir::CombineOp>(beta2_pow);
    paddle::dialect::MergedAdam_Op merged_adam__op = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MergedAdam_Op>(param_combine_op.out(), grad_combine_op.out(), learning_rate_combine_op.out(), moment1_combine_op.out(), moment2_combine_op.out(), beta1_pow_combine_op.out(), beta2_pow_combine_op.out(), optional_master_param.get(), beta1, beta2, epsilon, multi_precision, use_global_beta_pow);
    paddle::optional<std::vector<pir::OpResult>> optional_master_param_out;
    if (!IsEmptyOpResult(merged_adam__op.result(5))) {
        auto optional_master_param_out_slice_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>(merged_adam__op.result(5));
        optional_master_param_out = paddle::make_optional<std::vector<pir::OpResult>>(optional_master_param_out_slice_op.outputs());
    }
    auto param_out_split_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>(merged_adam__op.result(0));
    auto moment1_out_split_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>(merged_adam__op.result(1));
    auto moment2_out_split_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>(merged_adam__op.result(2));
    auto beta1_pow_out_split_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>(merged_adam__op.result(3));
    auto beta2_pow_out_split_op = APIBuilder::Instance().GetBuilder()->Build<pir::SplitOp>(merged_adam__op.result(4));
    return std::make_tuple(param_out_split_op.outputs(), moment1_out_split_op.outputs(), moment2_out_split_op.outputs(), beta1_pow_out_split_op.outputs(), beta2_pow_out_split_op.outputs(), optional_master_param_out);
}

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Sep 19, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Aurelius84
Aurelius84 previously approved these changes Sep 20, 2023
Aurelius84
Aurelius84 previously approved these changes Sep 21, 2023
changeyoung98
changeyoung98 previously approved these changes Sep 21, 2023
Charles-hit
Charles-hit previously approved these changes Sep 21, 2023
@0x45f 0x45f merged commit 3b65af2 into develop Sep 22, 2023
27 checks passed
@0x45f 0x45f deleted the support-optional-api branch September 22, 2023 08:11
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 16, 2023
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants