diff --git a/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/convert_job.py b/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/convert_job.py index 234ba260f5..cbbc8f3be8 100644 --- a/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/convert_job.py +++ b/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/convert_job.py @@ -1,10 +1,14 @@ # Copyright 2021-2023 VMware, Inc. # SPDX-License-Identifier: Apache-2.0 +import ast import glob +import logging import os import re import shutil +log = logging.getLogger() + def validate_dir(dir_path): if not os.path.isdir(dir_path): @@ -65,19 +69,33 @@ def process_files(self): for file in all_files: if file.endswith(".sql"): - with open(file) as f: - content = f.read() - self._code_structure.append(f'job_input.execute_query("""{content}""")') - self._removed_files.append(os.path.basename(file)) - os.remove(file) + self._process_sql_files(file) elif file.endswith(".py"): - with open(file) as f: - content = f.read() - - if re.search(r"def run\(job_input", content): - self._code_structure.append(content) - self._removed_files.append(os.path.basename(file)) - os.remove(file) + self._process_python_files(file) + + @staticmethod + def _has_run_function(file: str, content: str) -> bool: + tree = ast.parse(content, file) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef): + if node.name == "run": + return True + return False + + def _process_python_files(self, file): + with open(file) as f: + content = f.read() + if self._has_run_function(file, content): + self._code_structure.append(content) + self._removed_files.append(os.path.basename(file)) + os.remove(file) + + def _process_sql_files(self, file): + with open(file) as f: + content = f.read() + self._code_structure.append(f'job_input.execute_query("""{content}""")') + self._removed_files.append(os.path.basename(file)) + os.remove(file) def cleanup(self): for file_name in self._removed_files: diff --git a/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/tests/test_convert_job_directory_processor.py b/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/tests/test_convert_job_directory_processor.py index da6592c692..abed8da35f 100644 --- a/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/tests/test_convert_job_directory_processor.py +++ b/projects/vdk-plugins/vdk-jupyter/vdk-jupyterlab-extension/vdk_jupyterlab_extension/tests/test_convert_job_directory_processor.py @@ -13,20 +13,45 @@ def setUp(self): self.temp_dir = tempfile.mkdtemp() self.sql_content = "SELECT * FROM table" self.py_content_run = """ - def run(job_input: IJobInput): - print("Hello, World!") +def run(job_input: IJobInput): + print("Hello, World!") """ self.py_content_without_run = """ - def hello(): - print("Hello, World!") +def hello(): + print("Hello, World!") +''' +commented out def run +def run(job_input) +''' + """ + self.py_content_run_multiline = """ +def run( + job_input: IJobInput +): + print('Hello, World!') + """ + self.py_content_run_spaces = """ +def run( job_input: IJobInput): + print('Hello, World!') + """ + self.py_content_run_in_a_class = """ +class X: + def run(job_input: IJobInput): + print("Hello, World!") """ - with open(os.path.join(self.temp_dir, "test.sql"), "w") as f: + with open(os.path.join(self.temp_dir, "10_test.sql"), "w") as f: f.write(self.sql_content) - with open(os.path.join(self.temp_dir, "test_run.py"), "w") as f: + with open(os.path.join(self.temp_dir, "20_test_run.py"), "w") as f: f.write(self.py_content_run) - with open(os.path.join(self.temp_dir, "test_without_run.py"), "w") as f: + with open(os.path.join(self.temp_dir, "30_test_without_run.py"), "w") as f: f.write(self.py_content_without_run) + with open(os.path.join(self.temp_dir, "40_test_multi_line_run.py"), "w") as f: + f.write(self.py_content_run_multiline) + with open(os.path.join(self.temp_dir, "50_test_spaces.py"), "w") as f: + f.write(self.py_content_run_spaces) + with open(os.path.join(self.temp_dir, "60_run_in_a_class.py"), "w") as f: + f.write(self.py_content_run_in_a_class) with open(os.path.join(self.temp_dir, "config.ini"), "w") as f: pass @@ -35,42 +60,74 @@ def hello(): def tearDown(self): shutil.rmtree(self.temp_dir) - def test_process_files(self): + def test_process_non_step_files_remain(self): self.processor.process_files() - expected_code_structure = [ - f'job_input.execute_query("""{self.sql_content}""")', - self.py_content_run, - ] - self.assertEqual(self.processor.get_code_structure(), expected_code_structure) - self.assertFalse(os.path.exists(os.path.join(self.temp_dir, "test.sql"))) - self.assertFalse(os.path.exists(os.path.join(self.temp_dir, "test_run.py"))) self.assertTrue( - os.path.exists(os.path.join(self.temp_dir, "test_without_run.py")) + os.path.exists(os.path.join(self.temp_dir, "30_test_without_run.py")) ) - self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.ini"))) - - expected_removed_files = ["test.sql", "test_run.py"] - self.assertEqual( - set(self.processor.get_removed_files()), set(expected_removed_files) + self.assertTrue( + os.path.exists(os.path.join(self.temp_dir, "60_run_in_a_class.py")) ) + self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.ini"))) def test_cleanup(self): self.processor.process_files() self.processor.cleanup() - self.assertFalse(os.path.exists(os.path.join(self.temp_dir, "test.sql"))) - self.assertFalse(os.path.exists(os.path.join(self.temp_dir, "test_run.py"))) + self.assertFalse(os.path.exists(os.path.join(self.temp_dir, "10_test.sql"))) + self.assertFalse(os.path.exists(os.path.join(self.temp_dir, "20_test_run.py"))) + self.assertFalse( + os.path.exists(os.path.join(self.temp_dir, "40_test_multi_line_run.py")) + ) + self.assertFalse( + os.path.exists(os.path.join(self.temp_dir, "50_test_spaces.py")) + ) + + self.assertTrue( + os.path.exists(os.path.join(self.temp_dir, "30_test_without_run.py")) + ) + self.assertTrue( + os.path.exists(os.path.join(self.temp_dir, "60_run_in_a_class.py")) + ) + self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.ini"))) def test_get_code_structure(self): self.processor.process_files() expected_code_structure = [ f'job_input.execute_query("""{self.sql_content}""")', self.py_content_run, + self.py_content_run_multiline, + self.py_content_run_spaces, ] self.assertEqual(self.processor.get_code_structure(), expected_code_structure) def test_get_removed_files(self): self.processor.process_files() - expected_removed_files = ["test.sql", "test_run.py"] + expected_removed_files = [ + "10_test.sql", + "20_test_run.py", + "40_test_multi_line_run.py", + "50_test_spaces.py", + ] self.assertEqual( set(self.processor.get_removed_files()), set(expected_removed_files) ) + + def test_get_bad_python_file(self): + bad_job_dir = tempfile.mkdtemp() + try: + py_content_with_incorrect_syntax = """ + def run(job_input: IJobInput + + print(' Hello, World!') + """ + with open(os.path.join(bad_job_dir, "50_test_spaces.py"), "w") as f: + f.write(py_content_with_incorrect_syntax) + processor = ConvertJobDirectoryProcessor(bad_job_dir) + + try: + processor.process_files() + assert False, "Expected SyntaxError exception" + except SyntaxError as e: + assert "50_test_spaces.py" in e.filename + finally: + shutil.rmtree(bad_job_dir)