diff --git a/src/gitingest/exceptions.py b/src/gitingest/exceptions.py index 8808cf77..3b01018d 100644 --- a/src/gitingest/exceptions.py +++ b/src/gitingest/exceptions.py @@ -49,3 +49,10 @@ class AlreadyVisitedError(Exception): def __init__(self, path: str) -> None: super().__init__(f"Symlink target already visited: {path}") + + +class InvalidNotebookError(Exception): + """Exception raised when a Jupyter notebook is invalid or cannot be processed.""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/src/gitingest/notebook_utils.py b/src/gitingest/notebook_utils.py index c5590341..1a385ca4 100644 --- a/src/gitingest/notebook_utils.py +++ b/src/gitingest/notebook_utils.py @@ -2,11 +2,14 @@ import json import warnings +from itertools import chain from pathlib import Path from typing import Any +from gitingest.exceptions import InvalidNotebookError -def process_notebook(file: Path) -> str: + +def process_notebook(file: Path, include_output: bool = True) -> str: """ Process a Jupyter notebook file and return an executable Python script as a string. @@ -14,6 +17,8 @@ def process_notebook(file: Path) -> str: ---------- file : Path The path to the Jupyter notebook file. + include_output : bool + Whether to include cell outputs in the generated script, by default True. Returns ------- @@ -22,45 +27,127 @@ def process_notebook(file: Path) -> str: Raises ------ - ValueError - If an unexpected cell type is encountered. + InvalidNotebookError + If the notebook file is invalid or cannot be processed. """ - with file.open(encoding="utf-8") as f: - notebook: dict[str, Any] = json.load(f) + try: + with file.open(encoding="utf-8") as f: + notebook: dict[str, Any] = json.load(f) + except json.JSONDecodeError as e: + raise InvalidNotebookError(f"Invalid JSON in notebook: {file}") from e # Check if the notebook contains worksheets if worksheets := notebook.get("worksheets"): - # https://github.com/ipython/ipython/wiki/IPEP-17:-Notebook-Format-4#remove-multiple-worksheets - # "The `worksheets` field is a list, but we have no UI to support multiple worksheets. - # Our design has since shifted to heading-cell based structure, so we never intend to - # support the multiple worksheet model. The worksheets list of lists shall be replaced - # with a single list, called `cells`." - warnings.warn("Worksheets are deprecated as of IPEP-17.", DeprecationWarning) + warnings.warn( + "Worksheets are deprecated as of IPEP-17. Consider updating the notebook. " + "(See: https://github.com/jupyter/nbformat and " + "https://github.com/ipython/ipython/wiki/IPEP-17:-Notebook-Format-4#remove-multiple-worksheets " + "for more information.)", + DeprecationWarning, + ) if len(worksheets) > 1: - warnings.warn( - "Multiple worksheets are not supported. Only the first worksheet will be processed.", UserWarning - ) + warnings.warn("Multiple worksheets detected. Combining all worksheets into a single script.", UserWarning) + + cells = list(chain.from_iterable(ws["cells"] for ws in worksheets)) + + else: + cells = notebook["cells"] + + result = ["# Jupyter notebook converted to Python script."] + + for cell in cells: + if cell_str := _process_cell(cell, include_output=include_output): + result.append(cell_str) + + return "\n\n".join(result) + "\n" + + +def _process_cell(cell: dict[str, Any], include_output: bool) -> str | None: + """ + Process a Jupyter notebook cell and return the cell content as a string. - notebook = worksheets[0] + Parameters + ---------- + cell : dict[str, Any] + The cell dictionary from a Jupyter notebook. + include_output : bool + Whether to include cell outputs in the generated script + + Returns + ------- + str | None + The cell content as a string, or None if the cell is empty. + + Raises + ------ + ValueError + If an unexpected cell type is encountered. + """ + cell_type = cell["cell_type"] - result = [] + # Validate cell type and handle unexpected types + if cell_type not in ("markdown", "code", "raw"): + raise ValueError(f"Unknown cell type: {cell_type}") - for cell in notebook["cells"]: - cell_type = cell.get("cell_type") + cell_str = "".join(cell["source"]) - # Validate cell type and handle unexpected types - if cell_type not in ("markdown", "code", "raw"): - raise ValueError(f"Unknown cell type: {cell_type}") + # Skip empty cells + if not cell_str: + return None + + # Convert Markdown and raw cells to multi-line comments + if cell_type in ("markdown", "raw"): + return f'"""\n{cell_str}\n"""' + + # Add cell output as comments + if include_output and (outputs := cell.get("outputs")): + + # Include cell outputs as comments + output_lines = [] + + for output in outputs: + output_lines += _extract_output(output) + + for output_line in output_lines: + if not output_line.endswith("\n"): + output_line += "\n" + + cell_str += "\n# Output:\n# " + "\n# ".join(output_lines) + + return cell_str + + +def _extract_output(output: dict[str, Any]) -> list[str]: + """ + Extract the output from a Jupyter notebook cell. + + Parameters + ---------- + output : dict[str, Any] + The output dictionary from a Jupyter notebook cell. + + Returns + ------- + list[str] + The output as a list of strings. + + Raises + ------ + ValueError + If an unknown output type is encountered. + """ + output_type = output["output_type"] - str_ = "".join(cell.get("source", [])) - if not str_: - continue + match output_type: + case "stream": + return output["text"] - # Convert Markdown and raw cells to multi-line comments - if cell_type in ("markdown", "raw"): - str_ = f'"""\n{str_}\n"""' + case "execute_result" | "display_data": + return output["data"]["text/plain"] - result.append(str_) + case "error": + return [f"Error: {output['ename']}: {output['evalue']}"] - return "\n\n".join(result) + case _: + raise ValueError(f"Unknown output type: {output_type}") diff --git a/src/gitingest/query_ingestion.py b/src/gitingest/query_ingestion.py index 3396ca6e..2e1f292b 100644 --- a/src/gitingest/query_ingestion.py +++ b/src/gitingest/query_ingestion.py @@ -6,7 +6,12 @@ import tiktoken -from gitingest.exceptions import AlreadyVisitedError, MaxFileSizeReachedError, MaxFilesReachedError +from gitingest.exceptions import ( + AlreadyVisitedError, + InvalidNotebookError, + MaxFileSizeReachedError, + MaxFilesReachedError, +) from gitingest.notebook_utils import process_notebook MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB @@ -164,7 +169,7 @@ def _read_file_content(file_path: Path) -> str: with open(file_path, encoding="utf-8", errors="ignore") as f: return f.read() - except OSError as e: + except (OSError, InvalidNotebookError) as e: return f"Error reading file: {e}" diff --git a/tests/test_notebook_utils.py b/tests/test_notebook_utils.py index a0da1b17..6a23b926 100644 --- a/tests/test_notebook_utils.py +++ b/tests/test_notebook_utils.py @@ -70,16 +70,11 @@ def test_process_notebook_with_worksheets(write_notebook): def test_process_notebook_multiple_worksheets(write_notebook): """ Test a notebook containing multiple 'worksheets'. - - If multiple worksheets are present: - - Only process the first sheet's cells. - - DeprecationWarning for worksheets - - UserWarning for ignoring extra worksheets """ multi_worksheets = { "worksheets": [ {"cells": [{"cell_type": "markdown", "source": ["# First Worksheet"]}]}, - {"cells": [{"cell_type": "code", "source": ['print("Ignored Worksheet")']}]}, + {"cells": [{"cell_type": "code", "source": ["# Second Worksheet"]}]}, ] } @@ -93,15 +88,26 @@ def test_process_notebook_multiple_worksheets(write_notebook): nb_multi = write_notebook("multiple_worksheets.ipynb", multi_worksheets) nb_single = write_notebook("single_worksheet.ipynb", single_worksheet) - with pytest.warns(DeprecationWarning, match="Worksheets are deprecated as of IPEP-17."): - with pytest.warns(UserWarning, match="Multiple worksheets are not supported."): + with pytest.warns( + DeprecationWarning, match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook." + ): + with pytest.warns( + UserWarning, match="Multiple worksheets detected. Combining all worksheets into a single script." + ): result_multi = process_notebook(nb_multi) - with pytest.warns(DeprecationWarning, match="Worksheets are deprecated as of IPEP-17."): + with pytest.warns( + DeprecationWarning, match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook." + ): result_single = process_notebook(nb_single) # The second worksheet (with code) should have been ignored - assert result_multi == result_single, "Second worksheet was ignored, results match." + assert result_multi != result_single, "The multi-worksheet notebook should have more content." + assert len(result_multi) > len(result_single), "The multi-worksheet notebook should have more content." + assert "# First Worksheet" in result_single, "First worksheet content should be present." + assert "# Second Worksheet" not in result_single, "Second worksheet content should be absent." + assert "# First Worksheet" in result_multi, "First worksheet content should be present." + assert "# Second Worksheet" in result_multi, "Second worksheet content should be present." def test_process_notebook_code_only(write_notebook): @@ -204,3 +210,58 @@ def test_process_notebook_invalid_cell_type(write_notebook): with pytest.raises(ValueError, match="Unknown cell type: unknown"): process_notebook(nb_path) + + +def test_process_notebook_with_output(write_notebook): + """ + Test a notebook with code cells and outputs. + + The outputs should be included as comments if `include_output=True`. + """ + notebook_content = { + "cells": [ + { + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "print('my_data')\n", + "my_data = [1, 2, 3, 4, 5]\n", + "plt.plot(my_data)\n", + "my_data", + ], + "outputs": [ + {"output_type": "stream", "text": ["my_data"]}, + {"output_type": "execute_result", "data": {"text/plain": ["[1, 2, 3, 4, 5]"]}}, + {"output_type": "display_data", "data": {"text/plain": ["
"]}}, + ], + } + ] + } + + nb_path = write_notebook("with_output.ipynb", notebook_content) + with_output = process_notebook(nb_path, include_output=True) + without_output = process_notebook(nb_path, include_output=False) + + expected_source = "\n".join( + [ + "# Jupyter notebook converted to Python script.\n", + "import matplotlib.pyplot as plt", + "print('my_data')", + "my_data = [1, 2, 3, 4, 5]", + "plt.plot(my_data)", + "my_data\n", + ] + ) + expected_output = "\n".join( + [ + "# Output:", + "# my_data", + "# [1, 2, 3, 4, 5]", + "#
\n", + ] + ) + + expected_combined = expected_source + expected_output + + assert with_output == expected_combined, "Expected source code and output as comments." + assert without_output == expected_source, "Expected source code only."