From 29c7be683ef2e47563147af0d41130a5959ce508 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Wed, 20 Sep 2023 15:51:39 +0800 Subject: [PATCH] fix parse string (#57314) * fix parse string * fix parse string * fix string * fix string * fix string * fix string * fix codestyle * fix string * fix parse string --------- Co-authored-by: xingmingyyj --- paddle/pir/core/ir_printer.cc | 10 +++- paddle/pir/core/parser/ir_parser.cc | 28 ++++------ paddle/pir/core/parser/ir_parser.h | 2 +- paddle/pir/core/parser/lexer.cc | 28 +++++++--- paddle/pir/core/parser/lexer.h | 3 +- paddle/pir/core/parser/token.h | 2 +- test/cpp/pir/core/TestParserText.txt | 43 ++++++++++++--- test/cpp/pir/core/add_dialect_parser_test.cc | 2 +- test/cpp/pir/core/ir_parser_test.cc | 57 +++++++++++++++++--- 9 files changed, 130 insertions(+), 45 deletions(-) diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index 68a0eb99bc598..52c49be812104 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -87,7 +87,15 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { } if (auto s = attr.dyn_cast()) { - os << "(String)" << s.AsString(); + std::string s_val = s.AsString(); + std::string replacement = "\\\""; + std::string search = "\""; + size_t found = s_val.find(search); + while (found != std::string::npos) { + s_val.replace(found, search.length(), replacement); + found = s_val.find(search, found + replacement.length()); + } + os << "\"" << s_val << "\""; } else if (auto b = attr.dyn_cast()) { if (b.data()) { os << "true"; diff --git a/paddle/pir/core/parser/ir_parser.cc b/paddle/pir/core/parser/ir_parser.cc index 960ba9fd49610..ef235a6e9d8a8 100644 --- a/paddle/pir/core/parser/ir_parser.cc +++ b/paddle/pir/core/parser/ir_parser.cc @@ -24,23 +24,14 @@ IrParser::IrParser(IrContext* ctx, std::istream& is) { builder.reset(new Builder{ctx}); } -Token IrParser::ConsumeToken() { - auto token = lexer->ConsumeToken(); - return token; -} +Token IrParser::ConsumeToken() { return lexer->ConsumeToken(); } std::string IrParser::GetErrorLocationInfo() { return "The error occurred in line " + std::to_string(lexer->GetLine()) + ", column " + std::to_string(lexer->GetColumn()); } -Token IrParser::PeekToken() { - auto token = lexer->ConsumeToken(); - if (token.token_type_ != EOF_) { - lexer->Unget(token.val_.size()); - } - return token; -} +Token IrParser::PeekToken() { return lexer->PeekToken(); } void IrParser::ConsumeAToken(std::string expect_token_val) { std::string token_val = ConsumeToken().val_; @@ -128,14 +119,13 @@ Attribute IrParser::ParseAttribute() { auto parenthesis_token = ConsumeToken(); if (parenthesis_token.val_ == "true" || parenthesis_token.val_ == "false") { return builder->bool_attr(parenthesis_token.val_ == "true"); + } else if (parenthesis_token.token_type_ == STRING) { + std::string val = parenthesis_token.val_; + val = val.substr(1, val.size() - 2); + return builder->str_attr(val); } std::string attribute_type = PeekToken().val_; - if (attribute_type == "String") { - ConsumeAToken("String"); - ConsumeAToken(")"); - std::string val = ConsumeToken().val_; - return builder->str_attr(val); - } else if (attribute_type == "Float") { + if (attribute_type == "Float") { ConsumeAToken("Float"); ConsumeAToken(")"); std::string val = ConsumeToken().val_; @@ -216,7 +206,7 @@ Operation* IrParser::ParseOperation() { OpInfo opinfo = ParseOpInfo(); - std::vector inputs = ParseOprandList(); + std::vector inputs = ParseOperandList(); pir::AttributeMap attributeMap = ParseAttributeMap(); @@ -268,7 +258,7 @@ OpInfo IrParser::ParseOpInfo() { // OprandList := ValueList // ValueList := ValueId(,ValueId)* -std::vector IrParser::ParseOprandList() { +std::vector IrParser::ParseOperandList() { ConsumeAToken("("); std::vector inputs{}; Token ind_token = ConsumeToken(); diff --git a/paddle/pir/core/parser/ir_parser.h b/paddle/pir/core/parser/ir_parser.h index f345e28215f95..a28c1c99de553 100644 --- a/paddle/pir/core/parser/ir_parser.h +++ b/paddle/pir/core/parser/ir_parser.h @@ -51,7 +51,7 @@ class IrParser { std::vector ParseValueList(); - std::vector ParseOprandList(); + std::vector ParseOperandList(); AttributeMap ParseAttributeMap(); diff --git a/paddle/pir/core/parser/lexer.cc b/paddle/pir/core/parser/lexer.cc index c7f037de9927d..9bbfd7dbc804a 100644 --- a/paddle/pir/core/parser/lexer.cc +++ b/paddle/pir/core/parser/lexer.cc @@ -24,7 +24,7 @@ Token Lexer::ConsumeToken() { return *token; } else if (auto token = LexValueId()) { return *token; - } else if (auto token = LexOpName()) { + } else if (auto token = LexString()) { return *token; } else if (auto token = LexEOF()) { return *token; @@ -33,6 +33,16 @@ Token Lexer::ConsumeToken() { } } +Token Lexer::PeekToken() { + auto pos = is.tellg(); + auto token = ConsumeToken(); + if (is.eof()) { + is.clear(); + } + is.seekg(pos); + return token; +} + char Lexer::GetChar() { char c = is.get(); if (c == '\n') { @@ -160,19 +170,23 @@ std::unique_ptr Lexer::LexEOF() { } } -std::unique_ptr Lexer::LexOpName() { +std::unique_ptr Lexer::LexString() { if (is.peek() != '"') { return nullptr; } GetChar(); - std::string token_opname = ""; + std::string token_val = ""; while (is.peek() != '"') { - token_opname += GetChar(); + char c = GetChar(); + if (c == '\\' && is.peek() == '\"') { + c = GetChar(); + } + token_val += c; } GetChar(); - std::unique_ptr opname_token( - new Token{"\"" + token_opname + "\"", OPNAME}); - return opname_token; + std::unique_ptr string_token( + new Token{"\"" + token_val + "\"", STRING}); + return string_token; } bool Lexer::IsSpace(char c) { diff --git a/paddle/pir/core/parser/lexer.h b/paddle/pir/core/parser/lexer.h index 24694eb761317..4b1b94d57dc31 100644 --- a/paddle/pir/core/parser/lexer.h +++ b/paddle/pir/core/parser/lexer.h @@ -28,12 +28,13 @@ class Lexer { explicit Lexer(std::istream& is) : is(is) {} ~Lexer() = default; Token ConsumeToken(); + Token PeekToken(); std::unique_ptr LexIdentifer(); std::unique_ptr LexNumberOrArraow(); std::unique_ptr LexEndTagOrNullVal(); std::unique_ptr LexValueId(); std::unique_ptr LexEOF(); - std::unique_ptr LexOpName(); + std::unique_ptr LexString(); char GetChar(); void SkipWhitespace(); bool IsEndTag(char); diff --git a/paddle/pir/core/parser/token.h b/paddle/pir/core/parser/token.h index 78a20a691c8ac..fb62978a31582 100644 --- a/paddle/pir/core/parser/token.h +++ b/paddle/pir/core/parser/token.h @@ -22,7 +22,7 @@ enum Token_type { SDIGIT = 2, ENDTAG = 3, VALUEID = 4, - OPNAME = 5, + STRING = 5, ARRAOW = 6, NULL_ = 7, }; diff --git a/test/cpp/pir/core/TestParserText.txt b/test/cpp/pir/core/TestParserText.txt index 95c26c61501d1..38c2c9782eba1 100644 --- a/test/cpp/pir/core/TestParserText.txt +++ b/test/cpp/pir/core/TestParserText.txt @@ -1,19 +1,42 @@ //CHECK attribute -(String)sdfgs.sdsd +(Array)[" File \"train.py\", line 225, in ", +" main(args)", +" File \"train.py\", line 197, in main", +" lr_scheduler, args.profiler_options)", +" File \"/home/PaddleClas/ppcls/static/program.py\", line 397, in run", +" fetch_list=fetch_list)", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 1440, in run", +" use_prune=use_prune,", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 1635, in _run_impl", +" scope,", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 801, in get_program_and_executor", +" scope,", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 866, in _get_program_and_executor", +" use_fetch_v2=True,", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 411, in _add_feed_fetch_ops", +" attrs={'col': i},", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/framework.py\", line 4056, in append_op", +" attrs=kwargs.get(\"attrs\", None),", +" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/framework.py\", line 2818, in __init__", +" for frame in traceback.extract_stack():"] +//END //CHECK type f32 +//END //CHECK type pd_op.tensor<256xf32> +//END //CHECK program { - (%0) = "builtin.get_parameter" () {parameter_name:(String)conv2d_0.w_0} : () -> pd_op.tensor<64x3x7x7xf32> - (%1) = "pd_op.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:(String)data,stop_gradient:(Array)[true]} : () -> pd_op.tensor<-1x3x224x224xf32> - (%2) = "pd_op.conv2d" (%1, %0) {data_format:(String)NCHW,dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:(String)EXPLICIT,paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd_op.tensor<-1x3x224x224xf32>, pd_op.tensor<64x3x7x7xf32>) -> pd_op.tensor<-1x64x112x112xf32> + (%0) = "builtin.get_parameter" () {parameter_name:"conv2d_0.w_0"} : () -> pd_op.tensor<64x3x7x7xf32> + (%1) = "pd_op.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:"data",stop_gradient:(Array)[true]} : () -> pd_op.tensor<-1x3x224x224xf32> + (%2) = "pd_op.conv2d" (%1, %0) {data_format:"NCHW",dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:"EXPLICIT",paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd_op.tensor<-1x3x224x224xf32>, pd_op.tensor<64x3x7x7xf32>) -> pd_op.tensor<-1x64x112x112xf32> } +//END //CHECK attribute (Array)[(pd_op.DataType)bool,(pd_op.DataType)float32,(pd_op.DataType)float64, @@ -21,23 +44,27 @@ pd_op.tensor<256xf32> (pd_op.DataType)Undefined,(pd_op.DataType)Undefined,(pd_op.DataType)Undefined, (pd_op.DataType)bfloat16,(pd_op.DataType)uint8,(pd_op.DataType)uint32,(pd_op.DataType)int8, (pd_op.DataType)uint16,(pd_op.DataType)int16,(pd_op.DataType)int32,(pd_op.DataType)uint64,(pd_op.DataType)int64] - +//END //CHECK attribute (Array)[(pd_op.Place)Place(gpu:0),(pd_op.Place)Place(gpu_pinned),(pd_op.Place)Place(gpu_pinned), (pd_op.Place)Place(xpu:0),(pd_op.Place)Place(ipu:0),(pd_op.Place)Place(:0),(pd_op.Place)Place(cpu)] - +//END //CHECK attribute (Array)[(pd_op.DataLayout)NHWC,(pd_op.DataLayout)STRIDED,(pd_op.DataLayout)NCHW,(pd_op.DataLayout)Undefined(AnyLayout), (pd_op.DataLayout)ONEDNN,(pd_op.DataLayout)SPARSE_COO,(pd_op.DataLayout)SPARSE_CSR,(pd_op.DataLayout)NDHWC,(pd_op.DataLayout)NCDHW, (pd_op.DataLayout)PSTRING_UNION] +//END //CHECK attribute -(Array)[(Double)1,(Int64)0,(String)1] +(Array)[(Double)1,(Int64)0,"1"] +//END //CHECK type vec[bf16,f64,b,i8,u8,i16,c64,c128] +//END //CHECK attribute -(String)1 +(Array)["\"","\\"","\\\"","\t\n\r",""] +//END diff --git a/test/cpp/pir/core/add_dialect_parser_test.cc b/test/cpp/pir/core/add_dialect_parser_test.cc index 88c2f31df23b5..d1754e0b438c7 100644 --- a/test/cpp/pir/core/add_dialect_parser_test.cc +++ b/test/cpp/pir/core/add_dialect_parser_test.cc @@ -100,7 +100,7 @@ TEST(IrParserTest, AddAttribute) { std::string op_str = " (%0) = \"builtin.get_parameter\" () " - "{parameter_name:(String)conv2d_0.w_0,test:(tp.char)a} : () -> " + "{parameter_name:\"conv2d_0.w_0\",test:(tp.char)a} : () -> " "pd_op.tensor<64x3x7x7xf32>"; std::stringstream ss; ss << op_str; diff --git a/test/cpp/pir/core/ir_parser_test.cc b/test/cpp/pir/core/ir_parser_test.cc index 1627c2a4982c7..7990d26e8afaf 100644 --- a/test/cpp/pir/core/ir_parser_test.cc +++ b/test/cpp/pir/core/ir_parser_test.cc @@ -60,28 +60,50 @@ class ParserTest { explicit ParserTest(std::ifstream& test_text) : test_text(test_text) {} TestTask* GetTestTask(); bool ConsumeTestTask(TestTask* test_task, pir::IrContext* ctx); + std::string Peek(const size_t len); + std::string Get(const size_t len); }; TestTask* ParserTest::GetTestTask() { + while (test_text.peek() == '\n' || test_text.peek() == ' ') { + test_text.get(); + } + if (test_text.peek() == EOF) { return nullptr; } - std::string test_info; - while (test_text.peek() != '/') { + + while (Peek(7) != "//CHECK" && test_text.peek() != EOF) { test_text.get(); } - while (test_text.peek() != ' ') { + + while (test_text.peek() != ' ' && test_text.peek() != EOF) { test_text.get(); } + test_text.get(); + std::string test_type_info; - while (test_text.peek() != '\n') { + while (test_text.peek() != '\n' && test_text.peek() != ' ' && + test_text.peek() != EOF) { test_type_info += test_text.get(); } - test_text.get(); - while (test_text.peek() != '/' && test_text.peek() != EOF) { + + while (test_text.peek() == '\n' || test_text.peek() == ' ') { + test_text.get(); + } + + std::string test_info; + while (Peek(5) != "//END" && test_text.peek() != EOF) { test_info += test_text.get(); } + + if (Peek(5) != "//END" || test_info.size() == 0) { + return nullptr; + } + + Get(5); + if (test_type_info == "attribute") { return new TestTask(AttributeTest, test_info); } else if (test_type_info == "type") { @@ -89,6 +111,7 @@ TestTask* ParserTest::GetTestTask() { } else if (test_type_info == "program") { return new TestTask(ProgramTest, test_info); } + return nullptr; } @@ -135,6 +158,28 @@ bool ParserTest::ConsumeTestTask(TestTask* test_task, pir::IrContext* ctx) { return true; } +std::string ParserTest::Peek(const size_t len) { + std::string str; + auto pos = test_text.tellg(); + str = Get(len); + if (test_text.eof()) { + test_text.clear(); + } + test_text.seekg(pos); + return str; +} + +std::string ParserTest::Get(const size_t len) { + std::string str; + for (size_t i = 0; i < len; i++) { + if (test_text.peek() == EOF) { + break; + } + str += test_text.get(); + } + return str; +} + TEST(IrParserTest, TestParserByFile) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect();