Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor state latex drawer #8874

Merged
merged 14 commits into from
Jan 17, 2023
140 changes: 44 additions & 96 deletions qiskit/visualization/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,117 +13,63 @@
Tools to create LaTeX arrays.
"""

import math
from fractions import Fraction
import numpy as np

from qiskit.exceptions import MissingOptionalLibraryError


def _num_to_latex(num, precision=5):
"""Takes a complex number as input and returns a latex representation
def _num_to_latex(raw_value, decimals=15, first_term=True, coefficient=False):
"""Convert a complex number to latex code suitable for a ket expression

Args:
num (numerical): The number to be converted to latex.
precision (int): If the real or imaginary parts of num are not close
to an integer, the number of decimal places to round to

raw_value (complex): Value to convert.
decimals (int): Number of decimal places to round to (default 15).
coefficient (bool): Whether the number is to be used as a coefficient
of a ket.
first_term (bool): If a coefficient, whether this number is the first
coefficient in the expression.
Returns:
str: Latex representation of num
str: latex code
"""
# Result is combination of maximum 4 strings in the form:
# {common_facstring} ( {realstring} {operation} {imagstring}i )
# common_facstring: A common factor between the real and imaginary part
# realstring: The real part (inc. a negative sign if applicable)
# operation: The operation between the real and imaginary parts ('+' or '-')
# imagstring: Absolute value of the imaginary parts (i.e. not inc. any negative sign).
# This function computes each of these strings and combines appropriately.

r = np.real(num)
i = np.imag(num)
common_factor = None

# try to factor out common terms in imaginary numbers
if np.isclose(abs(r), abs(i)) and not np.isclose(r, 0) and not np.isclose(i, 0):
common_factor = abs(r)
r = r / common_factor
i = i / common_factor

common_terms = {
1 / math.sqrt(2): "\\tfrac{1}{\\sqrt{2}}",
1 / math.sqrt(3): "\\tfrac{1}{\\sqrt{3}}",
math.sqrt(2 / 3): "\\sqrt{\\tfrac{2}{3}}",
math.sqrt(3 / 4): "\\sqrt{\\tfrac{3}{4}}",
1 / math.sqrt(8): "\\tfrac{1}{\\sqrt{8}}",
}

def _proc_value(val):
# This function converts a real value to a latex string
# First, see if val is close to an integer:
val_mod = np.mod(val, 1)
if np.isclose(val_mod, 0) or np.isclose(val_mod, 1):
# If so, return that integer
return str(int(np.round(val)))
# Otherwise, see if it matches one of the common terms
for term, latex_str in common_terms.items():
if np.isclose(abs(val), term):
if val > 0:
return latex_str
else:
return "-" + latex_str
# try to factorise val nicely
frac = Fraction(val).limit_denominator()
num, denom = frac.numerator, frac.denominator
if abs(num) + abs(denom) < 20:
# If fraction is 'nice' return
if val > 0:
return f"\\tfrac{{{abs(num)}}}{{{abs(denom)}}}"
else:
return f"-\\tfrac{{{abs(num)}}}{{{abs(denom)}}}"
else:
# Failing everything else, return val as a decimal
return "{:.{}f}".format(val, precision).rstrip("0")
import sympy # runtime import

# Get string (or None) for common factor between real and imag
if common_factor is None:
common_facstring = None
else:
common_facstring = _proc_value(common_factor)
raw_value = np.around(raw_value, decimals=decimals)
value = sympy.nsimplify(raw_value, rational=False)

# Get string for real part
realstring = _proc_value(r)
if isinstance(value, sympy.core.numbers.Rational) and value.denominator > 50:
# Avoid showing ugly fractions (e.g. 50498971964399/62500000000000)
value = value.evalf() # Display as float

# Get string for both imaginary part and operation between real and imaginary parts
if i > 0:
operation = "+"
imagstring = _proc_value(i)
else:
operation = "-"
imagstring = _proc_value(-i)
if imagstring == "1":
imagstring = "" # Don't want to return '1i', just 'i'

# Now combine the strings appropriately:
if imagstring == "0":
return realstring # realstring already contains the negative sign (if needed)
if realstring == "0":
# imagstring needs the negative sign adding
if operation == "-":
return f"-{imagstring}i"
else:
return f"{imagstring}i"
if common_facstring is not None:
return f"{common_facstring}({realstring} {operation} {imagstring}i)"
else:
return f"{realstring} {operation} {imagstring}i"
if isinstance(value, sympy.core.numbers.Float):
value = round(value, decimals)

element = sympy.latex(value, full_prec=False)

if not coefficient:
return element

def _matrix_to_latex(matrix, precision=5, prefix="", max_size=(8, 8)):
if isinstance(value, sympy.core.Add):
# element has two terms
element = f"({element})"

if element == "1":
element = ""

if element == "-1":
element = "-"

if not first_term and not element.startswith("-"):
element = f"+{element}"

return element


def _matrix_to_latex(matrix, decimals=10, prefix="", max_size=(8, 8)):
"""Latex representation of a complex numpy array (with maximum dimension 2)

Args:
matrix (ndarray): The matrix to be converted to latex, must have dimension 2.
precision (int): For numbers not close to integers, the number of decimal places
decimals (int): For numbers not close to integers, the number of decimal places
to round to.
prefix (str): Latex string to be prepended to the latex, intended for labels.
max_size (list(```int```)): Indexable containing two integers: Maximum width and maximum
Expand All @@ -149,7 +95,7 @@ def _elements_to_latex(elements):
# string from it; Each element separated by `&`
el_string = ""
for el in elements:
num_string = _num_to_latex(el, precision=precision)
num_string = _num_to_latex(el, decimals=decimals)
el_string += num_string + " & "
el_string = el_string[:-2] # remove trailing ampersands
return el_string
Expand Down Expand Up @@ -197,7 +143,7 @@ def _rows_to_latex(rows, max_width):
return out_string


def array_to_latex(array, precision=5, prefix="", source=False, max_size=8):
def array_to_latex(array, precision=10, prefix="", source=False, max_size=8):
"""Latex representation of a complex numpy array (with dimension 1 or 2)

Args:
Expand Down Expand Up @@ -239,10 +185,12 @@ def array_to_latex(array, precision=5, prefix="", source=False, max_size=8):
if array.ndim <= 2:
if isinstance(max_size, int):
max_size = (max_size, max_size)
outstr = _matrix_to_latex(array, precision=precision, prefix=prefix, max_size=max_size)
outstr = _matrix_to_latex(array, decimals=precision, prefix=prefix, max_size=max_size)
frankharkins marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError("array_to_latex can only convert numpy ndarrays of dimension 1 or 2")

outstr = _matrix_to_latex(array, decimals=precision, prefix=prefix, max_size=max_size)

if source is False:
try:
from IPython.display import Latex
Expand Down
94 changes: 40 additions & 54 deletions qiskit/visualization/state_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Optional, List, Union
from functools import reduce
import colorsys
import warnings
import numpy as np
from qiskit import user_config
from qiskit.quantum_info.states.statevector import Statevector
Expand All @@ -29,7 +30,7 @@
from qiskit.utils import optionals as _optionals
from qiskit.circuit.tools.pi_check import pi_check

from .array import array_to_latex
from .array import _num_to_latex, array_to_latex
from .utils import matplotlib_close_if_inline
from .exceptions import VisualizationError

Expand Down Expand Up @@ -1233,69 +1234,34 @@ def num_to_latex_ket(raw_value: complex, first_term: bool, decimals: int = 10) -
Returns:
String with latex code or None if no term is required
"""
import sympy # runtime import

if raw_value == 0:
value = 0
real_value = 0
imag_value = 0
else:
raw_value = np.around(raw_value, decimals=decimals)
value = sympy.nsimplify(raw_value, constants=(sympy.pi,), rational=False)
real_value = float(sympy.re(value))
imag_value = float(sympy.im(value))

element = ""
if np.abs(value) > 0:
latex_element = sympy.latex(value, full_prec=False)
two_term = real_value != 0 and imag_value != 0
if isinstance(value, sympy.core.Add):
# can happen for expressions like 1 + sqrt(2)
two_term = True
if two_term:
if first_term:
element = f"({latex_element})"
else:
element = f"+ ({latex_element})"
else:
if first_term:
if np.isreal(complex(value)) and value > 0:
element = latex_element
else:
element = latex_element
if element == "1":
element = ""
elif element == "-1":
element = "-"
else:

if imag_value == 0 and real_value > 0:
element = "+" + latex_element
elif real_value == 0 and imag_value > 0:
element = "+" + latex_element
else:
element = latex_element
if element == "+1":
element = "+"
elif element == "-1":
element = "-"

return element
else:
warnings.warn(
"~qiskit.visualization.state_visualization.num_to_latex_ket "
"is deprecated as of 0.23.0 and will be removed no earlier than 3 months "
"after the release.",
category=DeprecationWarning,
stacklevel=2,
)
if np.around(np.abs(raw_value), decimals=decimals) == 0:
return None
return _num_to_latex(raw_value, first_term=first_term, decimals=decimals, coefficient=True)


def numbers_to_latex_terms(numbers: List[complex], decimals: int = 10) -> List[str]:
"""Convert a list of numbers to latex formatted terms

The first non-zero term is treated differently. For this term a leading + is suppressed.

Args:
numbers: List of numbers to format
decimals: Number of decimal places to round to (default: 10).
Returns:
List of formatted terms
"""
warnings.warn(
"~qiskit.visualization.state_visualization.num_to_latex_terms "
ikkoham marked this conversation as resolved.
Show resolved Hide resolved
"is deprecated as of 0.23.0 and will be removed no earlier than 3 months "
"after the release.",
category=DeprecationWarning,
stacklevel=2,
)
first_term = True
terms = []
for number in numbers:
Expand All @@ -1306,6 +1272,26 @@ def numbers_to_latex_terms(numbers: List[complex], decimals: int = 10) -> List[s
return terms


def _numbers_to_latex_terms(numbers: List[complex], decimals: int = 10) -> List[str]:
"""Convert a list of numbers to latex formatted terms

The first non-zero term is treated differently. For this term a leading + is suppressed.

Args:
numbers: List of numbers to format
decimals: Number of decimal places to round to (default: 10).
Returns:
List of formatted terms
"""
first_term = True
terms = []
for number in numbers:
term = _num_to_latex(number, decimals=decimals, first_term=first_term, coefficient=True)
terms.append(term)
first_term = False
return terms


def _state_to_latex_ket(data: List[complex], max_size: int = 12, prefix: str = "") -> str:
"""Convert state vector to latex representation

Expand All @@ -1329,10 +1315,10 @@ def ket_name(i):
nonzero_indices = (
nonzero_indices[: max_size // 2] + [0] + nonzero_indices[-max_size // 2 + 1 :]
)
latex_terms = numbers_to_latex_terms(data[nonzero_indices], max_size)
latex_terms = _numbers_to_latex_terms(data[nonzero_indices], max_size)
nonzero_indices[max_size // 2] = None
else:
latex_terms = numbers_to_latex_terms(data[nonzero_indices], max_size)
latex_terms = _numbers_to_latex_terms(data[nonzero_indices], max_size)

latex_str = ""
for idx, ket_idx in enumerate(nonzero_indices):
Expand Down
13 changes: 13 additions & 0 deletions releasenotes/notes/latex-refactor-0745471ddecac605.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
deprecations:
- |
Deprecated two functions
(```~qiskit.visualization.state_visualization.num_to_latex_ket``` and
```~qiskit.visualization.state_visualization.num_to_latex_terms```).
other:
frankharkins marked this conversation as resolved.
Show resolved Hide resolved
- |
The latex array drawer (e.g. ```array_to_latex```,
```Statevector.draw('latex')```) now uses the same sympy function as the
ket-convention drawer. This means it may render some numbers differently
(e.g. may identify new factors, or rationalize denominators where it did
not previously). The default ```precision``` has been changed from 5 to 10.
9 changes: 5 additions & 4 deletions test/python/quantum_info/states/test_statevector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,13 +1182,14 @@ def test_number_to_latex_terms(self):
([-1, 1j], ["-", "+i"]),
([1e-16 + 1j], ["i"]),
([-1 + 1e-16 * 1j], ["-"]),
([-1, -1 - 1j], ["-", "+ (-1 - i)"]),
([-1, -1 - 1j], ["-", "+(-1 - i)"]),
([np.sqrt(2) / 2, np.sqrt(2) / 2], ["\\frac{\\sqrt{2}}{2}", "+\\frac{\\sqrt{2}}{2}"]),
([1 + np.sqrt(2)], ["(1 + \\sqrt{2})"]),
]
for numbers, latex_terms in cases:
terms = numbers_to_latex_terms(numbers, 15)
self.assertListEqual(terms, latex_terms)
with self.assertWarns(DeprecationWarning):
for numbers, latex_terms in cases:
terms = numbers_to_latex_terms(numbers, 15)
self.assertListEqual(terms, latex_terms)

def test_statevector_draw_latex_regression(self):
"""Test numerical rounding errors are not printed"""
Expand Down
7 changes: 4 additions & 3 deletions test/python/visualization/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,10 @@ def test_array_to_latex(self):
]
matrix = np.array(matrix)
exp_str = (
"\\begin{bmatrix}\\tfrac{1}{\\sqrt{2}}&\\tfrac{1}{16}&\\tfrac{1}{\\sqrt{8}}+3i&"
"\\tfrac{1}{2}(-1+i)\\\\\\tfrac{1}{3}(1+i)&\\tfrac{1}{\\sqrt{2}}i&34.321&"
"-\\tfrac{9}{2}\\\\\\end{bmatrix}"
"\\begin{bmatrix}\\frac{\\sqrt{2}}{2}&\\frac{1}{16}&"
"\\frac{\\sqrt{2}}{4}+3i&-\\frac{1}{2}+\\frac{i}{2}\\\\"
"\\frac{1}{3}+\\frac{i}{3}&\\frac{\\sqrt{2}i}{2}&34.321&-"
"\\frac{9}{2}\\\\\\end{bmatrix}"
)
result = array_to_latex(matrix, source=True).replace(" ", "").replace("\n", "")
self.assertEqual(exp_str, result)
Expand Down