Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: Use set_at & as_fieldop instead of closure in iterator tests #1691

Open
wants to merge 89 commits into
base: main
Choose a base branch
from

Conversation

tehrengruber
Copy link
Contributor

@tehrengruber tehrengruber commented Oct 16, 2024

Blocked by #1648

Small script to do the conversion:

import ast
import sys

def transform_closure_to_set_at(source):
    """
    Transforms `closure(<domain>, <stencil>, <out>, [<arg1>, <arg2>, ...])`
    into `domain = <domain>
    set_at(as_fieldop(<stencil>, domain)(<arg1>, <arg2>, ...), domain, <out>)`.
    """
    # Parse the input expression
    try:
        parsed = ast.parse(source, mode='eval')
    except SyntaxError as e:
        raise ValueError(f"Invalid Python expression: {e}")

    if not isinstance(parsed.body, ast.Call):
        raise ValueError("Expected a function call expression.")

    func_call = parsed.body

    if not isinstance(func_call.func, ast.Name) or func_call.func.id != 'closure':
        raise ValueError("Expected a 'closure' function call.")

    # Extract the arguments from the closure call
    if len(func_call.args) < 3:
        raise ValueError("Expected at least three arguments: <domain>, <stencil>, <out>.")

    domain = ast.unparse(func_call.args[0])
    stencil = ast.unparse(func_call.args[1])
    out = ast.unparse(func_call.args[2])
    args = ", ".join([ast.unparse(arg) for arg in func_call.args[3].elts])

    # Construct the transformed code as a string
    transformed = f"""domain = {domain}
set_at(as_fieldop({stencil}, domain)({args}), domain, {out})"""
    
    return transformed

def main():
    # Read input from stdin until a blank line is encountered
    lines = []
    for line in sys.stdin:
        stripped_line = line.strip()
        if not stripped_line:
            break
        lines.append(stripped_line)
    
    # Combine the lines into a single input expression
    input_expr = " ".join(lines)
    
    # Perform the transformation
    result = transform_closure_to_set_at(input_expr)
    print(result, end="")

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants