Skip to content

Commit

Permalink
fix parse string (PaddlePaddle#57314)
Browse files Browse the repository at this point in the history
* fix parse string

* fix parse string

* fix string

* fix string

* fix string

* fix string

* fix codestyle

* fix string

* fix parse string

---------

Co-authored-by: xingmingyyj <zxm_3719@163.com>
  • Loading branch information
2 people authored and Frida-a committed Oct 14, 2023
1 parent 8995502 commit 29c7be6
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 45 deletions.
10 changes: 9 additions & 1 deletion paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,15 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
}

if (auto s = attr.dyn_cast<StrAttribute>()) {
os << "(String)" << s.AsString();
std::string s_val = s.AsString();
std::string replacement = "\\\"";
std::string search = "\"";
size_t found = s_val.find(search);
while (found != std::string::npos) {
s_val.replace(found, search.length(), replacement);
found = s_val.find(search, found + replacement.length());
}
os << "\"" << s_val << "\"";
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
if (b.data()) {
os << "true";
Expand Down
28 changes: 9 additions & 19 deletions paddle/pir/core/parser/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,14 @@ IrParser::IrParser(IrContext* ctx, std::istream& is) {
builder.reset(new Builder{ctx});
}

Token IrParser::ConsumeToken() {
auto token = lexer->ConsumeToken();
return token;
}
Token IrParser::ConsumeToken() { return lexer->ConsumeToken(); }

std::string IrParser::GetErrorLocationInfo() {
return "The error occurred in line " + std::to_string(lexer->GetLine()) +
", column " + std::to_string(lexer->GetColumn());
}

Token IrParser::PeekToken() {
auto token = lexer->ConsumeToken();
if (token.token_type_ != EOF_) {
lexer->Unget(token.val_.size());
}
return token;
}
Token IrParser::PeekToken() { return lexer->PeekToken(); }

void IrParser::ConsumeAToken(std::string expect_token_val) {
std::string token_val = ConsumeToken().val_;
Expand Down Expand Up @@ -128,14 +119,13 @@ Attribute IrParser::ParseAttribute() {
auto parenthesis_token = ConsumeToken();
if (parenthesis_token.val_ == "true" || parenthesis_token.val_ == "false") {
return builder->bool_attr(parenthesis_token.val_ == "true");
} else if (parenthesis_token.token_type_ == STRING) {
std::string val = parenthesis_token.val_;
val = val.substr(1, val.size() - 2);
return builder->str_attr(val);
}
std::string attribute_type = PeekToken().val_;
if (attribute_type == "String") {
ConsumeAToken("String");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->str_attr(val);
} else if (attribute_type == "Float") {
if (attribute_type == "Float") {
ConsumeAToken("Float");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
Expand Down Expand Up @@ -216,7 +206,7 @@ Operation* IrParser::ParseOperation() {

OpInfo opinfo = ParseOpInfo();

std::vector<Value> inputs = ParseOprandList();
std::vector<Value> inputs = ParseOperandList();

pir::AttributeMap attributeMap = ParseAttributeMap();

Expand Down Expand Up @@ -268,7 +258,7 @@ OpInfo IrParser::ParseOpInfo() {

// OprandList := ValueList
// ValueList := ValueId(,ValueId)*
std::vector<Value> IrParser::ParseOprandList() {
std::vector<Value> IrParser::ParseOperandList() {
ConsumeAToken("(");
std::vector<Value> inputs{};
Token ind_token = ConsumeToken();
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/parser/ir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class IrParser {

std::vector<std::string> ParseValueList();

std::vector<Value> ParseOprandList();
std::vector<Value> ParseOperandList();

AttributeMap ParseAttributeMap();

Expand Down
28 changes: 21 additions & 7 deletions paddle/pir/core/parser/lexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Token Lexer::ConsumeToken() {
return *token;
} else if (auto token = LexValueId()) {
return *token;
} else if (auto token = LexOpName()) {
} else if (auto token = LexString()) {
return *token;
} else if (auto token = LexEOF()) {
return *token;
Expand All @@ -33,6 +33,16 @@ Token Lexer::ConsumeToken() {
}
}

Token Lexer::PeekToken() {
auto pos = is.tellg();
auto token = ConsumeToken();
if (is.eof()) {
is.clear();
}
is.seekg(pos);
return token;
}

char Lexer::GetChar() {
char c = is.get();
if (c == '\n') {
Expand Down Expand Up @@ -160,19 +170,23 @@ std::unique_ptr<Token> Lexer::LexEOF() {
}
}

std::unique_ptr<Token> Lexer::LexOpName() {
std::unique_ptr<Token> Lexer::LexString() {
if (is.peek() != '"') {
return nullptr;
}
GetChar();
std::string token_opname = "";
std::string token_val = "";
while (is.peek() != '"') {
token_opname += GetChar();
char c = GetChar();
if (c == '\\' && is.peek() == '\"') {
c = GetChar();
}
token_val += c;
}
GetChar();
std::unique_ptr<Token> opname_token(
new Token{"\"" + token_opname + "\"", OPNAME});
return opname_token;
std::unique_ptr<Token> string_token(
new Token{"\"" + token_val + "\"", STRING});
return string_token;
}

bool Lexer::IsSpace(char c) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/core/parser/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ class Lexer {
explicit Lexer(std::istream& is) : is(is) {}
~Lexer() = default;
Token ConsumeToken();
Token PeekToken();
std::unique_ptr<Token> LexIdentifer();
std::unique_ptr<Token> LexNumberOrArraow();
std::unique_ptr<Token> LexEndTagOrNullVal();
std::unique_ptr<Token> LexValueId();
std::unique_ptr<Token> LexEOF();
std::unique_ptr<Token> LexOpName();
std::unique_ptr<Token> LexString();
char GetChar();
void SkipWhitespace();
bool IsEndTag(char);
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/parser/token.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ enum Token_type {
SDIGIT = 2,
ENDTAG = 3,
VALUEID = 4,
OPNAME = 5,
STRING = 5,
ARRAOW = 6,
NULL_ = 7,
};
Expand Down
43 changes: 35 additions & 8 deletions test/cpp/pir/core/TestParserText.txt
Original file line number Diff line number Diff line change
@@ -1,43 +1,70 @@

//CHECK attribute
(String)sdfgs.sdsd
(Array)[" File \"train.py\", line 225, in <module>",
" main(args)",
" File \"train.py\", line 197, in main",
" lr_scheduler, args.profiler_options)",
" File \"/home/PaddleClas/ppcls/static/program.py\", line 397, in run",
" fetch_list=fetch_list)",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 1440, in run",
" use_prune=use_prune,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 1635, in _run_impl",
" scope,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 801, in get_program_and_executor",
" scope,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 866, in _get_program_and_executor",
" use_fetch_v2=True,",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/executor.py\", line 411, in _add_feed_fetch_ops",
" attrs={'col': i},",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/framework.py\", line 4056, in append_op",
" attrs=kwargs.get(\"attrs\", None),",
" File \"/home/dy2stUpgrade/Paddle/build/python/paddle/fluid/framework.py\", line 2818, in __init__",
" for frame in traceback.extract_stack():"]
//END

//CHECK type
f32
//END

//CHECK type
pd_op.tensor<256xf32>
//END

//CHECK program
{
(%0) = "builtin.get_parameter" () {parameter_name:(String)conv2d_0.w_0} : () -> pd_op.tensor<64x3x7x7xf32>
(%1) = "pd_op.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:(String)data,stop_gradient:(Array)[true]} : () -> pd_op.tensor<-1x3x224x224xf32>
(%2) = "pd_op.conv2d" (%1, %0) {data_format:(String)NCHW,dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:(String)EXPLICIT,paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd_op.tensor<-1x3x224x224xf32>, pd_op.tensor<64x3x7x7xf32>) -> pd_op.tensor<-1x64x112x112xf32>
(%0) = "builtin.get_parameter" () {parameter_name:"conv2d_0.w_0"} : () -> pd_op.tensor<64x3x7x7xf32>
(%1) = "pd_op.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:"data",stop_gradient:(Array)[true]} : () -> pd_op.tensor<-1x3x224x224xf32>
(%2) = "pd_op.conv2d" (%1, %0) {data_format:"NCHW",dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:"EXPLICIT",paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd_op.tensor<-1x3x224x224xf32>, pd_op.tensor<64x3x7x7xf32>) -> pd_op.tensor<-1x64x112x112xf32>
}
//END

//CHECK attribute
(Array)[(pd_op.DataType)bool,(pd_op.DataType)float32,(pd_op.DataType)float64,
(pd_op.DataType)complex64,(pd_op.DataType)complex128,(pd_op.DataType)Undefined,
(pd_op.DataType)Undefined,(pd_op.DataType)Undefined,(pd_op.DataType)Undefined,
(pd_op.DataType)bfloat16,(pd_op.DataType)uint8,(pd_op.DataType)uint32,(pd_op.DataType)int8,
(pd_op.DataType)uint16,(pd_op.DataType)int16,(pd_op.DataType)int32,(pd_op.DataType)uint64,(pd_op.DataType)int64]

//END

//CHECK attribute
(Array)[(pd_op.Place)Place(gpu:0),(pd_op.Place)Place(gpu_pinned),(pd_op.Place)Place(gpu_pinned),
(pd_op.Place)Place(xpu:0),(pd_op.Place)Place(ipu:0),(pd_op.Place)Place(:0),(pd_op.Place)Place(cpu)]

//END

//CHECK attribute
(Array)[(pd_op.DataLayout)NHWC,(pd_op.DataLayout)STRIDED,(pd_op.DataLayout)NCHW,(pd_op.DataLayout)Undefined(AnyLayout),
(pd_op.DataLayout)ONEDNN,(pd_op.DataLayout)SPARSE_COO,(pd_op.DataLayout)SPARSE_CSR,(pd_op.DataLayout)NDHWC,(pd_op.DataLayout)NCDHW,
(pd_op.DataLayout)PSTRING_UNION]
//END

//CHECK attribute
(Array)[(Double)1,(Int64)0,(String)1]
(Array)[(Double)1,(Int64)0,"1"]
//END

//CHECK type
vec[bf16,f64,b,i8,u8,i16,c64,c128]
//END

//CHECK attribute
(String)1
(Array)["\"","\\"","\\\"","\t\n\r",""]
//END
2 changes: 1 addition & 1 deletion test/cpp/pir/core/add_dialect_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TEST(IrParserTest, AddAttribute) {

std::string op_str =
" (%0) = \"builtin.get_parameter\" () "
"{parameter_name:(String)conv2d_0.w_0,test:(tp.char)a} : () -> "
"{parameter_name:\"conv2d_0.w_0\",test:(tp.char)a} : () -> "
"pd_op.tensor<64x3x7x7xf32>";
std::stringstream ss;
ss << op_str;
Expand Down
57 changes: 51 additions & 6 deletions test/cpp/pir/core/ir_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,58 @@ class ParserTest {
explicit ParserTest(std::ifstream& test_text) : test_text(test_text) {}
TestTask* GetTestTask();
bool ConsumeTestTask(TestTask* test_task, pir::IrContext* ctx);
std::string Peek(const size_t len);
std::string Get(const size_t len);
};

TestTask* ParserTest::GetTestTask() {
while (test_text.peek() == '\n' || test_text.peek() == ' ') {
test_text.get();
}

if (test_text.peek() == EOF) {
return nullptr;
}
std::string test_info;
while (test_text.peek() != '/') {

while (Peek(7) != "//CHECK" && test_text.peek() != EOF) {
test_text.get();
}
while (test_text.peek() != ' ') {

while (test_text.peek() != ' ' && test_text.peek() != EOF) {
test_text.get();
}

test_text.get();

std::string test_type_info;
while (test_text.peek() != '\n') {
while (test_text.peek() != '\n' && test_text.peek() != ' ' &&
test_text.peek() != EOF) {
test_type_info += test_text.get();
}
test_text.get();
while (test_text.peek() != '/' && test_text.peek() != EOF) {

while (test_text.peek() == '\n' || test_text.peek() == ' ') {
test_text.get();
}

std::string test_info;
while (Peek(5) != "//END" && test_text.peek() != EOF) {
test_info += test_text.get();
}

if (Peek(5) != "//END" || test_info.size() == 0) {
return nullptr;
}

Get(5);

if (test_type_info == "attribute") {
return new TestTask(AttributeTest, test_info);
} else if (test_type_info == "type") {
return new TestTask(TypeTest, test_info);
} else if (test_type_info == "program") {
return new TestTask(ProgramTest, test_info);
}

return nullptr;
}

Expand Down Expand Up @@ -135,6 +158,28 @@ bool ParserTest::ConsumeTestTask(TestTask* test_task, pir::IrContext* ctx) {
return true;
}

std::string ParserTest::Peek(const size_t len) {
std::string str;
auto pos = test_text.tellg();
str = Get(len);
if (test_text.eof()) {
test_text.clear();
}
test_text.seekg(pos);
return str;
}

std::string ParserTest::Get(const size_t len) {
std::string str;
for (size_t i = 0; i < len; i++) {
if (test_text.peek() == EOF) {
break;
}
str += test_text.get();
}
return str;
}

TEST(IrParserTest, TestParserByFile) {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<OperatorDialect>();
Expand Down

0 comments on commit 29c7be6

Please sign in to comment.