Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions app/tests/symbolic_evaluation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,6 @@ def test_warning_inappropriate_symbol(self):
'(0,002*6800*v)/1,2',
'(0.002*6800*v)/1.2'
),
(
'-∞',
'-inf'
),
(
'x.y',
'x*y'
Expand Down Expand Up @@ -1865,15 +1861,43 @@ def test_sum_in_answer(self, response, answer, value):
result = evaluation_function(response, answer, params)
assert result["is_correct"] is value

def test_exclamation_mark_for_factorial(self):
response = "3!"
answer = "factorial(3)"
@pytest.mark.parametrize(
"response, answer, value",
[
("3!", "factorial(3)", True),
("(n+1)!", "factorial(n+1)", True),
("n!", "factorial(n)", True),
("a!=b", "factorial(3)", False),
("2*n!", "2*factorial(n)", True),
("3!", "3!", True)
]
)
def test_exclamation_mark_for_factorial(self, response, answer, value):
params = {
"strict_syntax": False,
"elementary_functions": True,
}
result = evaluation_function(response, answer, params)
assert result["is_correct"] is True
assert result["is_correct"] is value

@pytest.mark.parametrize(
"response, answer, value",
[
("3!!", "factorial2(3)", True),
("(n+1)!!", "factorial2(n+1)", True),
("n!!", "factorial2(n)", True),
("a!=b", "factorial2(3)", False),
("2*n!!", "2*factorial2(n)", True),
("3!!", "3!!", True),
]
)
def test_double_exclamation_mark_for_factorial(self, response, answer, value):
params = {
"strict_syntax": False,
"elementary_functions": True,
}
result = evaluation_function(response, answer, params)
assert result["is_correct"] is value

def test_alternatives_to_input_symbols_takes_priority_over_elementary_function_alternatives(self):
answer = "Ef*exp(x)"
Expand Down
161 changes: 160 additions & 1 deletion app/utility/expression_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from sympy.printing.latex import LatexPrinter
from sympy import Basic, Symbol, Equality, Function

from sympy import factorial as _sympy_factorial
from sympy.functions.combinatorial.factorials import factorial2 as _sympy_factorial2


import re
from typing import Dict, List, TypedDict

Expand Down Expand Up @@ -661,6 +665,149 @@ def preprocess_expression(name, expr, parameters):
success = False
return success, expr, abs_feedback

def convert_double_bang_factorial(s: str) -> str:
"""
Convert double postfix factorial (e.g., n!!, (x+1)!!, 3!!) to function form: factorial2(n), etc.
Safeguards:
- Does NOT treat '!=' specially (since we target '!!').
- Requires two consecutive '!' characters (no whitespace in between).
- Handles balanced parenthesis operands (e.g., '(... )!!').
- Handles identifiers and numeric literals.
"""
n = len(s)
i = 0
last = 0
out = []

while i < n:
ch = s[i]
if ch == '!' and (i + 1) < n and s[i + 1] == '!':
# Look left to find the operand (skip whitespace)
j = i - 1
while j >= 0 and s[j].isspace():
j -= 1
if j < 0:
# Nothing to the left; keep as-is
i += 1
continue

# Case 1: operand ends with ')': parenthesized group
if s[j] == ')':
depth = 1
k = j - 1
while k >= 0 and depth > 0:
if s[k] == ')':
depth += 1
elif s[k] == '(':
depth -= 1
k -= 1
if depth == 0:
L = k + 1 # index of '('
R = j # index of ')'
out.append(s[last:L])
out.append('factorial2(')
out.append(s[L:R+1])
out.append(')')
last = i + 2 # consume both '!'
i += 2
continue
else:
# Unbalanced parentheses; leave as-is
i += 1
continue

# Case 2: operand is an identifier/number ending at j
k = j
while k >= 0 and (s[k].isalnum() or s[k] in ('_', '.')):
k -= 1
L = k + 1
if L <= j:
out.append(s[last:L])
out.append('factorial2(')
out.append(s[L:j+1])
out.append(')')
last = i + 2
i += 2
continue
# If we get here, no valid operand token; fall through and keep scanning.

i += 1

out.append(s[last:])
return ''.join(out)

def convert_bang_factorial(s: str) -> str:
"""
Convert single postfix factorial (e.g., n!, (x+1)!, 3!) to function form: factorial(n), etc.
Safeguards:
- Does NOT convert '!='.
- Does NOT convert '!!' (handled by convert_double_bang_factorial).
"""
n = len(s)
i = 0
last = 0
out = []

while i < n:
ch = s[i]
if ch == '!':
# Skip '!=' and '!!' (the latter handled in a separate pass)
nxt = s[i+1] if i + 1 < n else ''
if nxt in ('=', '!'):
i += 1
continue

# Move left to find the operand (skip whitespace)
j = i - 1
while j >= 0 and s[j].isspace():
j -= 1
if j < 0:
i += 1
continue

# Parenthesized operand
if s[j] == ')':
depth = 1
k = j - 1
while k >= 0 and depth > 0:
if s[k] == ')':
depth += 1
elif s[k] == '(':
depth -= 1
k -= 1
if depth == 0:
L = k + 1
R = j
out.append(s[last:L])
out.append('factorial(')
out.append(s[L:R+1])
out.append(')')
last = i + 1
i += 1
continue
else:
i += 1
continue

# Identifier/number operand
k = j
while k >= 0 and (s[k].isalnum() or s[k] in ('_', '.')):
k -= 1
L = k + 1
if L <= j:
out.append(s[last:L])
out.append('factorial(')
out.append(s[L:j+1])
out.append(')')
last = i + 1
i += 1
continue

i += 1

out.append(s[last:])
return ''.join(out)


def parse_expression(expr_string, parsing_params):
'''
Expand All @@ -681,13 +828,25 @@ def parse_expression(expr_string, parsing_params):
separate_unsplittable_symbols = [(x, " "+x) for x in unsplittable_symbols]
substitutions = separate_unsplittable_symbols

symbol_dict = dict(symbol_dict)
symbol_dict.setdefault("factorial", _sympy_factorial)
symbol_dict.setdefault("factorial2", _sympy_factorial2)

if 'factorial' not in unsplittable_symbols or 'factorial2' not in unsplittable_symbols:
unsplittable_symbols = tuple(
list(unsplittable_symbols)
+ [s for s in ('factorial', 'factorial2') if s not in unsplittable_symbols]
)

parsed_expr_set = set()
for expr in expr_set:
expr = preprocess_according_to_chosen_convention(expr, parsing_params)
substitutions = list(set(substitutions))
substitutions.sort(key=substitutions_sort_key)
if parsing_params["elementary_functions"] is True:
substitutions += protect_elementary_functions_substitutions(expr)
expr = convert_double_bang_factorial(expr)
expr = convert_bang_factorial(expr)
substitutions = list(set(substitutions))
substitutions.sort(key=substitutions_sort_key)
expr = substitute(expr, substitutions)
Expand Down Expand Up @@ -717,4 +876,4 @@ def parse_expression(expr_string, parsing_params):
if len(expr_set) == 1:
return parsed_expr_set.pop()
else:
return parsed_expr_set
return parsed_expr_set