Skip to content

Commit

Permalink
Add unit test for max_history magic option
Browse files Browse the repository at this point in the history
  • Loading branch information
akaihola committed Oct 1, 2024
1 parent b7a2842 commit 5426c18
Showing 1 changed file with 63 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import patch
import os
from unittest.mock import Mock, patch

import pytest
from IPython import InteractiveShell
from IPython.core.display import Markdown
from jupyter_ai_magics.magics import AiMagics
from langchain_core.messages import AIMessage, HumanMessage
from pytest import fixture
from traitlets.config.loader import Config

Expand Down Expand Up @@ -48,3 +52,61 @@ def test_default_model_error_line(ip):
assert mock_run.called
cell_args = mock_run.call_args.args[0]
assert cell_args.model_id == "my-favourite-llm"


PROMPT = HumanMessage(
content=("Write code for me please\n\nProduce output in markdown format only.")
)
AI1 = AIMessage("ai1")
H1 = HumanMessage("h1")
AI2 = AIMessage("ai2")
H2 = HumanMessage("h2")
AI3 = AIMessage("ai3")


@pytest.mark.parametrize(
["transcript", "max_history", "expected_context"],
[
([], 3, [PROMPT]),
([AI1], 0, [PROMPT]),
([AI1], 1, [AI1, PROMPT]),
([H1, AI1], 0, [PROMPT]),
([H1, AI1], 1, [H1, AI1, PROMPT]),
([AI1, H1, AI2], 0, [PROMPT]),
([AI1, H1, AI2], 1, [H1, AI2, PROMPT]),
([AI1, H1, AI2], 2, [AI1, H1, AI2, PROMPT]),
([H1, AI1, H2, AI2], 0, [PROMPT]),
([H1, AI1, H2, AI2], 1, [H2, AI2, PROMPT]),
([H1, AI1, H2, AI2], 2, [H1, AI1, H2, AI2, PROMPT]),
([AI1, H1, AI2, H2, AI3], 0, [PROMPT]),
([AI1, H1, AI2, H2, AI3], 1, [H2, AI3, PROMPT]),
([AI1, H1, AI2, H2, AI3], 2, [H1, AI2, H2, AI3, PROMPT]),
([AI1, H1, AI2, H2, AI3], 3, [AI1, H1, AI2, H2, AI3, PROMPT]),
],
)
def test_max_history(ip, transcript, max_history, expected_context):
ip.extension_manager.load_extension("jupyter_ai_magics")
ai_magics = ip.magics_manager.registry["AiMagics"]
ai_magics.transcript = transcript
ai_magics.max_history = max_history
provider = ai_magics._get_provider("openrouter")
with patch.object(provider, "generate") as generate, patch.dict(
os.environ, OPENROUTER_API_KEY="123"
):
generate.return_value.generations = [[Mock(text="Leet code")]]
result = ip.run_cell_magic(
"ai",
"openrouter:anthropic/claude-3.5-sonnet",
cell="Write code for me please",
)
provider.generate.assert_called_once_with([expected_context])
assert isinstance(result, Markdown)
assert result.data == "Leet code"
assert result.filename is None
assert result.metadata == {
"jupyter_ai": {
"model_id": "anthropic/claude-3.5-sonnet",
"provider_id": "openrouter",
}
}
assert result.url is None

0 comments on commit 5426c18

Please sign in to comment.